diff --git a/.circleci/config.yml b/.circleci/config.yml index d3723416..ff163f73 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -479,7 +479,7 @@ jobs: OMP_NUM_THREADS: "1" CXXFLAGS: "-O2 -fsanitize=thread -fno-omit-frame-pointer" BUILD_TYPE: "Debug" - GTEST_ARGS: "--gtest_filter=-LoggingTest.LoggingAlsoWorksAfterFork:AssertTest_DebugBuild.*:CliTest_Setup.*:CliTest_IntegrityCheck.*:*/CliTest_WrongEnvironment.*:CliTest_Unmount.*" + GTEST_ARGS: "--gtest_filter=-LoggingTest.LoggingAlsoWorksAfterFork:AssertTest_DebugBuild.*:SignalCatcherTest.*_thenDies:CliTest_Setup.*:CliTest_IntegrityCheck.*:*/CliTest_WrongEnvironment.*:CliTest_Unmount.*" CMAKE_FLAGS: "" RUN_TESTS: true clang_tidy: diff --git a/src/blobstore/implementations/onblocks/BlobOnBlocks.cpp b/src/blobstore/implementations/onblocks/BlobOnBlocks.cpp index 3c08ce9c..b9db5b1c 100644 --- a/src/blobstore/implementations/onblocks/BlobOnBlocks.cpp +++ b/src/blobstore/implementations/onblocks/BlobOnBlocks.cpp @@ -52,6 +52,10 @@ void BlobOnBlocks::flush() { _datatree->flush(); } +uint32_t BlobOnBlocks::numNodes() const { + return _datatree->numNodes(); +} + const BlockId &BlobOnBlocks::blockId() const { return _datatree->blockId(); } diff --git a/src/blobstore/implementations/onblocks/BlobOnBlocks.h b/src/blobstore/implementations/onblocks/BlobOnBlocks.h index fe24e884..e1f0bee3 100644 --- a/src/blobstore/implementations/onblocks/BlobOnBlocks.h +++ b/src/blobstore/implementations/onblocks/BlobOnBlocks.h @@ -35,6 +35,8 @@ public: void flush() override; + uint32_t numNodes() const override; + cpputils::unique_ref releaseTree(); private: diff --git a/src/blobstore/implementations/onblocks/datatreestore/DataTree.cpp b/src/blobstore/implementations/onblocks/datatreestore/DataTree.cpp index e2e39251..e04b3c89 100644 --- a/src/blobstore/implementations/onblocks/datatreestore/DataTree.cpp +++ b/src/blobstore/implementations/onblocks/datatreestore/DataTree.cpp @@ -13,6 +13,7 @@ #include #include "impl/LeafTraverser.h" #include +#include using blockstore::BlockId; using blobstore::onblocks::datanodestore::DataNodeStore; @@ -64,7 +65,16 @@ unique_ref DataTree::releaseRootNode() { return std::move(_rootNode); } -//TODO Test numLeaves(), for example also two configurations with same number of bytes but different number of leaves (last leaf has 0 bytes) +uint32_t DataTree::numNodes() const { + uint32_t numNodesCurrentLevel = numLeaves(); + uint32_t totalNumNodes = numNodesCurrentLevel; + for(size_t level = 0; level < _rootNode->depth(); ++level) { + numNodesCurrentLevel = blobstore::onblocks::utils::ceilDivision(numNodesCurrentLevel, static_cast(_nodeStore->layout().maxChildrenPerInnerNode())); + totalNumNodes += numNodesCurrentLevel; + } + return totalNumNodes; +} + uint32_t DataTree::numLeaves() const { shared_lock lock(_treeStructureMutex); diff --git a/src/blobstore/implementations/onblocks/datatreestore/DataTree.h b/src/blobstore/implementations/onblocks/datatreestore/DataTree.h index 6154ca87..c26f3b75 100644 --- a/src/blobstore/implementations/onblocks/datatreestore/DataTree.h +++ b/src/blobstore/implementations/onblocks/datatreestore/DataTree.h @@ -40,6 +40,7 @@ public: void resizeNumBytes(uint64_t newNumBytes); + uint32_t numNodes() const; uint32_t numLeaves() const; uint64_t numBytes() const; diff --git a/src/blobstore/implementations/onblocks/parallelaccessdatatreestore/DataTreeRef.h b/src/blobstore/implementations/onblocks/parallelaccessdatatreestore/DataTreeRef.h index b6e8b1ea..71c780ac 100644 --- a/src/blobstore/implementations/onblocks/parallelaccessdatatreestore/DataTreeRef.h +++ b/src/blobstore/implementations/onblocks/parallelaccessdatatreestore/DataTreeRef.h @@ -54,6 +54,10 @@ public: return _baseTree->flush(); } + uint32_t numNodes() const { + return _baseTree->numNodes(); + } + private: datatreestore::DataTree *_baseTree; diff --git a/src/blobstore/interface/Blob.h b/src/blobstore/interface/Blob.h index 51b84e3d..64a55b7f 100644 --- a/src/blobstore/interface/Blob.h +++ b/src/blobstore/interface/Blob.h @@ -26,6 +26,8 @@ public: virtual void flush() = 0; + virtual uint32_t numNodes() const = 0; + //TODO Test tryRead }; diff --git a/src/blockstore/implementations/integrity/IntegrityBlockStore2.cpp b/src/blockstore/implementations/integrity/IntegrityBlockStore2.cpp index 62ae8f43..06add856 100644 --- a/src/blockstore/implementations/integrity/IntegrityBlockStore2.cpp +++ b/src/blockstore/implementations/integrity/IntegrityBlockStore2.cpp @@ -2,11 +2,14 @@ #include "IntegrityBlockStore2.h" #include "KnownBlockVersions.h" #include +#include +#include using cpputils::Data; using cpputils::unique_ref; using cpputils::serialize; using cpputils::deserialize; +using cpputils::SignalCatcher; using std::string; using boost::optional; using boost::none; @@ -195,22 +198,32 @@ void IntegrityBlockStore2::forEachBlock(std::function ca #ifndef CRYFS_NO_COMPATIBILITY void IntegrityBlockStore2::migrateFromBlockstoreWithoutVersionNumbers(BlockStore2 *baseBlockStore, const boost::filesystem::path &integrityFilePath, uint32_t myClientId) { - std::cout << "Migrating file system for integrity features. Please don't interrupt this process. This can take a while..." << std::flush; + SignalCatcher signalCatcher; + KnownBlockVersions knownBlockVersions(integrityFilePath, myClientId); - baseBlockStore->forEachBlock([&baseBlockStore, &knownBlockVersions] (const BlockId &blockId) { + uint64_t numProcessedBlocks = 0; + cpputils::ProgressBar progressbar("Migrating file system for integrity features. This can take a while...", baseBlockStore->numBlocks()); + baseBlockStore->forEachBlock([&] (const BlockId &blockId) { + if (signalCatcher.signal_occurred()) { + throw std::runtime_error("Caught signal"); + } migrateBlockFromBlockstoreWithoutVersionNumbers(baseBlockStore, blockId, &knownBlockVersions); + progressbar.update(++numProcessedBlocks); }); - std::cout << "done" << std::endl; } void IntegrityBlockStore2::migrateBlockFromBlockstoreWithoutVersionNumbers(blockstore::BlockStore2* baseBlockStore, const blockstore::BlockId& blockId, KnownBlockVersions *knownBlockVersions) { - uint64_t version = knownBlockVersions->incrementVersion(blockId); - auto data_ = baseBlockStore->load(blockId); if (data_ == boost::none) { LOG(WARN, "Block not found, but was returned from forEachBlock before"); return; } + if (0 != _readFormatHeader(*data_)) { + // already migrated + return; + } + + uint64_t version = knownBlockVersions->incrementVersion(blockId); cpputils::Data data = std::move(*data_); cpputils::Data dataWithHeader = _prependHeaderToData(blockId, knownBlockVersions->myClientId(), version, data); baseBlockStore->store(blockId, dataWithHeader); diff --git a/src/cpp-utils/CMakeLists.txt b/src/cpp-utils/CMakeLists.txt index 2d5db214..500e6ec1 100644 --- a/src/cpp-utils/CMakeLists.txt +++ b/src/cpp-utils/CMakeLists.txt @@ -11,6 +11,7 @@ set(SOURCES crypto/hash/Hash.cpp process/daemonize.cpp process/subprocess.cpp + process/SignalCatcher.cpp tempfile/TempFile.cpp tempfile/TempDir.cpp network/HttpClient.cpp @@ -22,10 +23,12 @@ set(SOURCES io/IOStreamConsole.cpp io/NoninteractiveConsole.cpp io/pipestream.cpp + io/ProgressBar.cpp thread/LoopThread.cpp thread/ThreadSystem.cpp thread/debugging_nonwindows.cpp thread/debugging_windows.cpp + thread/LeftRight.cpp random/Random.cpp random/RandomGeneratorThread.cpp random/OSRandomGenerator.cpp diff --git a/src/cpp-utils/io/ProgressBar.cpp b/src/cpp-utils/io/ProgressBar.cpp new file mode 100644 index 00000000..97aa499a --- /dev/null +++ b/src/cpp-utils/io/ProgressBar.cpp @@ -0,0 +1,35 @@ +#include "ProgressBar.h" +#include +#include +#include +#include "IOStreamConsole.h" + +using std::string; + +namespace cpputils { + +ProgressBar::ProgressBar(const char* preamble, uint64_t max_value) +: ProgressBar(std::make_shared(), preamble, max_value) {} + +ProgressBar::ProgressBar(std::shared_ptr console, const char* preamble, uint64_t max_value) +: _console(std::move(console)) +, _preamble(string("\r") + preamble + " ") +, _max_value(max_value) +, _lastPercentage(std::numeric_limits::max()) { + ASSERT(_max_value > 0, "Progress bar can't handle max_value of 0"); + + _console->print("\n"); + + // show progress bar. _lastPercentage is different to zero, so it shows. + update(0); +} + +void ProgressBar::update(uint64_t value) { + const size_t percentage = (100 * value) / _max_value; + if (percentage != _lastPercentage) { + _console->print(_preamble + std::to_string(percentage) + "%"); + _lastPercentage = percentage; + } +} + +} diff --git a/src/cpp-utils/io/ProgressBar.h b/src/cpp-utils/io/ProgressBar.h new file mode 100644 index 00000000..1bbd7841 --- /dev/null +++ b/src/cpp-utils/io/ProgressBar.h @@ -0,0 +1,31 @@ +#pragma once +#ifndef MESSMER_CPPUTILS_IO_PROGRESSBAR_H +#define MESSMER_CPPUTILS_IO_PROGRESSBAR_H + +#include +#include +#include +#include "Console.h" + +namespace cpputils { + +class ProgressBar final { +public: + explicit ProgressBar(std::shared_ptr console, const char* preamble, uint64_t max_value); + explicit ProgressBar(const char* preamble, uint64_t max_value); + + void update(uint64_t value); + +private: + + std::shared_ptr _console; + std::string _preamble; + uint64_t _max_value; + size_t _lastPercentage; + + DISALLOW_COPY_AND_ASSIGN(ProgressBar); +}; + +} + +#endif diff --git a/src/cpp-utils/process/SignalCatcher.cpp b/src/cpp-utils/process/SignalCatcher.cpp new file mode 100644 index 00000000..d0113a2e --- /dev/null +++ b/src/cpp-utils/process/SignalCatcher.cpp @@ -0,0 +1,239 @@ +#include "SignalCatcher.h" + +#include +#include +#include +#include +#include + +using std::make_unique; +using std::vector; +using std::pair; + +namespace cpputils { + +namespace { + +void got_signal(int signal); + +using SignalHandlerFunction = void(int); + +constexpr SignalHandlerFunction* signal_catcher_function = &got_signal; + +#if !defined(_MSC_VER) + +class SignalHandlerRAII final { +public: + SignalHandlerRAII(int signal) + : _old_handler(), _signal(signal) { + struct sigaction new_signal_handler{}; + std::memset(&new_signal_handler, 0, sizeof(new_signal_handler)); + new_signal_handler.sa_handler = signal_catcher_function; // NOLINT(cppcoreguidelines-pro-type-union-access) + new_signal_handler.sa_flags = SA_RESTART; + int error = sigfillset(&new_signal_handler.sa_mask); // block all signals while signal handler is running + if (0 != error) { + throw std::runtime_error("Error calling sigfillset. Errno: " + std::to_string(errno)); + } + _sigaction(_signal, &new_signal_handler, &_old_handler); + } + + ~SignalHandlerRAII() { + // reset to old signal handler + struct sigaction removed_handler{}; + _sigaction(_signal, &_old_handler, &removed_handler); + if (signal_catcher_function != removed_handler.sa_handler) { // NOLINT(cppcoreguidelines-pro-type-union-access) + ASSERT(false, "Signal handler screwup. We just replaced a signal handler that wasn't our own."); + } + } + +private: + static void _sigaction(int signal, struct sigaction *new_handler, struct sigaction *old_handler) { + int error = sigaction(signal, new_handler, old_handler); + if (0 != error) { + throw std::runtime_error("Error calling sigaction. Errno: " + std::to_string(errno)); + } + } + + struct sigaction _old_handler; + int _signal; + + DISALLOW_COPY_AND_ASSIGN(SignalHandlerRAII); +}; + +#else + +class SignalHandlerRAII final { +public: + SignalHandlerRAII(int signal) + : _old_handler(nullptr), _signal(signal) { + _old_handler = ::signal(_signal, signal_catcher_function); + if (_old_handler == SIG_ERR) { + throw std::logic_error("Error calling signal(). Errno: " + std::to_string(errno)); + } + } + + ~SignalHandlerRAII() { + // reset to old signal handler + SignalHandlerFunction* error = ::signal(_signal, _old_handler); + if (error == SIG_ERR) { + throw std::logic_error("Error resetting signal(). Errno: " + std::to_string(errno)); + } + if (error != signal_catcher_function) { + throw std::logic_error("Signal handler screwup. We just replaced a signal handler that wasn't our own."); + } + } + +private: + + SignalHandlerFunction* _old_handler; + int _signal; + + DISALLOW_COPY_AND_ASSIGN(SignalHandlerRAII); +}; + +// The Linux default behavior (i.e. the way we set up sigaction above) is to disable signal processing while the signal +// handler is running and to re-enable the custom handler once processing is finished. The Windows default behavior +// is to reset the handler to the default handler directly before executing the handler, i.e. the handler will only +// be called once. To fix this, we use this RAII class on Windows, of which an instance will live in the signal handler. +// In its constructor, it disables signal handling, and in its destructor it re-sets the custom handler. +// This is not perfect since there is a small time window between calling the signal handler and calling the constructor +// of this class, but it's the best we can do. +class SignalHandlerRunningRAII final { +public: + SignalHandlerRunningRAII(int signal) : _signal(signal) { + SignalHandlerFunction* old_handler = ::signal(_signal, SIG_IGN); + if (old_handler == SIG_ERR) { + throw std::logic_error("Error disabling signal(). Errno: " + std::to_string(errno)); + } + if (old_handler != SIG_DFL) { + // see description above, we expected the signal handler to be reset. + throw std::logic_error("We expected windows to reset the signal handler but it didn't. Did the Windows API change?"); + } + } + + ~SignalHandlerRunningRAII() { + SignalHandlerFunction* old_handler = ::signal(_signal, signal_catcher_function); + if (old_handler == SIG_ERR) { + throw std::logic_error("Error resetting signal() after calling handler. Errno: " + std::to_string(errno)); + } + if (old_handler != SIG_IGN) { + throw std::logic_error("Weird, we just did set the signal handler to ignore. Why isn't it still ignore?"); + } + } + +private: + int _signal; +}; + +#endif + +class SignalCatcherRegistry final { +public: + void add(int signal, std::atomic* signal_occurred_flag) { + _catchers.write([&] (auto& catchers) { + catchers.emplace_back(signal, signal_occurred_flag); + }); + } + + void remove(std::atomic* catcher) { + _catchers.write([&] (auto& catchers) { + auto found = std::find_if(catchers.rbegin(), catchers.rend(), [catcher] (const auto& entry) {return entry.second == catcher;}); + ASSERT(found != catchers.rend(), "Signal handler not found"); + catchers.erase(--found.base()); // decrement because it's a reverse iterator + }); + } + + ~SignalCatcherRegistry() { + ASSERT(0 == _catchers.read([] (auto& catchers) {return catchers.size();}), "Leftover signal catchers that weren't destroyed"); + } + + std::atomic* find(int signal) { + // this is called in a signal handler and must be mutex-free. + return _catchers.read([&](auto& catchers) { + auto found = std::find_if(catchers.rbegin(), catchers.rend(), [signal](const auto& entry) {return entry.first == signal; }); + ASSERT(found != catchers.rend(), "Signal handler not found"); + return found->second; + }); + } + + static SignalCatcherRegistry& singleton() { + static SignalCatcherRegistry _singleton; + return _singleton; + } + +private: + SignalCatcherRegistry() = default; + + // using LeftRight datastructure because we need mutex-free reads. Signal handlers can't use mutexes. + LeftRight*>>> _catchers; + + DISALLOW_COPY_AND_ASSIGN(SignalCatcherRegistry); +}; + +void got_signal(int signal) { +#if defined(_MSC_VER) + // Only needed on Windows, Linux does this by default. See comment on SignalHandlerRunningRAII class. + SignalHandlerRunningRAII disable_signal_processing_while_handler_running_and_reset_handler_afterwards(signal); +#endif + std::atomic* catcher = SignalCatcherRegistry::singleton().find(signal); + *catcher = true; +} + +class SignalCatcherRegisterer final { +public: + SignalCatcherRegisterer(int signal, std::atomic* catcher) + : _catcher(catcher) { + SignalCatcherRegistry::singleton().add(signal, _catcher); + } + + ~SignalCatcherRegisterer() { + SignalCatcherRegistry::singleton().remove(_catcher); + } + +private: + std::atomic* _catcher; + + DISALLOW_COPY_AND_ASSIGN(SignalCatcherRegisterer); +}; + +} + +namespace details { + +class SignalCatcherImpl final { +public: + SignalCatcherImpl(int signal, std::atomic* signal_occurred_flag) + : _registerer(signal, signal_occurred_flag) + , _handler(signal) { + // note: the order of the members ensures that: + // - when registering the signal handler fails, the registerer will be destroyed, unregistering the signal_occurred_flag, + // i.e. there is no leak. + + // Allow only the set of signals that is supported on all platforms, see for Windows: https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/signal?view=vs-2017 + ASSERT(signal == SIGABRT || signal == SIGFPE || signal == SIGILL || signal == SIGINT || signal == SIGSEGV || signal == SIGTERM, "Unknown signal"); + } +private: + SignalCatcherRegisterer _registerer; + SignalHandlerRAII _handler; + + DISALLOW_COPY_AND_ASSIGN(SignalCatcherImpl); +}; + +} + +SignalCatcher::SignalCatcher(std::initializer_list signals) +: _signal_occurred(false) +, _impls() { + // note: the order of the members ensures that: + // - when the signal handler is set, the _signal_occurred flag is already initialized. + // - the _signal_occurred flag will not be destructed as long as the signal handler might be called (i.e. as long as _impls lives) + + _impls.reserve(signals.size()); + for (int signal : signals) { + _impls.emplace_back(make_unique(signal, &_signal_occurred)); + } +} + +SignalCatcher::~SignalCatcher() {} + +} diff --git a/src/cpp-utils/process/SignalCatcher.h b/src/cpp-utils/process/SignalCatcher.h new file mode 100644 index 00000000..87d3c948 --- /dev/null +++ b/src/cpp-utils/process/SignalCatcher.h @@ -0,0 +1,42 @@ +#pragma once +#ifndef MESSMER_CPPUTILS_PROCESS_SIGNALCATCHER_H_ +#define MESSMER_CPPUTILS_PROCESS_SIGNALCATCHER_H_ + +#include +#include +#include +#include +#include + +namespace cpputils { + +namespace details { +class SignalCatcherImpl; +} + +/* + * While an instance of this class is in scope, the specified signal (e.g. SIGINT) + * is caught and doesn't exit the application. You can poll if the signal occurred. + */ +class SignalCatcher final { +public: + SignalCatcher(): SignalCatcher({SIGINT, SIGTERM}) {} + + SignalCatcher(std::initializer_list signals); + ~SignalCatcher(); + + bool signal_occurred() const { + return _signal_occurred; + } + +private: + // note: _signal_occurred must be initialized before _impl because _impl might use it + std::atomic _signal_occurred; + std::vector> _impls; + + DISALLOW_COPY_AND_ASSIGN(SignalCatcher); +}; + +} + +#endif diff --git a/src/cpp-utils/thread/LeftRight.cpp b/src/cpp-utils/thread/LeftRight.cpp new file mode 100644 index 00000000..d0c6cbd2 --- /dev/null +++ b/src/cpp-utils/thread/LeftRight.cpp @@ -0,0 +1 @@ +#include "LeftRight.h" diff --git a/src/cpp-utils/thread/LeftRight.h b/src/cpp-utils/thread/LeftRight.h new file mode 100644 index 00000000..1870eee9 --- /dev/null +++ b/src/cpp-utils/thread/LeftRight.h @@ -0,0 +1,165 @@ +#include +#include +#include +#include +#include +#include + +namespace cpputils { + +namespace detail { + +struct IncrementRAII final { +public: + explicit IncrementRAII(std::atomic *counter): _counter(counter) { + ++(*_counter); + } + + ~IncrementRAII() { + --(*_counter); + } +private: + std::atomic *_counter; + + DISALLOW_COPY_AND_ASSIGN(IncrementRAII); +}; + +} + +// LeftRight wait-free readers synchronization primitive +// https://hal.archives-ouvertes.fr/hal-01207881/document +template +class LeftRight final { +public: + LeftRight() + : _writeMutex() + , _foregroundCounterIndex{0} + , _foregroundDataIndex{0} + , _counters{{{0}, {0}}} + , _data{{{}, {}}} + , _inDestruction(false) {} + + ~LeftRight() { + // from now on, no new readers/writers will be accepted (see asserts in read()/write()) + _inDestruction = true; + + // wait until any potentially running writers are finished + { + std::unique_lock lock(_writeMutex); + } + + // wait until any potentially running readers are finished + while (_counters[0].load() != 0 || _counters[1].load() != 0) { + std::this_thread::yield(); + } + } + + template + auto read(F&& readFunc) const { + if(_inDestruction.load()) { + throw std::logic_error("Issued LeftRight::read() after the destructor started running"); + } + + detail::IncrementRAII _increment_counter(&_counters[_foregroundCounterIndex.load()]); // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index) + return readFunc(_data[_foregroundDataIndex.load()]); // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index) + } + + // Throwing from write would result in invalid state + template + auto write(F&& writeFunc) { + if(_inDestruction.load()) { + throw std::logic_error("Issued LeftRight::read() after the destructor started running"); + } + + std::unique_lock lock(_writeMutex); + return _write(writeFunc); + } + +private: + template + auto _write(const F& writeFunc) { + /* + * Assume, A is in background and B in foreground. In simplified terms, we want to do the following: + * 1. Write to A (old background) + * 2. Switch A/B + * 3. Write to B (new background) + * + * More detailed algorithm (explanations on why this is important are below in code): + * 1. Write to A + * 2. Switch A/B data pointers + * 3. Wait until A counter is zero + * 4. Switch A/B counters + * 5. Wait until B counter is zero + * 6. Write to B + */ + + auto localDataIndex = _foregroundDataIndex.load(); + + // 1. Write to A + _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + + // 2. Switch A/B data pointers + localDataIndex = localDataIndex ^ 1; + _foregroundDataIndex = localDataIndex; + + /* + * 3. Wait until A counter is zero + * + * In the previous write run, A was foreground and B was background. + * There was a time after switching _foregroundDataIndex (B to foreground) and before switching _foregroundCounterIndex, + * in which new readers could have read B but incremented A's counter. + * + * In this current run, we just switched _foregroundDataIndex (A back to foreground), but before writing to + * the new background B, we have to make sure A's counter was zero briefly, so all these old readers are gone. + */ + auto localCounterIndex = _foregroundCounterIndex.load(); + _waitForBackgroundCounterToBeZero(localCounterIndex); + + /* + *4. Switch A/B counters + * + * Now that we know all readers on B are really gone, we can switch the counters and have new readers + * increment A's counter again, which is the correct counter since they're reading A. + */ + localCounterIndex = localCounterIndex ^ 1; + _foregroundCounterIndex = localCounterIndex; + + /* + * 5. Wait until B counter is zero + * + * This waits for all the readers on B that came in while both data and counter for B was in foreground, + * i.e. normal readers that happened outside of that brief gap between switching data and counter. + */ + _waitForBackgroundCounterToBeZero(localCounterIndex); + + // 6. Write to B + _callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex); + } + + template + auto _callWriteFuncOnBackgroundInstance(const F& writeFunc, uint8_t localDataIndex) { + try { + return writeFunc(_data[localDataIndex ^ 1]); // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index) + } catch (...) { + // recover invariant by copying from the foreground instance + _data[localDataIndex ^ 1] = _data[localDataIndex]; // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index) + // rethrow + throw; + } + } + + void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) { + while (_counters[counterIndex ^ 1].load() != 0) { // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index) + std::this_thread::yield(); + } + } + + std::mutex _writeMutex; + std::atomic _foregroundCounterIndex; + std::atomic _foregroundDataIndex; + mutable std::array, 2> _counters; + std::array _data; + std::atomic _inDestruction; +}; + +} diff --git a/src/cryfs/config/CryConfig.cpp b/src/cryfs/config/CryConfig.cpp index 3558e693..01b49843 100644 --- a/src/cryfs/config/CryConfig.cpp +++ b/src/cryfs/config/CryConfig.cpp @@ -32,6 +32,7 @@ CryConfig::CryConfig() , _exclusiveClientId(none) #ifndef CRYFS_NO_COMPATIBILITY , _hasVersionNumbers(true) +, _hasParentPointers(true) #endif { } @@ -53,6 +54,7 @@ CryConfig CryConfig::load(const Data &data) { cfg._exclusiveClientId = pt.get_optional("cryfs.exclusiveClientId"); #ifndef CRYFS_NO_COMPATIBILITY cfg._hasVersionNumbers = pt.get("cryfs.migrations.hasVersionNumbers", false); + cfg._hasParentPointers = pt.get("cryfs.migrations.hasParentPointers", false); #endif optional filesystemIdOpt = pt.get_optional("cryfs.filesystemId"); @@ -81,6 +83,7 @@ Data CryConfig::save() const { } #ifndef CRYFS_NO_COMPATIBILITY pt.put("cryfs.migrations.hasVersionNumbers", _hasVersionNumbers); + pt.put("cryfs.migrations.hasParentPointers", _hasParentPointers); #endif stringstream stream; @@ -172,6 +175,14 @@ bool CryConfig::HasVersionNumbers() const { void CryConfig::SetHasVersionNumbers(bool value) { _hasVersionNumbers = value; } + +bool CryConfig::HasParentPointers() const { + return _hasParentPointers; +} + +void CryConfig::SetHasParentPointers(bool value) { + _hasParentPointers = value; +} #endif } diff --git a/src/cryfs/config/CryConfig.h b/src/cryfs/config/CryConfig.h index 26fdcaeb..618d67e3 100644 --- a/src/cryfs/config/CryConfig.h +++ b/src/cryfs/config/CryConfig.h @@ -56,6 +56,11 @@ public: // Version numbers cannot be disabled, but the file system will be migrated to version numbers automatically. bool HasVersionNumbers() const; void SetHasVersionNumbers(bool value); + + // This is a trigger to recognize old file systems that didn't have version numbers. + // Version numbers cannot be disabled, but the file system will be migrated to version numbers automatically. + bool HasParentPointers() const; + void SetHasParentPointers(bool value); #endif static CryConfig load(const cpputils::Data &data); @@ -73,6 +78,7 @@ private: boost::optional _exclusiveClientId; #ifndef CRYFS_NO_COMPATIBILITY bool _hasVersionNumbers; + bool _hasParentPointers; #endif CryConfig &operator=(const CryConfig &rhs) = delete; diff --git a/src/cryfs/filesystem/CryDevice.cpp b/src/cryfs/filesystem/CryDevice.cpp index c6d9ee7e..c09c4a2c 100644 --- a/src/cryfs/filesystem/CryDevice.cpp +++ b/src/cryfs/filesystem/CryDevice.cpp @@ -80,7 +80,14 @@ unique_ref CryDevice::MigrateOrCreateFsBlobStore(uniqu if ("" == rootBlobId) { return make_unique_ref(std::move(blobStore)); } - return FsBlobStore::migrateIfNeeded(std::move(blobStore), BlockId::FromString(rootBlobId)); + if (!configFile->config()->HasParentPointers()) { + auto result = FsBlobStore::migrate(std::move(blobStore), BlockId::FromString(rootBlobId)); + // Don't migrate again if it was successful + configFile->config()->SetHasParentPointers(true); + configFile->save(); + return result; + } + return make_unique_ref(std::move(blobStore)); } #endif @@ -106,6 +113,7 @@ unique_ref CryDevice::CreateIntegrityEncryptedBlockStore(unique_ref if (!configFile->config()->HasVersionNumbers()) { IntegrityBlockStore2::migrateFromBlockstoreWithoutVersionNumbers(encryptedBlockStore.get(), integrityFilePath, myClientId); configFile->config()->SetBlocksizeBytes(configFile->config()->BlocksizeBytes() + IntegrityBlockStore2::HEADER_LENGTH - blockstore::BlockId::BINARY_LENGTH); // Minus BlockId size because EncryptedBlockStore doesn't store the BlockId anymore (that was moved to IntegrityBlockStore) + // Don't migrate again if it was successful configFile->config()->SetHasVersionNumbers(true); configFile->save(); } diff --git a/src/cryfs/filesystem/fsblobstore/FsBlobStore.cpp b/src/cryfs/filesystem/fsblobstore/FsBlobStore.cpp index 6d116358..a7136f6b 100644 --- a/src/cryfs/filesystem/fsblobstore/FsBlobStore.cpp +++ b/src/cryfs/filesystem/fsblobstore/FsBlobStore.cpp @@ -2,9 +2,13 @@ #include "FileBlob.h" #include "DirBlob.h" #include "SymlinkBlob.h" +#include +#include +#include using cpputils::unique_ref; using cpputils::make_unique_ref; +using cpputils::SignalCatcher; using blobstore::BlobStore; using blockstore::BlockId; using boost::none; @@ -31,33 +35,41 @@ boost::optional> FsBlobStore::load(const blockstore::BlockId } #ifndef CRYFS_NO_COMPATIBILITY - unique_ref FsBlobStore::migrateIfNeeded(unique_ref blobStore, const blockstore::BlockId &rootBlobId) { + unique_ref FsBlobStore::migrate(unique_ref blobStore, const blockstore::BlockId &rootBlobId) { + SignalCatcher signalCatcher; + auto rootBlob = blobStore->load(rootBlobId); ASSERT(rootBlob != none, "Could not load root blob"); - uint16_t format = FsBlobView::getFormatVersionHeader(**rootBlob); auto fsBlobStore = make_unique_ref(std::move(blobStore)); - if (format == 0) { - // migration needed - std::cout << "Migrating file system for conflict resolution features. Please don't interrupt this process. This can take a while..." << std::flush; - fsBlobStore->_migrate(std::move(*rootBlob), blockstore::BlockId::Null()); - std::cout << "done" << std::endl; - } + + uint64_t migratedBlocks = 0; + cpputils::ProgressBar progressbar("Migrating file system for conflict resolution features. This can take a while...", fsBlobStore->numBlocks()); + fsBlobStore->_migrate(std::move(*rootBlob), blockstore::BlockId::Null(), &signalCatcher, [&] (uint32_t numNodes) { + migratedBlocks += numNodes; + progressbar.update(migratedBlocks); + }); + return fsBlobStore; } - void FsBlobStore::_migrate(unique_ref node, const blockstore::BlockId &parentId) { + void FsBlobStore::_migrate(unique_ref node, const blockstore::BlockId &parentId, SignalCatcher* signalCatcher, std::function perBlobCallback) { FsBlobView::migrate(node.get(), parentId); + perBlobCallback(node->numNodes()); if (FsBlobView::blobType(*node) == FsBlobView::BlobType::DIR) { DirBlob dir(std::move(node), _getLstatSize()); vector children; dir.AppendChildrenTo(&children); for (const auto &child : children) { + if (signalCatcher->signal_occurred()) { + // on a SIGINT or SIGTERM, cancel migration but gracefully shutdown, i.e. call destructors. + throw std::runtime_error("Caught signal"); + } auto childEntry = dir.GetChild(child.name); - ASSERT(childEntry != none, "Couldn't load child, although it was returned as a child in the lsit."); + ASSERT(childEntry != none, "Couldn't load child, although it was returned as a child in the list."); auto childBlob = _baseBlobStore->load(childEntry->blockId()); ASSERT(childBlob != none, "Couldn't load child blob"); - _migrate(std::move(*childBlob), dir.blockId()); + _migrate(std::move(*childBlob), dir.blockId(), signalCatcher, perBlobCallback); } } } diff --git a/src/cryfs/filesystem/fsblobstore/FsBlobStore.h b/src/cryfs/filesystem/fsblobstore/FsBlobStore.h index 48c6909a..d55aa0dd 100644 --- a/src/cryfs/filesystem/fsblobstore/FsBlobStore.h +++ b/src/cryfs/filesystem/fsblobstore/FsBlobStore.h @@ -8,6 +8,9 @@ #include "FileBlob.h" #include "DirBlob.h" #include "SymlinkBlob.h" +#ifndef CRYFS_NO_COMPATIBILITY +#include +#endif namespace cryfs { namespace fsblobstore { @@ -29,13 +32,13 @@ namespace cryfs { uint64_t virtualBlocksizeBytes() const; #ifndef CRYFS_NO_COMPATIBILITY - static cpputils::unique_ref migrateIfNeeded(cpputils::unique_ref blobStore, const blockstore::BlockId &blockId); + static cpputils::unique_ref migrate(cpputils::unique_ref blobStore, const blockstore::BlockId &blockId); #endif private: #ifndef CRYFS_NO_COMPATIBILITY - void _migrate(cpputils::unique_ref node, const blockstore::BlockId &parentId); + void _migrate(cpputils::unique_ref node, const blockstore::BlockId &parentId, cpputils::SignalCatcher* signalCatcher, std::function perBlobCallback); #endif std::function _getLstatSize(); diff --git a/src/cryfs/filesystem/fsblobstore/FsBlobView.cpp b/src/cryfs/filesystem/fsblobstore/FsBlobView.cpp index 32d59240..c1f7e57d 100644 --- a/src/cryfs/filesystem/fsblobstore/FsBlobView.cpp +++ b/src/cryfs/filesystem/fsblobstore/FsBlobView.cpp @@ -10,7 +10,11 @@ namespace cryfs { void FsBlobView::migrate(blobstore::Blob *blob, const blockstore::BlockId &parentId) { constexpr unsigned int OLD_HEADER_SIZE = sizeof(FORMAT_VERSION_HEADER) + sizeof(uint8_t); - ASSERT(FsBlobView::getFormatVersionHeader(*blob) == 0, "Block already migrated"); + if(FsBlobView::getFormatVersionHeader(*blob) != 0) { + // blob already migrated + return; + } + // Resize blob and move data back cpputils::Data data = blob->readAll(); blob->resize(blob->size() + blockstore::BlockId::BINARY_LENGTH); diff --git a/src/cryfs/filesystem/fsblobstore/FsBlobView.h b/src/cryfs/filesystem/fsblobstore/FsBlobView.h index f13bc46d..cb6ca8ad 100644 --- a/src/cryfs/filesystem/fsblobstore/FsBlobView.h +++ b/src/cryfs/filesystem/fsblobstore/FsBlobView.h @@ -85,6 +85,10 @@ namespace cryfs { return _baseBlob->flush(); } + uint32_t numNodes() const override { + return _baseBlob->numNodes(); + } + cpputils::unique_ref releaseBaseBlob() { return std::move(_baseBlob); } diff --git a/src/stats/main.cpp b/src/stats/main.cpp index 6012d97f..f5da47ee 100644 --- a/src/stats/main.cpp +++ b/src/stats/main.cpp @@ -163,6 +163,18 @@ int main(int argc, char* argv[]) { std::cerr << "Error loading config file" << std::endl; exit(1); } + const auto& config_ = config->configFile.config(); + std::cout << "Loading filesystem of version " << config_->Version() << std::endl; +#ifndef CRYFS_NO_COMPATIBILITY + const bool is_correct_format = config_->Version() == CryConfig::FilesystemFormatVersion && !config_->HasParentPointers() && !config_->HasVersionNumbers(); +#else + const bool is_correct_format = config_->Version() == CryConfig::FilesystemFormatVersion; +#endif + if (!is_correct_format) { + // TODO At this point, the cryfs.config file was already switched to 0.10 format. We should probably not do that. + std::cerr << "The filesystem is not in the 0.10 format. It needs to be migrated. The cryfs-stats tool unfortunately can't handle this, please mount and unmount the filesystem once." << std::endl; + exit(1); + } cout << "Listing all blocks..." << flush; set unaccountedBlocks = _getAllBlockIds(basedir, *config, localStateDir); diff --git a/test/blobstore/implementations/onblocks/datatreestore/DataTreeTest_NumStoredBytes.cpp b/test/blobstore/implementations/onblocks/datatreestore/DataTreeTest_NumStoredBytes.cpp index dc3bedbd..12d53621 100644 --- a/test/blobstore/implementations/onblocks/datatreestore/DataTreeTest_NumStoredBytes.cpp +++ b/test/blobstore/implementations/onblocks/datatreestore/DataTreeTest_NumStoredBytes.cpp @@ -21,50 +21,68 @@ INSTANTIATE_TEST_CASE_P(EmptyLastLeaf, DataTreeTest_NumStoredBytes_P, Values(0u) INSTANTIATE_TEST_CASE_P(HalfFullLastLeaf, DataTreeTest_NumStoredBytes_P, Values(5u, 10u)); INSTANTIATE_TEST_CASE_P(FullLastLeaf, DataTreeTest_NumStoredBytes_P, Values(static_cast(DataNodeLayout(DataTreeTest_NumStoredBytes::BLOCKSIZE_BYTES).maxBytesPerLeaf()))); +//TODO Test numLeaves() and numNodes() also two configurations with same number of bytes but different number of leaves (last leaf has 0 bytes) + TEST_P(DataTreeTest_NumStoredBytes_P, SingleLeaf) { BlockId blockId = CreateLeafWithSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(GetParam(), tree->numBytes()); + EXPECT_EQ(1, tree->numLeaves()); + EXPECT_EQ(1, tree->numNodes()); } TEST_P(DataTreeTest_NumStoredBytes_P, TwoLeafTree) { BlockId blockId = CreateTwoLeafWithSecondLeafSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes()); + EXPECT_EQ(2, tree->numLeaves()); + EXPECT_EQ(3, tree->numNodes()); } TEST_P(DataTreeTest_NumStoredBytes_P, FullTwolevelTree) { BlockId blockId = CreateFullTwoLevelWithLastLeafSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*(nodeStore->layout().maxChildrenPerInnerNode()-1) + GetParam(), tree->numBytes()); + EXPECT_EQ(nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves()); + EXPECT_EQ(1 + nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes()); } TEST_P(DataTreeTest_NumStoredBytes_P, ThreeLevelTreeWithOneChild) { BlockId blockId = CreateThreeLevelWithOneChildAndLastLeafSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes()); + EXPECT_EQ(2, tree->numLeaves()); + EXPECT_EQ(4, tree->numNodes()); } TEST_P(DataTreeTest_NumStoredBytes_P, ThreeLevelTreeWithTwoChildren) { BlockId blockId = CreateThreeLevelWithTwoChildrenAndLastLeafSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes()); + EXPECT_EQ(2 + nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves()); + EXPECT_EQ(5 + nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes()); } TEST_P(DataTreeTest_NumStoredBytes_P, ThreeLevelTreeWithThreeChildren) { BlockId blockId = CreateThreeLevelWithThreeChildrenAndLastLeafSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(2*nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes()); + EXPECT_EQ(2 + 2*nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves()); + EXPECT_EQ(6 + 2*nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes()); } TEST_P(DataTreeTest_NumStoredBytes_P, FullThreeLevelTree) { BlockId blockId = CreateFullThreeLevelWithLastLeafSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode()*(nodeStore->layout().maxChildrenPerInnerNode()-1) + nodeStore->layout().maxBytesPerLeaf()*(nodeStore->layout().maxChildrenPerInnerNode()-1) + GetParam(), tree->numBytes()); + EXPECT_EQ(nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves()); + EXPECT_EQ(1 + nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes()); } TEST_P(DataTreeTest_NumStoredBytes_P, FourLevelMinDataTree) { BlockId blockId = CreateFourLevelMinDataWithLastLeafSize(GetParam())->blockId(); auto tree = treeStore.load(blockId).value(); EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode() + GetParam(), tree->numBytes()); + EXPECT_EQ(1 + nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves()); + EXPECT_EQ(5 + nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes()); } diff --git a/test/cpp-utils/CMakeLists.txt b/test/cpp-utils/CMakeLists.txt index d8efa2e6..831683c9 100644 --- a/test/cpp-utils/CMakeLists.txt +++ b/test/cpp-utils/CMakeLists.txt @@ -16,6 +16,7 @@ set(SOURCES process/daemonize_include_test.cpp process/subprocess_include_test.cpp process/SubprocessTest.cpp + process/SignalCatcherTest.cpp tempfile/TempFileTest.cpp tempfile/TempFileIncludeTest.cpp tempfile/TempDirIncludeTest.cpp @@ -28,6 +29,7 @@ set(SOURCES io/ConsoleTest_Print.cpp io/ConsoleTest_Ask.cpp io/ConsoleTest_AskPassword.cpp + io/ProgressBarTest.cpp random/RandomIncludeTest.cpp lock/LockPoolIncludeTest.cpp lock/ConditionBarrierIncludeTest.cpp @@ -55,6 +57,7 @@ set(SOURCES system/HomedirTest.cpp system/EnvTest.cpp thread/debugging_test.cpp + thread/LeftRightTest.cpp value_type/ValueTypeTest.cpp ) diff --git a/test/cpp-utils/io/ProgressBarTest.cpp b/test/cpp-utils/io/ProgressBarTest.cpp new file mode 100644 index 00000000..97ca2211 --- /dev/null +++ b/test/cpp-utils/io/ProgressBarTest.cpp @@ -0,0 +1,66 @@ +#include +#include + +using cpputils::ProgressBar; +using std::make_shared; + +class MockConsole final: public cpputils::Console { +public: + void EXPECT_OUTPUT(const char* expected) { + EXPECT_EQ(expected, _output); + _output = ""; + } + + void print(const std::string& text) override { + _output += text; + } + + unsigned int ask(const std::string&, const std::vector&) override { + EXPECT_TRUE(false); + return 0; + } + + bool askYesNo(const std::string&, bool) override { + EXPECT_TRUE(false); + return false; + } + + std::string askPassword(const std::string&) override { + EXPECT_TRUE(false); + return ""; + } + +private: + std::string _output; +}; + +TEST(ProgressBarTest, testProgressBar) { + auto console = make_shared(); + + ProgressBar bar(console, "Preamble", 2000); + console->EXPECT_OUTPUT("\n\rPreamble 0%"); + + // when updating to 0, doesn't reprint + bar.update(0); + console->EXPECT_OUTPUT(""); + + // update to half + bar.update(1000); + console->EXPECT_OUTPUT("\rPreamble 50%"); + + // when updating to same value, doesn't reprint + bar.update(1000); + console->EXPECT_OUTPUT(""); + + // when updating to value with same percentage, doesn't reprint + bar.update(1001); + console->EXPECT_OUTPUT(""); + + // update to 0 + bar.update(0); + console->EXPECT_OUTPUT("\rPreamble 0%"); + + // update to full + bar.update(2000); + console->EXPECT_OUTPUT("\rPreamble 100%"); +} diff --git a/test/cpp-utils/process/SignalCatcherTest.cpp b/test/cpp-utils/process/SignalCatcherTest.cpp new file mode 100644 index 00000000..f0d49fce --- /dev/null +++ b/test/cpp-utils/process/SignalCatcherTest.cpp @@ -0,0 +1,193 @@ +#include +#include +#include + +using cpputils::SignalCatcher; + +namespace { +void raise_signal(int signal) { + int error = ::raise(signal); + if (error != 0) { + throw std::runtime_error("Error raising signal"); + } +} +} + +TEST(SignalCatcherTest, givenNoSignalCatcher_whenRaisingSigint_thenDies) { + EXPECT_DEATH( + raise_signal(SIGINT), + "" + ); +} + +TEST(SignalCatcherTest, givenNoSignalCatcher_whenRaisingSigterm_thenDies) { + EXPECT_DEATH( + raise_signal(SIGTERM), + "" + ); +} + +TEST(SignalCatcherTest, givenSigIntCatcher_whenRaisingSigInt_thenCatches) { + SignalCatcher catcher({SIGINT}); + + EXPECT_FALSE(catcher.signal_occurred()); + raise_signal(SIGINT); + EXPECT_TRUE(catcher.signal_occurred()); + + // raise again + raise_signal(SIGINT); + EXPECT_TRUE(catcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenSigTermCatcher_whenRaisingSigTerm_thenCatches) { + SignalCatcher catcher({SIGTERM}); + + EXPECT_FALSE(catcher.signal_occurred()); + raise_signal(SIGTERM); + EXPECT_TRUE(catcher.signal_occurred()); + + // raise again + raise_signal(SIGTERM); + EXPECT_TRUE(catcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigInt_thenCatches) { + SignalCatcher catcher({SIGINT, SIGTERM}); + + EXPECT_FALSE(catcher.signal_occurred()); + raise_signal(SIGINT); + EXPECT_TRUE(catcher.signal_occurred()); + + // raise again + raise_signal(SIGINT); + EXPECT_TRUE(catcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigTerm_thenCatches) { + SignalCatcher catcher({SIGINT, SIGTERM}); + + EXPECT_FALSE(catcher.signal_occurred()); + raise_signal(SIGTERM); + EXPECT_TRUE(catcher.signal_occurred()); + + // raise again + raise_signal(SIGTERM); + EXPECT_TRUE(catcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigIntAndSigTerm_thenCatches) { + SignalCatcher catcher({SIGINT, SIGTERM}); + + EXPECT_FALSE(catcher.signal_occurred()); + raise_signal(SIGTERM); + EXPECT_TRUE(catcher.signal_occurred()); + + raise_signal(SIGINT); + EXPECT_TRUE(catcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenSigIntCatcherAndSigTermCatcher_whenRaisingSignalsInOrder_thenCorrectCatcherCatches) { + SignalCatcher sigintCatcher({SIGINT}); + SignalCatcher sigtermCatcher({SIGTERM}); + + EXPECT_FALSE(sigintCatcher.signal_occurred()); + raise_signal(SIGINT); + EXPECT_TRUE(sigintCatcher.signal_occurred()); + + EXPECT_FALSE(sigtermCatcher.signal_occurred()); + raise_signal(SIGTERM); + EXPECT_TRUE(sigtermCatcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenSigIntCatcherAndSigTermCatcher_whenRaisingSignalsInReverseOrder_thenCorrectCatcherCatches) { + SignalCatcher sigintCatcher({SIGINT}); + SignalCatcher sigtermCatcher({SIGTERM}); + + EXPECT_FALSE(sigtermCatcher.signal_occurred()); + raise_signal(SIGTERM); + EXPECT_TRUE(sigtermCatcher.signal_occurred()); + + EXPECT_FALSE(sigintCatcher.signal_occurred()); + raise_signal(SIGINT); + EXPECT_TRUE(sigintCatcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenNestedSigIntCatchers_whenRaisingSignals_thenCorrectCatcherCatches) { + SignalCatcher outerCatcher({SIGINT}); + { + SignalCatcher middleCatcher({SIGINT}); + + EXPECT_FALSE(middleCatcher.signal_occurred()); + raise_signal(SIGINT); + EXPECT_TRUE(middleCatcher.signal_occurred()); + + { + SignalCatcher innerCatcher({SIGINT}); + + EXPECT_FALSE(innerCatcher.signal_occurred()); + raise_signal(SIGINT); + EXPECT_TRUE(innerCatcher.signal_occurred()); + } + } + + EXPECT_FALSE(outerCatcher.signal_occurred()); + raise_signal(SIGINT); + EXPECT_TRUE(outerCatcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenExpiredSigIntCatcher_whenRaisingSigInt_thenDies) { + { + SignalCatcher catcher({SIGINT}); + } + + EXPECT_DEATH( + raise_signal(SIGINT), + "" + ); +} + +TEST(SignalCatcherTest, givenExpiredSigTermCatcher_whenRaisingSigTerm_thenDies) { + { + SignalCatcher catcher({SIGTERM}); + } + + EXPECT_DEATH( + raise_signal(SIGTERM), + "" + ); +} + +TEST(SignalCatcherTest, givenExpiredSigIntCatcherAndSigTermCatcher_whenRaisingSigTerm_thenDies) { + { + SignalCatcher sigIntCatcher({SIGTERM}); + SignalCatcher sigTermCatcer({SIGTERM}); + } + + EXPECT_DEATH( + raise_signal(SIGTERM), + "" + ); +} + +TEST(SignalCatcherTest, givenSigTermCatcherAndExpiredSigIntCatcher_whenRaisingSigTerm_thenCatches) { + SignalCatcher sigTermCatcher({SIGTERM}); + { + SignalCatcher sigIntCatcher({SIGINT}); + } + + EXPECT_FALSE(sigTermCatcher.signal_occurred()); + raise_signal(SIGTERM); + EXPECT_TRUE(sigTermCatcher.signal_occurred()); +} + +TEST(SignalCatcherTest, givenSigTermCatcherAndExpiredSigIntCatcher_whenRaisingSigInt_thenDies) { + SignalCatcher sigTermCacher({SIGTERM}); + { + SignalCatcher sigIntCatcher({SIGINT}); + } + + EXPECT_DEATH( + raise_signal(SIGINT), + "" + ); +} diff --git a/test/cpp-utils/thread/LeftRightTest.cpp b/test/cpp-utils/thread/LeftRightTest.cpp new file mode 100644 index 00000000..2c2db300 --- /dev/null +++ b/test/cpp-utils/thread/LeftRightTest.cpp @@ -0,0 +1,204 @@ +#include +#include +#include + +using cpputils::LeftRight; +using std::vector; + +TEST(LeftRightTest, givenInt_whenWritingAndReading_thenChangesArePresent) { + LeftRight obj; + + obj.write([] (auto& obj) {obj = 5;}); + int read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ(5, read); + + // check changes are also present in background copy + obj.write([] (auto&) {}); // this switches to the background copy + read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ(5, read); +} + +TEST(LeftRightTest, givenVector_whenWritingAndReading_thenChangesArePresent) { + LeftRight> obj; + + obj.write([] (auto& obj) {obj.push_back(5);}); + vector read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ((vector{5}), read); + + obj.write([] (auto& obj) {obj.push_back(6);}); + read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ((vector{5, 6}), read); +} + +TEST(LeftRightTest, readsCanBeConcurrent) { + LeftRight obj; + std::atomic num_running_readers{0}; + + std::thread reader1([&] () { + obj.read([&] (auto&) { + ++num_running_readers; + while(num_running_readers.load() < 2) {} + }); + }); + + std::thread reader2([&] () { + obj.read([&] (auto&) { + ++num_running_readers; + while(num_running_readers.load() < 2) {} + }); + }); + + // the threads only finish after both entered the read function. + // if LeftRight didn't allow concurrency, this would cause a deadlock. + reader1.join(); + reader2.join(); +} + +TEST(LeftRightTest, writesCanBeConcurrentWithReads_readThenWrite) { + LeftRight obj; + std::atomic reader_running{false}; + std::atomic writer_running{false}; + + std::thread reader([&] () { + obj.read([&] (auto&) { + reader_running = true; + while(!writer_running.load()) {} + }); + }); + + std::thread writer([&] () { + // run read first, write second + while (!reader_running.load()) {} + + obj.write([&] (auto&) { + writer_running = true; + }); + }); + + // the threads only finish after both entered the read function. + // if LeftRight didn't allow concurrency, this would cause a deadlock. + reader.join(); + writer.join(); +} + +TEST(LeftRightTest, writesCanBeConcurrentWithReads_writeThenRead) { + LeftRight obj; + std::atomic writer_running{false}; + std::atomic reader_running{false}; + + std::thread writer([&] () { + obj.read([&] (auto&) { + writer_running = true; + while(!reader_running.load()) {} + }); + }); + + std::thread reader([&] () { + // run write first, read second + while (!writer_running.load()) {} + + obj.read([&] (auto&) { + reader_running = true; + }); + }); + + // the threads only finish after both entered the read function. + // if LeftRight didn't allow concurrency, this would cause a deadlock. + writer.join(); + reader.join(); +} + +TEST(LeftRightTest, writesCannotBeConcurrentWithWrites) { + LeftRight obj; + std::atomic first_writer_started{false}; + std::atomic first_writer_finished{false}; + + std::thread writer1([&] () { + obj.write([&] (auto&) { + first_writer_started = true; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + first_writer_finished = true; + }); + }); + + std::thread writer2([&] () { + // make sure the other writer runs first + while (!first_writer_started.load()) {} + + obj.write([&] (auto&) { + // expect the other writer finished before this one starts + EXPECT_TRUE(first_writer_finished.load()); + }); + }); + + writer1.join(); + writer2.join(); +} + +namespace { +class MyException : std::exception {}; +} + +TEST(LeftRightTest, whenReadThrowsException_thenThrowsThrough) { + LeftRight obj; + + EXPECT_THROW( + obj.read([](auto&) {throw MyException();}), + MyException + ); +} + +TEST(LeftRightTest, whenWriteThrowsException_thenThrowsThrough) { + LeftRight obj; + + EXPECT_THROW( + obj.write([](auto&) {throw MyException();}), + MyException + ); +} + +TEST(LeftRightTest, givenInt_whenWriteThrowsException_thenResetsToOldState) { + LeftRight obj; + + obj.write([](auto& obj) {obj = 5;}); + + EXPECT_THROW( + obj.write([](auto& obj) { + obj = 6; + throw MyException(); + }), + MyException + ); + + // check reading it returns old value + int read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ(5, read); + + // check changes are also present in background copy + obj.write([] (auto&) {}); // this switches to the background copy + read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ(5, read); +} + +TEST(LeftRightTest, givenVector_whenWriteThrowsException_thenResetsToOldState) { + LeftRight> obj; + + obj.write([](auto& obj) {obj.push_back(5);}); + + EXPECT_THROW( + obj.write([](auto& obj) { + obj.push_back(6); + throw MyException(); + }), + MyException + ); + + // check reading it returns old value + vector read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ((vector{5}), read); + + // check changes are also present in background copy + obj.write([] (auto&) {}); // this switches to the background copy + read = obj.read([] (auto& obj) {return obj;}); + EXPECT_EQ((vector{5}), read); +}