Merge branch 'release/0.10' of https://github.com/cryfs/cryfs into release/0.10

This commit is contained in:
Sebastian Messmer 2019-01-27 05:24:39 -08:00
commit 1cdf530548
28 changed files with 1118 additions and 22 deletions

View File

@ -479,7 +479,7 @@ jobs:
OMP_NUM_THREADS: "1"
CXXFLAGS: "-O2 -fsanitize=thread -fno-omit-frame-pointer"
BUILD_TYPE: "Debug"
GTEST_ARGS: "--gtest_filter=-LoggingTest.LoggingAlsoWorksAfterFork:AssertTest_DebugBuild.*:CliTest_Setup.*:CliTest_IntegrityCheck.*:*/CliTest_WrongEnvironment.*:CliTest_Unmount.*"
GTEST_ARGS: "--gtest_filter=-LoggingTest.LoggingAlsoWorksAfterFork:AssertTest_DebugBuild.*:SignalCatcherTest.*_thenDies:CliTest_Setup.*:CliTest_IntegrityCheck.*:*/CliTest_WrongEnvironment.*:CliTest_Unmount.*"
CMAKE_FLAGS: ""
RUN_TESTS: true
clang_tidy:

View File

@ -52,6 +52,10 @@ void BlobOnBlocks::flush() {
_datatree->flush();
}
uint32_t BlobOnBlocks::numNodes() const {
return _datatree->numNodes();
}
const BlockId &BlobOnBlocks::blockId() const {
return _datatree->blockId();
}

View File

@ -35,6 +35,8 @@ public:
void flush() override;
uint32_t numNodes() const override;
cpputils::unique_ref<parallelaccessdatatreestore::DataTreeRef> releaseTree();
private:

View File

@ -13,6 +13,7 @@
#include <cpp-utils/assert/assert.h>
#include "impl/LeafTraverser.h"
#include <boost/thread.hpp>
#include <blobstore/implementations/onblocks/utils/Math.h>
using blockstore::BlockId;
using blobstore::onblocks::datanodestore::DataNodeStore;
@ -64,7 +65,16 @@ unique_ref<DataNode> DataTree::releaseRootNode() {
return std::move(_rootNode);
}
//TODO Test numLeaves(), for example also two configurations with same number of bytes but different number of leaves (last leaf has 0 bytes)
uint32_t DataTree::numNodes() const {
uint32_t numNodesCurrentLevel = numLeaves();
uint32_t totalNumNodes = numNodesCurrentLevel;
for(size_t level = 0; level < _rootNode->depth(); ++level) {
numNodesCurrentLevel = blobstore::onblocks::utils::ceilDivision(numNodesCurrentLevel, static_cast<uint32_t>(_nodeStore->layout().maxChildrenPerInnerNode()));
totalNumNodes += numNodesCurrentLevel;
}
return totalNumNodes;
}
uint32_t DataTree::numLeaves() const {
shared_lock<shared_mutex> lock(_treeStructureMutex);

View File

@ -40,6 +40,7 @@ public:
void resizeNumBytes(uint64_t newNumBytes);
uint32_t numNodes() const;
uint32_t numLeaves() const;
uint64_t numBytes() const;

View File

@ -54,6 +54,10 @@ public:
return _baseTree->flush();
}
uint32_t numNodes() const {
return _baseTree->numNodes();
}
private:
datatreestore::DataTree *_baseTree;

View File

@ -26,6 +26,8 @@ public:
virtual void flush() = 0;
virtual uint32_t numNodes() const = 0;
//TODO Test tryRead
};

View File

@ -2,11 +2,14 @@
#include "IntegrityBlockStore2.h"
#include "KnownBlockVersions.h"
#include <cpp-utils/data/SerializationHelper.h>
#include <cpp-utils/process/SignalCatcher.h>
#include <cpp-utils/io/ProgressBar.h>
using cpputils::Data;
using cpputils::unique_ref;
using cpputils::serialize;
using cpputils::deserialize;
using cpputils::SignalCatcher;
using std::string;
using boost::optional;
using boost::none;
@ -195,22 +198,32 @@ void IntegrityBlockStore2::forEachBlock(std::function<void (const BlockId &)> ca
#ifndef CRYFS_NO_COMPATIBILITY
void IntegrityBlockStore2::migrateFromBlockstoreWithoutVersionNumbers(BlockStore2 *baseBlockStore, const boost::filesystem::path &integrityFilePath, uint32_t myClientId) {
std::cout << "Migrating file system for integrity features. Please don't interrupt this process. This can take a while..." << std::flush;
SignalCatcher signalCatcher;
KnownBlockVersions knownBlockVersions(integrityFilePath, myClientId);
baseBlockStore->forEachBlock([&baseBlockStore, &knownBlockVersions] (const BlockId &blockId) {
uint64_t numProcessedBlocks = 0;
cpputils::ProgressBar progressbar("Migrating file system for integrity features. This can take a while...", baseBlockStore->numBlocks());
baseBlockStore->forEachBlock([&] (const BlockId &blockId) {
if (signalCatcher.signal_occurred()) {
throw std::runtime_error("Caught signal");
}
migrateBlockFromBlockstoreWithoutVersionNumbers(baseBlockStore, blockId, &knownBlockVersions);
progressbar.update(++numProcessedBlocks);
});
std::cout << "done" << std::endl;
}
void IntegrityBlockStore2::migrateBlockFromBlockstoreWithoutVersionNumbers(blockstore::BlockStore2* baseBlockStore, const blockstore::BlockId& blockId, KnownBlockVersions *knownBlockVersions) {
uint64_t version = knownBlockVersions->incrementVersion(blockId);
auto data_ = baseBlockStore->load(blockId);
if (data_ == boost::none) {
LOG(WARN, "Block not found, but was returned from forEachBlock before");
return;
}
if (0 != _readFormatHeader(*data_)) {
// already migrated
return;
}
uint64_t version = knownBlockVersions->incrementVersion(blockId);
cpputils::Data data = std::move(*data_);
cpputils::Data dataWithHeader = _prependHeaderToData(blockId, knownBlockVersions->myClientId(), version, data);
baseBlockStore->store(blockId, dataWithHeader);

View File

@ -11,6 +11,7 @@ set(SOURCES
crypto/hash/Hash.cpp
process/daemonize.cpp
process/subprocess.cpp
process/SignalCatcher.cpp
tempfile/TempFile.cpp
tempfile/TempDir.cpp
network/HttpClient.cpp
@ -22,10 +23,12 @@ set(SOURCES
io/IOStreamConsole.cpp
io/NoninteractiveConsole.cpp
io/pipestream.cpp
io/ProgressBar.cpp
thread/LoopThread.cpp
thread/ThreadSystem.cpp
thread/debugging_nonwindows.cpp
thread/debugging_windows.cpp
thread/LeftRight.cpp
random/Random.cpp
random/RandomGeneratorThread.cpp
random/OSRandomGenerator.cpp

View File

@ -0,0 +1,35 @@
#include "ProgressBar.h"
#include <iostream>
#include <limits>
#include <mutex>
#include "IOStreamConsole.h"
using std::string;
namespace cpputils {
ProgressBar::ProgressBar(const char* preamble, uint64_t max_value)
: ProgressBar(std::make_shared<IOStreamConsole>(), preamble, max_value) {}
ProgressBar::ProgressBar(std::shared_ptr<Console> console, const char* preamble, uint64_t max_value)
: _console(std::move(console))
, _preamble(string("\r") + preamble + " ")
, _max_value(max_value)
, _lastPercentage(std::numeric_limits<decltype(_lastPercentage)>::max()) {
ASSERT(_max_value > 0, "Progress bar can't handle max_value of 0");
_console->print("\n");
// show progress bar. _lastPercentage is different to zero, so it shows.
update(0);
}
void ProgressBar::update(uint64_t value) {
const size_t percentage = (100 * value) / _max_value;
if (percentage != _lastPercentage) {
_console->print(_preamble + std::to_string(percentage) + "%");
_lastPercentage = percentage;
}
}
}

View File

@ -0,0 +1,31 @@
#pragma once
#ifndef MESSMER_CPPUTILS_IO_PROGRESSBAR_H
#define MESSMER_CPPUTILS_IO_PROGRESSBAR_H
#include <cpp-utils/macros.h>
#include <string>
#include <memory>
#include "Console.h"
namespace cpputils {
class ProgressBar final {
public:
explicit ProgressBar(std::shared_ptr<Console> console, const char* preamble, uint64_t max_value);
explicit ProgressBar(const char* preamble, uint64_t max_value);
void update(uint64_t value);
private:
std::shared_ptr<Console> _console;
std::string _preamble;
uint64_t _max_value;
size_t _lastPercentage;
DISALLOW_COPY_AND_ASSIGN(ProgressBar);
};
}
#endif

View File

@ -0,0 +1,239 @@
#include "SignalCatcher.h"
#include <algorithm>
#include <stdexcept>
#include <vector>
#include <cpp-utils/assert/assert.h>
#include <cpp-utils/thread/LeftRight.h>
using std::make_unique;
using std::vector;
using std::pair;
namespace cpputils {
namespace {
void got_signal(int signal);
using SignalHandlerFunction = void(int);
constexpr SignalHandlerFunction* signal_catcher_function = &got_signal;
#if !defined(_MSC_VER)
class SignalHandlerRAII final {
public:
SignalHandlerRAII(int signal)
: _old_handler(), _signal(signal) {
struct sigaction new_signal_handler{};
std::memset(&new_signal_handler, 0, sizeof(new_signal_handler));
new_signal_handler.sa_handler = signal_catcher_function; // NOLINT(cppcoreguidelines-pro-type-union-access)
new_signal_handler.sa_flags = SA_RESTART;
int error = sigfillset(&new_signal_handler.sa_mask); // block all signals while signal handler is running
if (0 != error) {
throw std::runtime_error("Error calling sigfillset. Errno: " + std::to_string(errno));
}
_sigaction(_signal, &new_signal_handler, &_old_handler);
}
~SignalHandlerRAII() {
// reset to old signal handler
struct sigaction removed_handler{};
_sigaction(_signal, &_old_handler, &removed_handler);
if (signal_catcher_function != removed_handler.sa_handler) { // NOLINT(cppcoreguidelines-pro-type-union-access)
ASSERT(false, "Signal handler screwup. We just replaced a signal handler that wasn't our own.");
}
}
private:
static void _sigaction(int signal, struct sigaction *new_handler, struct sigaction *old_handler) {
int error = sigaction(signal, new_handler, old_handler);
if (0 != error) {
throw std::runtime_error("Error calling sigaction. Errno: " + std::to_string(errno));
}
}
struct sigaction _old_handler;
int _signal;
DISALLOW_COPY_AND_ASSIGN(SignalHandlerRAII);
};
#else
class SignalHandlerRAII final {
public:
SignalHandlerRAII(int signal)
: _old_handler(nullptr), _signal(signal) {
_old_handler = ::signal(_signal, signal_catcher_function);
if (_old_handler == SIG_ERR) {
throw std::logic_error("Error calling signal(). Errno: " + std::to_string(errno));
}
}
~SignalHandlerRAII() {
// reset to old signal handler
SignalHandlerFunction* error = ::signal(_signal, _old_handler);
if (error == SIG_ERR) {
throw std::logic_error("Error resetting signal(). Errno: " + std::to_string(errno));
}
if (error != signal_catcher_function) {
throw std::logic_error("Signal handler screwup. We just replaced a signal handler that wasn't our own.");
}
}
private:
SignalHandlerFunction* _old_handler;
int _signal;
DISALLOW_COPY_AND_ASSIGN(SignalHandlerRAII);
};
// The Linux default behavior (i.e. the way we set up sigaction above) is to disable signal processing while the signal
// handler is running and to re-enable the custom handler once processing is finished. The Windows default behavior
// is to reset the handler to the default handler directly before executing the handler, i.e. the handler will only
// be called once. To fix this, we use this RAII class on Windows, of which an instance will live in the signal handler.
// In its constructor, it disables signal handling, and in its destructor it re-sets the custom handler.
// This is not perfect since there is a small time window between calling the signal handler and calling the constructor
// of this class, but it's the best we can do.
class SignalHandlerRunningRAII final {
public:
SignalHandlerRunningRAII(int signal) : _signal(signal) {
SignalHandlerFunction* old_handler = ::signal(_signal, SIG_IGN);
if (old_handler == SIG_ERR) {
throw std::logic_error("Error disabling signal(). Errno: " + std::to_string(errno));
}
if (old_handler != SIG_DFL) {
// see description above, we expected the signal handler to be reset.
throw std::logic_error("We expected windows to reset the signal handler but it didn't. Did the Windows API change?");
}
}
~SignalHandlerRunningRAII() {
SignalHandlerFunction* old_handler = ::signal(_signal, signal_catcher_function);
if (old_handler == SIG_ERR) {
throw std::logic_error("Error resetting signal() after calling handler. Errno: " + std::to_string(errno));
}
if (old_handler != SIG_IGN) {
throw std::logic_error("Weird, we just did set the signal handler to ignore. Why isn't it still ignore?");
}
}
private:
int _signal;
};
#endif
class SignalCatcherRegistry final {
public:
void add(int signal, std::atomic<bool>* signal_occurred_flag) {
_catchers.write([&] (auto& catchers) {
catchers.emplace_back(signal, signal_occurred_flag);
});
}
void remove(std::atomic<bool>* catcher) {
_catchers.write([&] (auto& catchers) {
auto found = std::find_if(catchers.rbegin(), catchers.rend(), [catcher] (const auto& entry) {return entry.second == catcher;});
ASSERT(found != catchers.rend(), "Signal handler not found");
catchers.erase(--found.base()); // decrement because it's a reverse iterator
});
}
~SignalCatcherRegistry() {
ASSERT(0 == _catchers.read([] (auto& catchers) {return catchers.size();}), "Leftover signal catchers that weren't destroyed");
}
std::atomic<bool>* find(int signal) {
// this is called in a signal handler and must be mutex-free.
return _catchers.read([&](auto& catchers) {
auto found = std::find_if(catchers.rbegin(), catchers.rend(), [signal](const auto& entry) {return entry.first == signal; });
ASSERT(found != catchers.rend(), "Signal handler not found");
return found->second;
});
}
static SignalCatcherRegistry& singleton() {
static SignalCatcherRegistry _singleton;
return _singleton;
}
private:
SignalCatcherRegistry() = default;
// using LeftRight datastructure because we need mutex-free reads. Signal handlers can't use mutexes.
LeftRight<vector<pair<int, std::atomic<bool>*>>> _catchers;
DISALLOW_COPY_AND_ASSIGN(SignalCatcherRegistry);
};
void got_signal(int signal) {
#if defined(_MSC_VER)
// Only needed on Windows, Linux does this by default. See comment on SignalHandlerRunningRAII class.
SignalHandlerRunningRAII disable_signal_processing_while_handler_running_and_reset_handler_afterwards(signal);
#endif
std::atomic<bool>* catcher = SignalCatcherRegistry::singleton().find(signal);
*catcher = true;
}
class SignalCatcherRegisterer final {
public:
SignalCatcherRegisterer(int signal, std::atomic<bool>* catcher)
: _catcher(catcher) {
SignalCatcherRegistry::singleton().add(signal, _catcher);
}
~SignalCatcherRegisterer() {
SignalCatcherRegistry::singleton().remove(_catcher);
}
private:
std::atomic<bool>* _catcher;
DISALLOW_COPY_AND_ASSIGN(SignalCatcherRegisterer);
};
}
namespace details {
class SignalCatcherImpl final {
public:
SignalCatcherImpl(int signal, std::atomic<bool>* signal_occurred_flag)
: _registerer(signal, signal_occurred_flag)
, _handler(signal) {
// note: the order of the members ensures that:
// - when registering the signal handler fails, the registerer will be destroyed, unregistering the signal_occurred_flag,
// i.e. there is no leak.
// Allow only the set of signals that is supported on all platforms, see for Windows: https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/signal?view=vs-2017
ASSERT(signal == SIGABRT || signal == SIGFPE || signal == SIGILL || signal == SIGINT || signal == SIGSEGV || signal == SIGTERM, "Unknown signal");
}
private:
SignalCatcherRegisterer _registerer;
SignalHandlerRAII _handler;
DISALLOW_COPY_AND_ASSIGN(SignalCatcherImpl);
};
}
SignalCatcher::SignalCatcher(std::initializer_list<int> signals)
: _signal_occurred(false)
, _impls() {
// note: the order of the members ensures that:
// - when the signal handler is set, the _signal_occurred flag is already initialized.
// - the _signal_occurred flag will not be destructed as long as the signal handler might be called (i.e. as long as _impls lives)
_impls.reserve(signals.size());
for (int signal : signals) {
_impls.emplace_back(make_unique<details::SignalCatcherImpl>(signal, &_signal_occurred));
}
}
SignalCatcher::~SignalCatcher() {}
}

View File

@ -0,0 +1,42 @@
#pragma once
#ifndef MESSMER_CPPUTILS_PROCESS_SIGNALCATCHER_H_
#define MESSMER_CPPUTILS_PROCESS_SIGNALCATCHER_H_
#include <cpp-utils/macros.h>
#include <atomic>
#include <csignal>
#include <memory>
#include <vector>
namespace cpputils {
namespace details {
class SignalCatcherImpl;
}
/*
* While an instance of this class is in scope, the specified signal (e.g. SIGINT)
* is caught and doesn't exit the application. You can poll if the signal occurred.
*/
class SignalCatcher final {
public:
SignalCatcher(): SignalCatcher({SIGINT, SIGTERM}) {}
SignalCatcher(std::initializer_list<int> signals);
~SignalCatcher();
bool signal_occurred() const {
return _signal_occurred;
}
private:
// note: _signal_occurred must be initialized before _impl because _impl might use it
std::atomic<bool> _signal_occurred;
std::vector<std::unique_ptr<details::SignalCatcherImpl>> _impls;
DISALLOW_COPY_AND_ASSIGN(SignalCatcher);
};
}
#endif

View File

@ -0,0 +1 @@
#include "LeftRight.h"

View File

@ -0,0 +1,165 @@
#include <atomic>
#include <functional>
#include <mutex>
#include <thread>
#include <cpp-utils/macros.h>
#include <array>
namespace cpputils {
namespace detail {
struct IncrementRAII final {
public:
explicit IncrementRAII(std::atomic<int32_t> *counter): _counter(counter) {
++(*_counter);
}
~IncrementRAII() {
--(*_counter);
}
private:
std::atomic<int32_t> *_counter;
DISALLOW_COPY_AND_ASSIGN(IncrementRAII);
};
}
// LeftRight wait-free readers synchronization primitive
// https://hal.archives-ouvertes.fr/hal-01207881/document
template <class T>
class LeftRight final {
public:
LeftRight()
: _writeMutex()
, _foregroundCounterIndex{0}
, _foregroundDataIndex{0}
, _counters{{{0}, {0}}}
, _data{{{}, {}}}
, _inDestruction(false) {}
~LeftRight() {
// from now on, no new readers/writers will be accepted (see asserts in read()/write())
_inDestruction = true;
// wait until any potentially running writers are finished
{
std::unique_lock<std::mutex> lock(_writeMutex);
}
// wait until any potentially running readers are finished
while (_counters[0].load() != 0 || _counters[1].load() != 0) {
std::this_thread::yield();
}
}
template <typename F>
auto read(F&& readFunc) const {
if(_inDestruction.load()) {
throw std::logic_error("Issued LeftRight::read() after the destructor started running");
}
detail::IncrementRAII _increment_counter(&_counters[_foregroundCounterIndex.load()]); // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index)
return readFunc(_data[_foregroundDataIndex.load()]); // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index)
}
// Throwing from write would result in invalid state
template <typename F>
auto write(F&& writeFunc) {
if(_inDestruction.load()) {
throw std::logic_error("Issued LeftRight::read() after the destructor started running");
}
std::unique_lock<std::mutex> lock(_writeMutex);
return _write(writeFunc);
}
private:
template <class F>
auto _write(const F& writeFunc) {
/*
* Assume, A is in background and B in foreground. In simplified terms, we want to do the following:
* 1. Write to A (old background)
* 2. Switch A/B
* 3. Write to B (new background)
*
* More detailed algorithm (explanations on why this is important are below in code):
* 1. Write to A
* 2. Switch A/B data pointers
* 3. Wait until A counter is zero
* 4. Switch A/B counters
* 5. Wait until B counter is zero
* 6. Write to B
*/
auto localDataIndex = _foregroundDataIndex.load();
// 1. Write to A
_callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
// 2. Switch A/B data pointers
localDataIndex = localDataIndex ^ 1;
_foregroundDataIndex = localDataIndex;
/*
* 3. Wait until A counter is zero
*
* In the previous write run, A was foreground and B was background.
* There was a time after switching _foregroundDataIndex (B to foreground) and before switching _foregroundCounterIndex,
* in which new readers could have read B but incremented A's counter.
*
* In this current run, we just switched _foregroundDataIndex (A back to foreground), but before writing to
* the new background B, we have to make sure A's counter was zero briefly, so all these old readers are gone.
*/
auto localCounterIndex = _foregroundCounterIndex.load();
_waitForBackgroundCounterToBeZero(localCounterIndex);
/*
*4. Switch A/B counters
*
* Now that we know all readers on B are really gone, we can switch the counters and have new readers
* increment A's counter again, which is the correct counter since they're reading A.
*/
localCounterIndex = localCounterIndex ^ 1;
_foregroundCounterIndex = localCounterIndex;
/*
* 5. Wait until B counter is zero
*
* This waits for all the readers on B that came in while both data and counter for B was in foreground,
* i.e. normal readers that happened outside of that brief gap between switching data and counter.
*/
_waitForBackgroundCounterToBeZero(localCounterIndex);
// 6. Write to B
_callWriteFuncOnBackgroundInstance(writeFunc, localDataIndex);
}
template<class F>
auto _callWriteFuncOnBackgroundInstance(const F& writeFunc, uint8_t localDataIndex) {
try {
return writeFunc(_data[localDataIndex ^ 1]); // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index)
} catch (...) {
// recover invariant by copying from the foreground instance
_data[localDataIndex ^ 1] = _data[localDataIndex]; // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index)
// rethrow
throw;
}
}
void _waitForBackgroundCounterToBeZero(uint8_t counterIndex) {
while (_counters[counterIndex ^ 1].load() != 0) { // NOLINT(cppcoreguidelines-pro-bounds-constant-array-index)
std::this_thread::yield();
}
}
std::mutex _writeMutex;
std::atomic<uint8_t> _foregroundCounterIndex;
std::atomic<uint8_t> _foregroundDataIndex;
mutable std::array<std::atomic<int32_t>, 2> _counters;
std::array<T, 2> _data;
std::atomic<bool> _inDestruction;
};
}

View File

@ -32,6 +32,7 @@ CryConfig::CryConfig()
, _exclusiveClientId(none)
#ifndef CRYFS_NO_COMPATIBILITY
, _hasVersionNumbers(true)
, _hasParentPointers(true)
#endif
{
}
@ -53,6 +54,7 @@ CryConfig CryConfig::load(const Data &data) {
cfg._exclusiveClientId = pt.get_optional<uint32_t>("cryfs.exclusiveClientId");
#ifndef CRYFS_NO_COMPATIBILITY
cfg._hasVersionNumbers = pt.get<bool>("cryfs.migrations.hasVersionNumbers", false);
cfg._hasParentPointers = pt.get<bool>("cryfs.migrations.hasParentPointers", false);
#endif
optional<string> filesystemIdOpt = pt.get_optional<string>("cryfs.filesystemId");
@ -81,6 +83,7 @@ Data CryConfig::save() const {
}
#ifndef CRYFS_NO_COMPATIBILITY
pt.put<bool>("cryfs.migrations.hasVersionNumbers", _hasVersionNumbers);
pt.put<bool>("cryfs.migrations.hasParentPointers", _hasParentPointers);
#endif
stringstream stream;
@ -172,6 +175,14 @@ bool CryConfig::HasVersionNumbers() const {
void CryConfig::SetHasVersionNumbers(bool value) {
_hasVersionNumbers = value;
}
bool CryConfig::HasParentPointers() const {
return _hasParentPointers;
}
void CryConfig::SetHasParentPointers(bool value) {
_hasParentPointers = value;
}
#endif
}

View File

@ -56,6 +56,11 @@ public:
// Version numbers cannot be disabled, but the file system will be migrated to version numbers automatically.
bool HasVersionNumbers() const;
void SetHasVersionNumbers(bool value);
// This is a trigger to recognize old file systems that didn't have version numbers.
// Version numbers cannot be disabled, but the file system will be migrated to version numbers automatically.
bool HasParentPointers() const;
void SetHasParentPointers(bool value);
#endif
static CryConfig load(const cpputils::Data &data);
@ -73,6 +78,7 @@ private:
boost::optional<uint32_t> _exclusiveClientId;
#ifndef CRYFS_NO_COMPATIBILITY
bool _hasVersionNumbers;
bool _hasParentPointers;
#endif
CryConfig &operator=(const CryConfig &rhs) = delete;

View File

@ -80,7 +80,14 @@ unique_ref<fsblobstore::FsBlobStore> CryDevice::MigrateOrCreateFsBlobStore(uniqu
if ("" == rootBlobId) {
return make_unique_ref<FsBlobStore>(std::move(blobStore));
}
return FsBlobStore::migrateIfNeeded(std::move(blobStore), BlockId::FromString(rootBlobId));
if (!configFile->config()->HasParentPointers()) {
auto result = FsBlobStore::migrate(std::move(blobStore), BlockId::FromString(rootBlobId));
// Don't migrate again if it was successful
configFile->config()->SetHasParentPointers(true);
configFile->save();
return result;
}
return make_unique_ref<FsBlobStore>(std::move(blobStore));
}
#endif
@ -106,6 +113,7 @@ unique_ref<BlockStore2> CryDevice::CreateIntegrityEncryptedBlockStore(unique_ref
if (!configFile->config()->HasVersionNumbers()) {
IntegrityBlockStore2::migrateFromBlockstoreWithoutVersionNumbers(encryptedBlockStore.get(), integrityFilePath, myClientId);
configFile->config()->SetBlocksizeBytes(configFile->config()->BlocksizeBytes() + IntegrityBlockStore2::HEADER_LENGTH - blockstore::BlockId::BINARY_LENGTH); // Minus BlockId size because EncryptedBlockStore doesn't store the BlockId anymore (that was moved to IntegrityBlockStore)
// Don't migrate again if it was successful
configFile->config()->SetHasVersionNumbers(true);
configFile->save();
}

View File

@ -2,9 +2,13 @@
#include "FileBlob.h"
#include "DirBlob.h"
#include "SymlinkBlob.h"
#include <cryfs/config/CryConfigFile.h>
#include <cpp-utils/io/ProgressBar.h>
#include <cpp-utils/process/SignalCatcher.h>
using cpputils::unique_ref;
using cpputils::make_unique_ref;
using cpputils::SignalCatcher;
using blobstore::BlobStore;
using blockstore::BlockId;
using boost::none;
@ -31,33 +35,41 @@ boost::optional<unique_ref<FsBlob>> FsBlobStore::load(const blockstore::BlockId
}
#ifndef CRYFS_NO_COMPATIBILITY
unique_ref<FsBlobStore> FsBlobStore::migrateIfNeeded(unique_ref<BlobStore> blobStore, const blockstore::BlockId &rootBlobId) {
unique_ref<FsBlobStore> FsBlobStore::migrate(unique_ref<BlobStore> blobStore, const blockstore::BlockId &rootBlobId) {
SignalCatcher signalCatcher;
auto rootBlob = blobStore->load(rootBlobId);
ASSERT(rootBlob != none, "Could not load root blob");
uint16_t format = FsBlobView::getFormatVersionHeader(**rootBlob);
auto fsBlobStore = make_unique_ref<FsBlobStore>(std::move(blobStore));
if (format == 0) {
// migration needed
std::cout << "Migrating file system for conflict resolution features. Please don't interrupt this process. This can take a while..." << std::flush;
fsBlobStore->_migrate(std::move(*rootBlob), blockstore::BlockId::Null());
std::cout << "done" << std::endl;
}
uint64_t migratedBlocks = 0;
cpputils::ProgressBar progressbar("Migrating file system for conflict resolution features. This can take a while...", fsBlobStore->numBlocks());
fsBlobStore->_migrate(std::move(*rootBlob), blockstore::BlockId::Null(), &signalCatcher, [&] (uint32_t numNodes) {
migratedBlocks += numNodes;
progressbar.update(migratedBlocks);
});
return fsBlobStore;
}
void FsBlobStore::_migrate(unique_ref<blobstore::Blob> node, const blockstore::BlockId &parentId) {
void FsBlobStore::_migrate(unique_ref<blobstore::Blob> node, const blockstore::BlockId &parentId, SignalCatcher* signalCatcher, std::function<void(uint32_t numNodes)> perBlobCallback) {
FsBlobView::migrate(node.get(), parentId);
perBlobCallback(node->numNodes());
if (FsBlobView::blobType(*node) == FsBlobView::BlobType::DIR) {
DirBlob dir(std::move(node), _getLstatSize());
vector<fspp::Dir::Entry> children;
dir.AppendChildrenTo(&children);
for (const auto &child : children) {
if (signalCatcher->signal_occurred()) {
// on a SIGINT or SIGTERM, cancel migration but gracefully shutdown, i.e. call destructors.
throw std::runtime_error("Caught signal");
}
auto childEntry = dir.GetChild(child.name);
ASSERT(childEntry != none, "Couldn't load child, although it was returned as a child in the lsit.");
ASSERT(childEntry != none, "Couldn't load child, although it was returned as a child in the list.");
auto childBlob = _baseBlobStore->load(childEntry->blockId());
ASSERT(childBlob != none, "Couldn't load child blob");
_migrate(std::move(*childBlob), dir.blockId());
_migrate(std::move(*childBlob), dir.blockId(), signalCatcher, perBlobCallback);
}
}
}

View File

@ -8,6 +8,9 @@
#include "FileBlob.h"
#include "DirBlob.h"
#include "SymlinkBlob.h"
#ifndef CRYFS_NO_COMPATIBILITY
#include <cpp-utils/process/SignalCatcher.h>
#endif
namespace cryfs {
namespace fsblobstore {
@ -29,13 +32,13 @@ namespace cryfs {
uint64_t virtualBlocksizeBytes() const;
#ifndef CRYFS_NO_COMPATIBILITY
static cpputils::unique_ref<FsBlobStore> migrateIfNeeded(cpputils::unique_ref<blobstore::BlobStore> blobStore, const blockstore::BlockId &blockId);
static cpputils::unique_ref<FsBlobStore> migrate(cpputils::unique_ref<blobstore::BlobStore> blobStore, const blockstore::BlockId &blockId);
#endif
private:
#ifndef CRYFS_NO_COMPATIBILITY
void _migrate(cpputils::unique_ref<blobstore::Blob> node, const blockstore::BlockId &parentId);
void _migrate(cpputils::unique_ref<blobstore::Blob> node, const blockstore::BlockId &parentId, cpputils::SignalCatcher* signalCatcher, std::function<void(uint32_t numNodes)> perBlobCallback);
#endif
std::function<fspp::num_bytes_t(const blockstore::BlockId &)> _getLstatSize();

View File

@ -10,7 +10,11 @@ namespace cryfs {
void FsBlobView::migrate(blobstore::Blob *blob, const blockstore::BlockId &parentId) {
constexpr unsigned int OLD_HEADER_SIZE = sizeof(FORMAT_VERSION_HEADER) + sizeof(uint8_t);
ASSERT(FsBlobView::getFormatVersionHeader(*blob) == 0, "Block already migrated");
if(FsBlobView::getFormatVersionHeader(*blob) != 0) {
// blob already migrated
return;
}
// Resize blob and move data back
cpputils::Data data = blob->readAll();
blob->resize(blob->size() + blockstore::BlockId::BINARY_LENGTH);

View File

@ -85,6 +85,10 @@ namespace cryfs {
return _baseBlob->flush();
}
uint32_t numNodes() const override {
return _baseBlob->numNodes();
}
cpputils::unique_ref<blobstore::Blob> releaseBaseBlob() {
return std::move(_baseBlob);
}

View File

@ -163,6 +163,18 @@ int main(int argc, char* argv[]) {
std::cerr << "Error loading config file" << std::endl;
exit(1);
}
const auto& config_ = config->configFile.config();
std::cout << "Loading filesystem of version " << config_->Version() << std::endl;
#ifndef CRYFS_NO_COMPATIBILITY
const bool is_correct_format = config_->Version() == CryConfig::FilesystemFormatVersion && !config_->HasParentPointers() && !config_->HasVersionNumbers();
#else
const bool is_correct_format = config_->Version() == CryConfig::FilesystemFormatVersion;
#endif
if (!is_correct_format) {
// TODO At this point, the cryfs.config file was already switched to 0.10 format. We should probably not do that.
std::cerr << "The filesystem is not in the 0.10 format. It needs to be migrated. The cryfs-stats tool unfortunately can't handle this, please mount and unmount the filesystem once." << std::endl;
exit(1);
}
cout << "Listing all blocks..." << flush;
set<BlockId> unaccountedBlocks = _getAllBlockIds(basedir, *config, localStateDir);

View File

@ -21,50 +21,68 @@ INSTANTIATE_TEST_CASE_P(EmptyLastLeaf, DataTreeTest_NumStoredBytes_P, Values(0u)
INSTANTIATE_TEST_CASE_P(HalfFullLastLeaf, DataTreeTest_NumStoredBytes_P, Values(5u, 10u));
INSTANTIATE_TEST_CASE_P(FullLastLeaf, DataTreeTest_NumStoredBytes_P, Values(static_cast<uint32_t>(DataNodeLayout(DataTreeTest_NumStoredBytes::BLOCKSIZE_BYTES).maxBytesPerLeaf())));
//TODO Test numLeaves() and numNodes() also two configurations with same number of bytes but different number of leaves (last leaf has 0 bytes)
TEST_P(DataTreeTest_NumStoredBytes_P, SingleLeaf) {
BlockId blockId = CreateLeafWithSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(GetParam(), tree->numBytes());
EXPECT_EQ(1, tree->numLeaves());
EXPECT_EQ(1, tree->numNodes());
}
TEST_P(DataTreeTest_NumStoredBytes_P, TwoLeafTree) {
BlockId blockId = CreateTwoLeafWithSecondLeafSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes());
EXPECT_EQ(2, tree->numLeaves());
EXPECT_EQ(3, tree->numNodes());
}
TEST_P(DataTreeTest_NumStoredBytes_P, FullTwolevelTree) {
BlockId blockId = CreateFullTwoLevelWithLastLeafSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*(nodeStore->layout().maxChildrenPerInnerNode()-1) + GetParam(), tree->numBytes());
EXPECT_EQ(nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves());
EXPECT_EQ(1 + nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes());
}
TEST_P(DataTreeTest_NumStoredBytes_P, ThreeLevelTreeWithOneChild) {
BlockId blockId = CreateThreeLevelWithOneChildAndLastLeafSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes());
EXPECT_EQ(2, tree->numLeaves());
EXPECT_EQ(4, tree->numNodes());
}
TEST_P(DataTreeTest_NumStoredBytes_P, ThreeLevelTreeWithTwoChildren) {
BlockId blockId = CreateThreeLevelWithTwoChildrenAndLastLeafSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes());
EXPECT_EQ(2 + nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves());
EXPECT_EQ(5 + nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes());
}
TEST_P(DataTreeTest_NumStoredBytes_P, ThreeLevelTreeWithThreeChildren) {
BlockId blockId = CreateThreeLevelWithThreeChildrenAndLastLeafSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(2*nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxBytesPerLeaf() + GetParam(), tree->numBytes());
EXPECT_EQ(2 + 2*nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves());
EXPECT_EQ(6 + 2*nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes());
}
TEST_P(DataTreeTest_NumStoredBytes_P, FullThreeLevelTree) {
BlockId blockId = CreateFullThreeLevelWithLastLeafSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode()*(nodeStore->layout().maxChildrenPerInnerNode()-1) + nodeStore->layout().maxBytesPerLeaf()*(nodeStore->layout().maxChildrenPerInnerNode()-1) + GetParam(), tree->numBytes());
EXPECT_EQ(nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves());
EXPECT_EQ(1 + nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes());
}
TEST_P(DataTreeTest_NumStoredBytes_P, FourLevelMinDataTree) {
BlockId blockId = CreateFourLevelMinDataWithLastLeafSize(GetParam())->blockId();
auto tree = treeStore.load(blockId).value();
EXPECT_EQ(nodeStore->layout().maxBytesPerLeaf()*nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode() + GetParam(), tree->numBytes());
EXPECT_EQ(1 + nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numLeaves());
EXPECT_EQ(5 + nodeStore->layout().maxChildrenPerInnerNode() + nodeStore->layout().maxChildrenPerInnerNode()*nodeStore->layout().maxChildrenPerInnerNode(), tree->numNodes());
}

View File

@ -16,6 +16,7 @@ set(SOURCES
process/daemonize_include_test.cpp
process/subprocess_include_test.cpp
process/SubprocessTest.cpp
process/SignalCatcherTest.cpp
tempfile/TempFileTest.cpp
tempfile/TempFileIncludeTest.cpp
tempfile/TempDirIncludeTest.cpp
@ -28,6 +29,7 @@ set(SOURCES
io/ConsoleTest_Print.cpp
io/ConsoleTest_Ask.cpp
io/ConsoleTest_AskPassword.cpp
io/ProgressBarTest.cpp
random/RandomIncludeTest.cpp
lock/LockPoolIncludeTest.cpp
lock/ConditionBarrierIncludeTest.cpp
@ -55,6 +57,7 @@ set(SOURCES
system/HomedirTest.cpp
system/EnvTest.cpp
thread/debugging_test.cpp
thread/LeftRightTest.cpp
value_type/ValueTypeTest.cpp
)

View File

@ -0,0 +1,66 @@
#include <cpp-utils/io/ProgressBar.h>
#include <gmock/gmock.h>
using cpputils::ProgressBar;
using std::make_shared;
class MockConsole final: public cpputils::Console {
public:
void EXPECT_OUTPUT(const char* expected) {
EXPECT_EQ(expected, _output);
_output = "";
}
void print(const std::string& text) override {
_output += text;
}
unsigned int ask(const std::string&, const std::vector<std::string>&) override {
EXPECT_TRUE(false);
return 0;
}
bool askYesNo(const std::string&, bool) override {
EXPECT_TRUE(false);
return false;
}
std::string askPassword(const std::string&) override {
EXPECT_TRUE(false);
return "";
}
private:
std::string _output;
};
TEST(ProgressBarTest, testProgressBar) {
auto console = make_shared<MockConsole>();
ProgressBar bar(console, "Preamble", 2000);
console->EXPECT_OUTPUT("\n\rPreamble 0%");
// when updating to 0, doesn't reprint
bar.update(0);
console->EXPECT_OUTPUT("");
// update to half
bar.update(1000);
console->EXPECT_OUTPUT("\rPreamble 50%");
// when updating to same value, doesn't reprint
bar.update(1000);
console->EXPECT_OUTPUT("");
// when updating to value with same percentage, doesn't reprint
bar.update(1001);
console->EXPECT_OUTPUT("");
// update to 0
bar.update(0);
console->EXPECT_OUTPUT("\rPreamble 0%");
// update to full
bar.update(2000);
console->EXPECT_OUTPUT("\rPreamble 100%");
}

View File

@ -0,0 +1,193 @@
#include <cpp-utils/process/SignalCatcher.h>
#include <gtest/gtest.h>
#include <csignal>
using cpputils::SignalCatcher;
namespace {
void raise_signal(int signal) {
int error = ::raise(signal);
if (error != 0) {
throw std::runtime_error("Error raising signal");
}
}
}
TEST(SignalCatcherTest, givenNoSignalCatcher_whenRaisingSigint_thenDies) {
EXPECT_DEATH(
raise_signal(SIGINT),
""
);
}
TEST(SignalCatcherTest, givenNoSignalCatcher_whenRaisingSigterm_thenDies) {
EXPECT_DEATH(
raise_signal(SIGTERM),
""
);
}
TEST(SignalCatcherTest, givenSigIntCatcher_whenRaisingSigInt_thenCatches) {
SignalCatcher catcher({SIGINT});
EXPECT_FALSE(catcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(catcher.signal_occurred());
// raise again
raise_signal(SIGINT);
EXPECT_TRUE(catcher.signal_occurred());
}
TEST(SignalCatcherTest, givenSigTermCatcher_whenRaisingSigTerm_thenCatches) {
SignalCatcher catcher({SIGTERM});
EXPECT_FALSE(catcher.signal_occurred());
raise_signal(SIGTERM);
EXPECT_TRUE(catcher.signal_occurred());
// raise again
raise_signal(SIGTERM);
EXPECT_TRUE(catcher.signal_occurred());
}
TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigInt_thenCatches) {
SignalCatcher catcher({SIGINT, SIGTERM});
EXPECT_FALSE(catcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(catcher.signal_occurred());
// raise again
raise_signal(SIGINT);
EXPECT_TRUE(catcher.signal_occurred());
}
TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigTerm_thenCatches) {
SignalCatcher catcher({SIGINT, SIGTERM});
EXPECT_FALSE(catcher.signal_occurred());
raise_signal(SIGTERM);
EXPECT_TRUE(catcher.signal_occurred());
// raise again
raise_signal(SIGTERM);
EXPECT_TRUE(catcher.signal_occurred());
}
TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigIntAndSigTerm_thenCatches) {
SignalCatcher catcher({SIGINT, SIGTERM});
EXPECT_FALSE(catcher.signal_occurred());
raise_signal(SIGTERM);
EXPECT_TRUE(catcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(catcher.signal_occurred());
}
TEST(SignalCatcherTest, givenSigIntCatcherAndSigTermCatcher_whenRaisingSignalsInOrder_thenCorrectCatcherCatches) {
SignalCatcher sigintCatcher({SIGINT});
SignalCatcher sigtermCatcher({SIGTERM});
EXPECT_FALSE(sigintCatcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(sigintCatcher.signal_occurred());
EXPECT_FALSE(sigtermCatcher.signal_occurred());
raise_signal(SIGTERM);
EXPECT_TRUE(sigtermCatcher.signal_occurred());
}
TEST(SignalCatcherTest, givenSigIntCatcherAndSigTermCatcher_whenRaisingSignalsInReverseOrder_thenCorrectCatcherCatches) {
SignalCatcher sigintCatcher({SIGINT});
SignalCatcher sigtermCatcher({SIGTERM});
EXPECT_FALSE(sigtermCatcher.signal_occurred());
raise_signal(SIGTERM);
EXPECT_TRUE(sigtermCatcher.signal_occurred());
EXPECT_FALSE(sigintCatcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(sigintCatcher.signal_occurred());
}
TEST(SignalCatcherTest, givenNestedSigIntCatchers_whenRaisingSignals_thenCorrectCatcherCatches) {
SignalCatcher outerCatcher({SIGINT});
{
SignalCatcher middleCatcher({SIGINT});
EXPECT_FALSE(middleCatcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(middleCatcher.signal_occurred());
{
SignalCatcher innerCatcher({SIGINT});
EXPECT_FALSE(innerCatcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(innerCatcher.signal_occurred());
}
}
EXPECT_FALSE(outerCatcher.signal_occurred());
raise_signal(SIGINT);
EXPECT_TRUE(outerCatcher.signal_occurred());
}
TEST(SignalCatcherTest, givenExpiredSigIntCatcher_whenRaisingSigInt_thenDies) {
{
SignalCatcher catcher({SIGINT});
}
EXPECT_DEATH(
raise_signal(SIGINT),
""
);
}
TEST(SignalCatcherTest, givenExpiredSigTermCatcher_whenRaisingSigTerm_thenDies) {
{
SignalCatcher catcher({SIGTERM});
}
EXPECT_DEATH(
raise_signal(SIGTERM),
""
);
}
TEST(SignalCatcherTest, givenExpiredSigIntCatcherAndSigTermCatcher_whenRaisingSigTerm_thenDies) {
{
SignalCatcher sigIntCatcher({SIGTERM});
SignalCatcher sigTermCatcer({SIGTERM});
}
EXPECT_DEATH(
raise_signal(SIGTERM),
""
);
}
TEST(SignalCatcherTest, givenSigTermCatcherAndExpiredSigIntCatcher_whenRaisingSigTerm_thenCatches) {
SignalCatcher sigTermCatcher({SIGTERM});
{
SignalCatcher sigIntCatcher({SIGINT});
}
EXPECT_FALSE(sigTermCatcher.signal_occurred());
raise_signal(SIGTERM);
EXPECT_TRUE(sigTermCatcher.signal_occurred());
}
TEST(SignalCatcherTest, givenSigTermCatcherAndExpiredSigIntCatcher_whenRaisingSigInt_thenDies) {
SignalCatcher sigTermCacher({SIGTERM});
{
SignalCatcher sigIntCatcher({SIGINT});
}
EXPECT_DEATH(
raise_signal(SIGINT),
""
);
}

View File

@ -0,0 +1,204 @@
#include <cpp-utils/thread/LeftRight.h>
#include <gtest/gtest.h>
#include <vector>
using cpputils::LeftRight;
using std::vector;
TEST(LeftRightTest, givenInt_whenWritingAndReading_thenChangesArePresent) {
LeftRight<int> obj;
obj.write([] (auto& obj) {obj = 5;});
int read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ(5, read);
// check changes are also present in background copy
obj.write([] (auto&) {}); // this switches to the background copy
read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ(5, read);
}
TEST(LeftRightTest, givenVector_whenWritingAndReading_thenChangesArePresent) {
LeftRight<vector<int>> obj;
obj.write([] (auto& obj) {obj.push_back(5);});
vector<int> read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ((vector<int>{5}), read);
obj.write([] (auto& obj) {obj.push_back(6);});
read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ((vector<int>{5, 6}), read);
}
TEST(LeftRightTest, readsCanBeConcurrent) {
LeftRight<int> obj;
std::atomic<int> num_running_readers{0};
std::thread reader1([&] () {
obj.read([&] (auto&) {
++num_running_readers;
while(num_running_readers.load() < 2) {}
});
});
std::thread reader2([&] () {
obj.read([&] (auto&) {
++num_running_readers;
while(num_running_readers.load() < 2) {}
});
});
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
reader1.join();
reader2.join();
}
TEST(LeftRightTest, writesCanBeConcurrentWithReads_readThenWrite) {
LeftRight<int> obj;
std::atomic<bool> reader_running{false};
std::atomic<bool> writer_running{false};
std::thread reader([&] () {
obj.read([&] (auto&) {
reader_running = true;
while(!writer_running.load()) {}
});
});
std::thread writer([&] () {
// run read first, write second
while (!reader_running.load()) {}
obj.write([&] (auto&) {
writer_running = true;
});
});
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
reader.join();
writer.join();
}
TEST(LeftRightTest, writesCanBeConcurrentWithReads_writeThenRead) {
LeftRight<int> obj;
std::atomic<bool> writer_running{false};
std::atomic<bool> reader_running{false};
std::thread writer([&] () {
obj.read([&] (auto&) {
writer_running = true;
while(!reader_running.load()) {}
});
});
std::thread reader([&] () {
// run write first, read second
while (!writer_running.load()) {}
obj.read([&] (auto&) {
reader_running = true;
});
});
// the threads only finish after both entered the read function.
// if LeftRight didn't allow concurrency, this would cause a deadlock.
writer.join();
reader.join();
}
TEST(LeftRightTest, writesCannotBeConcurrentWithWrites) {
LeftRight<int> obj;
std::atomic<bool> first_writer_started{false};
std::atomic<bool> first_writer_finished{false};
std::thread writer1([&] () {
obj.write([&] (auto&) {
first_writer_started = true;
std::this_thread::sleep_for(std::chrono::milliseconds(50));
first_writer_finished = true;
});
});
std::thread writer2([&] () {
// make sure the other writer runs first
while (!first_writer_started.load()) {}
obj.write([&] (auto&) {
// expect the other writer finished before this one starts
EXPECT_TRUE(first_writer_finished.load());
});
});
writer1.join();
writer2.join();
}
namespace {
class MyException : std::exception {};
}
TEST(LeftRightTest, whenReadThrowsException_thenThrowsThrough) {
LeftRight<int> obj;
EXPECT_THROW(
obj.read([](auto&) {throw MyException();}),
MyException
);
}
TEST(LeftRightTest, whenWriteThrowsException_thenThrowsThrough) {
LeftRight<int> obj;
EXPECT_THROW(
obj.write([](auto&) {throw MyException();}),
MyException
);
}
TEST(LeftRightTest, givenInt_whenWriteThrowsException_thenResetsToOldState) {
LeftRight<int> obj;
obj.write([](auto& obj) {obj = 5;});
EXPECT_THROW(
obj.write([](auto& obj) {
obj = 6;
throw MyException();
}),
MyException
);
// check reading it returns old value
int read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ(5, read);
// check changes are also present in background copy
obj.write([] (auto&) {}); // this switches to the background copy
read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ(5, read);
}
TEST(LeftRightTest, givenVector_whenWriteThrowsException_thenResetsToOldState) {
LeftRight<vector<int>> obj;
obj.write([](auto& obj) {obj.push_back(5);});
EXPECT_THROW(
obj.write([](auto& obj) {
obj.push_back(6);
throw MyException();
}),
MyException
);
// check reading it returns old value
vector<int> read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ((vector<int>{5}), read);
// check changes are also present in background copy
obj.write([] (auto&) {}); // this switches to the background copy
read = obj.read([] (auto& obj) {return obj;});
EXPECT_EQ((vector<int>{5}), read);
}