* Prevent rollback to the "newest" version of a client when this version was superseded by a version from a different client.

* Use mutex/locks to secure access to KnownBlockVersions
This commit is contained in:
Sebastian Messmer 2016-06-22 16:27:35 -07:00
parent 263c540cd0
commit 4d1f7a46b9
6 changed files with 222 additions and 52 deletions

View File

@ -1,13 +1,13 @@
#include <fstream> #include <fstream>
#include <cpp-utils/random/Random.h> #include <cpp-utils/random/Random.h>
#include "KnownBlockVersions.h" #include "KnownBlockVersions.h"
#include <cpp-utils/data/Serializer.h>
#include <cpp-utils/data/Deserializer.h>
namespace bf = boost::filesystem; namespace bf = boost::filesystem;
using std::unordered_map; using std::unordered_map;
using std::pair; using std::pair;
using std::string; using std::string;
using std::unique_lock;
using std::mutex;
using boost::optional; using boost::optional;
using boost::none; using boost::none;
using cpputils::Data; using cpputils::Data;
@ -21,30 +21,50 @@ namespace versioncounting {
const string KnownBlockVersions::HEADER = "cryfs.integritydata.knownblockversions;0"; const string KnownBlockVersions::HEADER = "cryfs.integritydata.knownblockversions;0";
KnownBlockVersions::KnownBlockVersions(const bf::path &stateFilePath) KnownBlockVersions::KnownBlockVersions(const bf::path &stateFilePath)
:_knownVersions(), _stateFilePath(stateFilePath), _myClientId(0), _valid(true) { :_knownVersions(), _lastUpdateClientId(), _stateFilePath(stateFilePath), _myClientId(0), _mutex(), _valid(true) {
unique_lock<mutex> lock(_mutex);
_loadStateFile(); _loadStateFile();
} }
KnownBlockVersions::KnownBlockVersions(KnownBlockVersions &&rhs) KnownBlockVersions::KnownBlockVersions(KnownBlockVersions &&rhs)
: _knownVersions(std::move(rhs._knownVersions)), _stateFilePath(std::move(rhs._stateFilePath)), _myClientId(rhs._myClientId), _valid(true) { : _knownVersions(), _lastUpdateClientId(), _stateFilePath(), _myClientId(0), _mutex(), _valid(true) {
unique_lock<mutex> rhsLock(rhs._mutex);
unique_lock<mutex> lock(_mutex);
_knownVersions = std::move(rhs._knownVersions);
_lastUpdateClientId = std::move(rhs._lastUpdateClientId);
_stateFilePath = std::move(rhs._stateFilePath);
_myClientId = rhs._myClientId;
rhs._valid = false; rhs._valid = false;
} }
KnownBlockVersions::~KnownBlockVersions() { KnownBlockVersions::~KnownBlockVersions() {
unique_lock<mutex> lock(_mutex);
if (_valid) { if (_valid) {
_saveStateFile(); _saveStateFile();
} }
} }
bool KnownBlockVersions::checkAndUpdateVersion(uint32_t clientId, const Key &key, uint64_t version) { bool KnownBlockVersions::checkAndUpdateVersion(uint32_t clientId, const Key &key, uint64_t version) {
unique_lock<mutex> lock(_mutex);
ASSERT(version > 0, "Version has to be >0"); // Otherwise we wouldn't handle notexisting entries correctly.
ASSERT(_valid, "Object not valid due to a std::move"); ASSERT(_valid, "Object not valid due to a std::move");
uint64_t &found = _knownVersions[{clientId, key}]; // If the entry doesn't exist, this creates it with value 0. uint64_t &found = _knownVersions[{clientId, key}]; // If the entry doesn't exist, this creates it with value 0.
if (found > version) { if (found > version) {
// This client already published a newer block version. Rollbacks are not allowed.
return false;
}
uint32_t &lastUpdateClientId = _lastUpdateClientId[key]; // If entry doesn't exist, this creates it with value 0. However, in this case, found == 0 (and version > 0), which means found != version.
if (found == version && lastUpdateClientId != clientId) {
// This is a roll back to the "newest" block of client [clientId], which was since then superseded by a version from client _lastUpdateClientId[key].
// This is not allowed.
return false; return false;
} }
found = version; found = version;
lastUpdateClientId = clientId;
return true; return true;
} }
@ -67,19 +87,46 @@ void KnownBlockVersions::_loadStateFile() {
throw std::runtime_error("Invalid local state: Invalid integrity file header."); throw std::runtime_error("Invalid local state: Invalid integrity file header.");
} }
_myClientId = deserializer.readUint32(); _myClientId = deserializer.readUint32();
uint64_t numEntries = deserializer.readUint64(); _deserializeKnownVersions(&deserializer);
_deserializeLastUpdateClientIds(&deserializer);
_knownVersions.clear();
_knownVersions.reserve(static_cast<uint64_t>(1.2 * numEntries)); // Reserve for factor 1.2 more, so the file system doesn't immediately have to resize it on the first new block.
for (uint64_t i = 0 ; i < numEntries; ++i) {
auto entry = _readEntry(&deserializer);
_knownVersions.insert(entry);
}
deserializer.finished(); deserializer.finished();
}; };
pair<ClientIdAndBlockKey, uint64_t> KnownBlockVersions::_readEntry(Deserializer *deserializer) {
void KnownBlockVersions::_saveStateFile() const {
Serializer serializer(
Serializer::StringSize(HEADER) + sizeof(uint32_t) +
sizeof(uint64_t) + _knownVersions.size() * (sizeof(uint32_t) + Key::BINARY_LENGTH + sizeof(uint64_t)) +
sizeof(uint64_t) + _lastUpdateClientId.size() * (Key::BINARY_LENGTH + sizeof(uint32_t)));
serializer.writeString(HEADER);
serializer.writeUint32(_myClientId);
_serializeKnownVersions(&serializer);
_serializeLastUpdateClientIds(&serializer);
serializer.finished().StoreToFile(_stateFilePath);
}
void KnownBlockVersions::_deserializeKnownVersions(Deserializer *deserializer) {
uint64_t numEntries = deserializer->readUint64();
_knownVersions.clear();
_knownVersions.reserve(static_cast<uint64_t>(1.2 * numEntries)); // Reserve for factor 1.2 more, so the file system doesn't immediately have to resize it on the first new block.
for (uint64_t i = 0 ; i < numEntries; ++i) {
auto entry = _deserializeKnownVersionsEntry(deserializer);
_knownVersions.insert(entry);
}
}
void KnownBlockVersions::_serializeKnownVersions(Serializer *serializer) const {
uint64_t numEntries = _knownVersions.size();
serializer->writeUint64(numEntries);
for (const auto &entry : _knownVersions) {
_serializeKnownVersionsEntry(serializer, entry);
}
}
pair<ClientIdAndBlockKey, uint64_t> KnownBlockVersions::_deserializeKnownVersionsEntry(Deserializer *deserializer) {
uint32_t clientId = deserializer->readUint32(); uint32_t clientId = deserializer->readUint32();
Key blockKey = deserializer->readFixedSizeData<Key::BINARY_LENGTH>(); Key blockKey = deserializer->readFixedSizeData<Key::BINARY_LENGTH>();
uint64_t version = deserializer->readUint64(); uint64_t version = deserializer->readUint64();
@ -87,21 +134,41 @@ pair<ClientIdAndBlockKey, uint64_t> KnownBlockVersions::_readEntry(Deserializer
return {{clientId, blockKey}, version}; return {{clientId, blockKey}, version};
}; };
void KnownBlockVersions::_saveStateFile() const { void KnownBlockVersions::_serializeKnownVersionsEntry(Serializer *serializer, const pair<ClientIdAndBlockKey, uint64_t> &entry) {
uint64_t numEntries = _knownVersions.size(); serializer->writeUint32(entry.first.clientId);
serializer->writeFixedSizeData<Key::BINARY_LENGTH>(entry.first.blockKey);
serializer->writeUint64(entry.second);
}
Serializer serializer(Serializer::StringSize(HEADER) + sizeof(uint32_t) + sizeof(uint64_t) + numEntries * (sizeof(uint32_t) + Key::BINARY_LENGTH + sizeof(uint64_t))); void KnownBlockVersions::_deserializeLastUpdateClientIds(Deserializer *deserializer) {
serializer.writeString(HEADER); uint64_t numEntries = deserializer->readUint64();
serializer.writeUint32(_myClientId); _lastUpdateClientId.clear();
serializer.writeUint64(numEntries); _lastUpdateClientId.reserve(static_cast<uint64_t>(1.2 * numEntries)); // Reserve for factor 1.2 more, so the file system doesn't immediately have to resize it on the first new block.
for (uint64_t i = 0 ; i < numEntries; ++i) {
for (const auto &entry : _knownVersions) { auto entry = _deserializeLastUpdateClientIdEntry(deserializer);
serializer.writeUint32(entry.first.clientId); _lastUpdateClientId.insert(entry);
serializer.writeFixedSizeData<Key::BINARY_LENGTH>(entry.first.blockKey);
serializer.writeUint64(entry.second);
} }
}
serializer.finished().StoreToFile(_stateFilePath); void KnownBlockVersions::_serializeLastUpdateClientIds(Serializer *serializer) const {
uint64_t numEntries = _lastUpdateClientId.size();
serializer->writeUint64(numEntries);
for (const auto &entry : _lastUpdateClientId) {
_serializeLastUpdateClientIdEntry(serializer, entry);
}
}
pair<Key, uint32_t> KnownBlockVersions::_deserializeLastUpdateClientIdEntry(Deserializer *deserializer) {
Key blockKey = deserializer->readFixedSizeData<Key::BINARY_LENGTH>();
uint32_t clientId = deserializer->readUint32();
return {blockKey, clientId};
};
void KnownBlockVersions::_serializeLastUpdateClientIdEntry(Serializer *serializer, const pair<Key, uint32_t> &entry) {
serializer->writeFixedSizeData<Key::BINARY_LENGTH>(entry.first);
serializer->writeUint32(entry.second);
} }
uint32_t KnownBlockVersions::myClientId() const { uint32_t KnownBlockVersions::myClientId() const {

View File

@ -8,6 +8,8 @@
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include "ClientIdAndBlockKey.h" #include "ClientIdAndBlockKey.h"
#include <cpp-utils/data/Deserializer.h> #include <cpp-utils/data/Deserializer.h>
#include <cpp-utils/data/Serializer.h>
#include <mutex>
namespace blockstore { namespace blockstore {
namespace versioncounting { namespace versioncounting {
@ -27,16 +29,30 @@ namespace blockstore {
private: private:
std::unordered_map<ClientIdAndBlockKey, uint64_t> _knownVersions; std::unordered_map<ClientIdAndBlockKey, uint64_t> _knownVersions;
std::unordered_map<Key, uint32_t> _lastUpdateClientId; // The client who last updated the block
boost::filesystem::path _stateFilePath; boost::filesystem::path _stateFilePath;
uint32_t _myClientId; uint32_t _myClientId;
mutable std::mutex _mutex;
bool _valid; bool _valid;
static const std::string HEADER; static const std::string HEADER;
void _loadStateFile(); void _loadStateFile();
static std::pair<ClientIdAndBlockKey, uint64_t> _readEntry(cpputils::Deserializer *deserializer);
void _saveStateFile() const; void _saveStateFile() const;
void _deserializeKnownVersions(cpputils::Deserializer *deserializer);
void _serializeKnownVersions(cpputils::Serializer *serializer) const;
static std::pair<ClientIdAndBlockKey, uint64_t> _deserializeKnownVersionsEntry(cpputils::Deserializer *deserializer);
static void _serializeKnownVersionsEntry(cpputils::Serializer *serializer, const std::pair<ClientIdAndBlockKey, uint64_t> &entry);
void _deserializeLastUpdateClientIds(cpputils::Deserializer *deserializer);
void _serializeLastUpdateClientIds(cpputils::Serializer *serializer) const;
static std::pair<Key, uint32_t> _deserializeLastUpdateClientIdEntry(cpputils::Deserializer *deserializer);
static void _serializeLastUpdateClientIdEntry(cpputils::Serializer *serializer, const std::pair<Key, uint32_t> &entry);
DISALLOW_COPY_AND_ASSIGN(KnownBlockVersions); DISALLOW_COPY_AND_ASSIGN(KnownBlockVersions);
}; };

View File

@ -2,6 +2,8 @@
namespace blockstore { namespace blockstore {
namespace versioncounting { namespace versioncounting {
constexpr unsigned int VersionCountingBlock::CLIENTID_HEADER_OFFSET;
constexpr unsigned int VersionCountingBlock::VERSION_HEADER_OFFSET;
constexpr unsigned int VersionCountingBlock::HEADER_LENGTH; constexpr unsigned int VersionCountingBlock::HEADER_LENGTH;
constexpr uint16_t VersionCountingBlock::FORMAT_VERSION_HEADER; constexpr uint16_t VersionCountingBlock::FORMAT_VERSION_HEADER;
constexpr uint64_t VersionCountingBlock::VERSION_ZERO; constexpr uint64_t VersionCountingBlock::VERSION_ZERO;

View File

@ -16,10 +16,10 @@
#include <cpp-utils/data/DataUtils.h> #include <cpp-utils/data/DataUtils.h>
#include <mutex> #include <mutex>
#include <cpp-utils/logging/logging.h> #include <cpp-utils/logging/logging.h>
#include "../../../../vendor/googletest/gtest-1.7.0/googletest/include/gtest/gtest_prod.h"
namespace blockstore { namespace blockstore {
namespace versioncounting { namespace versioncounting {
class VersionCountingBlockStore;
// TODO Is an implementation that doesn't keep an in-memory copy but just passes through write() calls to the underlying block store (including a write call to the version number each time) faster? // TODO Is an implementation that doesn't keep an in-memory copy but just passes through write() calls to the underlying block store (including a write call to the version number each time) faster?
@ -59,12 +59,16 @@ private:
// This header is prepended to blocks to allow future versions to have compatibility. // This header is prepended to blocks to allow future versions to have compatibility.
static constexpr uint16_t FORMAT_VERSION_HEADER = 0; static constexpr uint16_t FORMAT_VERSION_HEADER = 0;
static constexpr uint64_t VERSION_ZERO = 0; static constexpr uint64_t VERSION_ZERO = 1; // lowest block version is '1', because that is required by class KnownBlockVersions.
static constexpr unsigned int HEADER_LENGTH = sizeof(FORMAT_VERSION_HEADER) + sizeof(uint32_t) + sizeof(VERSION_ZERO);
std::mutex _mutex; std::mutex _mutex;
DISALLOW_COPY_AND_ASSIGN(VersionCountingBlock); DISALLOW_COPY_AND_ASSIGN(VersionCountingBlock);
public:
static constexpr unsigned int CLIENTID_HEADER_OFFSET = sizeof(FORMAT_VERSION_HEADER);
static constexpr unsigned int VERSION_HEADER_OFFSET = sizeof(FORMAT_VERSION_HEADER) + sizeof(uint32_t);
static constexpr unsigned int HEADER_LENGTH = sizeof(FORMAT_VERSION_HEADER) + sizeof(uint32_t) + sizeof(VERSION_ZERO);
}; };
@ -84,8 +88,8 @@ inline cpputils::Data VersionCountingBlock::_prependHeaderToData(uint32_t myClie
static_assert(HEADER_LENGTH == sizeof(FORMAT_VERSION_HEADER) + sizeof(myClientId) + sizeof(version), "Wrong header length"); static_assert(HEADER_LENGTH == sizeof(FORMAT_VERSION_HEADER) + sizeof(myClientId) + sizeof(version), "Wrong header length");
cpputils::Data result(data.size() + HEADER_LENGTH); cpputils::Data result(data.size() + HEADER_LENGTH);
std::memcpy(result.dataOffset(0), &FORMAT_VERSION_HEADER, sizeof(FORMAT_VERSION_HEADER)); std::memcpy(result.dataOffset(0), &FORMAT_VERSION_HEADER, sizeof(FORMAT_VERSION_HEADER));
std::memcpy(result.dataOffset(sizeof(FORMAT_VERSION_HEADER)), &myClientId, sizeof(myClientId)); std::memcpy(result.dataOffset(CLIENTID_HEADER_OFFSET), &myClientId, sizeof(myClientId));
std::memcpy(result.dataOffset(sizeof(FORMAT_VERSION_HEADER)+sizeof(myClientId)), &version, sizeof(version)); std::memcpy(result.dataOffset(VERSION_HEADER_OFFSET), &version, sizeof(version));
std::memcpy((uint8_t*)result.dataOffset(HEADER_LENGTH), data.data(), data.size()); std::memcpy((uint8_t*)result.dataOffset(HEADER_LENGTH), data.data(), data.size());
return result; return result;
} }
@ -112,13 +116,13 @@ inline void VersionCountingBlock::_checkFormatHeader(const cpputils::Data &data)
inline uint32_t VersionCountingBlock::_readClientId(const cpputils::Data &data) { inline uint32_t VersionCountingBlock::_readClientId(const cpputils::Data &data) {
uint32_t clientId; uint32_t clientId;
std::memcpy(&clientId, data.dataOffset(sizeof(FORMAT_VERSION_HEADER)), sizeof(clientId)); std::memcpy(&clientId, data.dataOffset(CLIENTID_HEADER_OFFSET), sizeof(clientId));
return clientId; return clientId;
} }
inline uint64_t VersionCountingBlock::_readVersion(const cpputils::Data &data) { inline uint64_t VersionCountingBlock::_readVersion(const cpputils::Data &data) {
uint64_t version; uint64_t version;
std::memcpy(&version, data.dataOffset(sizeof(FORMAT_VERSION_HEADER) + sizeof(uint32_t)), sizeof(version)); std::memcpy(&version, data.dataOffset(VERSION_HEADER_OFFSET), sizeof(version));
return version; return version;
} }
@ -170,8 +174,8 @@ inline void VersionCountingBlock::_storeToBaseBlock() {
if (_dataChanged) { if (_dataChanged) {
++_version; ++_version;
uint32_t myClientId = _knownBlockVersions->myClientId(); uint32_t myClientId = _knownBlockVersions->myClientId();
std::memcpy(_dataWithHeader.dataOffset(sizeof(FORMAT_VERSION_HEADER)), &myClientId, sizeof(myClientId)); std::memcpy(_dataWithHeader.dataOffset(CLIENTID_HEADER_OFFSET), &myClientId, sizeof(myClientId));
std::memcpy(_dataWithHeader.dataOffset(sizeof(FORMAT_VERSION_HEADER) + sizeof(myClientId)), &_version, sizeof(_version)); std::memcpy(_dataWithHeader.dataOffset(VERSION_HEADER_OFFSET), &_version, sizeof(_version));
_baseBlock->write(_dataWithHeader.data(), 0, _dataWithHeader.size()); _baseBlock->write(_dataWithHeader.data(), 0, _dataWithHeader.size());
_knownBlockVersions->updateVersion(key(), _version); _knownBlockVersions->updateVersion(key(), _version);
_dataChanged = false; _dataChanged = false;

View File

@ -19,24 +19,24 @@ public:
void EXPECT_VERSION_IS(uint64_t version, KnownBlockVersions *testobj, blockstore::Key &key, uint32_t clientId) { void EXPECT_VERSION_IS(uint64_t version, KnownBlockVersions *testobj, blockstore::Key &key, uint32_t clientId) {
EXPECT_FALSE(testobj->checkAndUpdateVersion(clientId, key, version-1)); EXPECT_FALSE(testobj->checkAndUpdateVersion(clientId, key, version-1));
EXPECT_TRUE(testobj->checkAndUpdateVersion(clientId, key, version)); EXPECT_TRUE(testobj->checkAndUpdateVersion(clientId, key, version+1));
} }
}; };
TEST_F(KnownBlockVersionsTest, update_newEntry_zero) { TEST_F(KnownBlockVersionsTest, update_newEntry_lowversion) {
testobj.updateVersion(key, 0); testobj.updateVersion(key, 1);
} }
TEST_F(KnownBlockVersionsTest, update_newEntry_nonzero) { TEST_F(KnownBlockVersionsTest, update_newEntry_highversion) {
testobj.updateVersion(key, 100); testobj.updateVersion(key, 100);
} }
TEST_F(KnownBlockVersionsTest, update_existingEntry_equal_zero) { TEST_F(KnownBlockVersionsTest, update_existingEntry_equal_lowversion) {
testobj.updateVersion(key, 0); testobj.updateVersion(key, 1);
testobj.updateVersion(key, 0); testobj.updateVersion(key, 1);
} }
TEST_F(KnownBlockVersionsTest, update_existingEntry_equal_nonzero) { TEST_F(KnownBlockVersionsTest, update_existingEntry_equal_highversion) {
testobj.updateVersion(key, 100); testobj.updateVersion(key, 100);
testobj.updateVersion(key, 100); testobj.updateVersion(key, 100);
} }
@ -58,20 +58,20 @@ TEST_F(KnownBlockVersionsTest, update_updatesOwnClientId) {
EXPECT_VERSION_IS(100, &testobj, key, testobj.myClientId()); EXPECT_VERSION_IS(100, &testobj, key, testobj.myClientId());
} }
TEST_F(KnownBlockVersionsTest, checkAndUpdate_newEntry_zero) { TEST_F(KnownBlockVersionsTest, checkAndUpdate_newEntry_lowversion) {
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 0)); EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 1));
} }
TEST_F(KnownBlockVersionsTest, checkAndUpdate_newEntry_nonzero) { TEST_F(KnownBlockVersionsTest, checkAndUpdate_newEntry_highversion) {
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100));
} }
TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_equal_zero) { TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_equal_lowversion) {
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 0)); EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 1));
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 0)); EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 1));
} }
TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_equal_nonzero) { TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_equal_highversion) {
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100));
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100));
} }
@ -114,11 +114,22 @@ TEST_F(KnownBlockVersionsTest, checkAndUpdate_twoEntriesDontInfluenceEachOther_d
EXPECT_VERSION_IS(100, &testobj, key, clientId2); EXPECT_VERSION_IS(100, &testobj, key, clientId2);
} }
TEST_F(KnownBlockVersionsTest, checkAndUpdate_allowsRollbackToSameClientWithSameVersionNumber) {
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100));
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100));
}
TEST_F(KnownBlockVersionsTest, checkAndUpdate_doesntAllowRollbackToOldClientWithSameVersionNumber) {
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100));
EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId2, key, 10));
EXPECT_FALSE(testobj.checkAndUpdateVersion(clientId, key, 100));
}
TEST_F(KnownBlockVersionsTest, saveAndLoad_empty) { TEST_F(KnownBlockVersionsTest, saveAndLoad_empty) {
TempFile stateFile(false); TempFile stateFile(false);
KnownBlockVersions(stateFile.path()); KnownBlockVersions(stateFile.path());
EXPECT_TRUE(KnownBlockVersions(stateFile.path()).checkAndUpdateVersion(clientId, key, 0)); EXPECT_TRUE(KnownBlockVersions(stateFile.path()).checkAndUpdateVersion(clientId, key, 1));
} }
TEST_F(KnownBlockVersionsTest, saveAndLoad_oneentry) { TEST_F(KnownBlockVersionsTest, saveAndLoad_oneentry) {
@ -143,3 +154,16 @@ TEST_F(KnownBlockVersionsTest, saveAndLoad_threeentries) {
EXPECT_VERSION_IS(50, &obj, key2, obj.myClientId()); EXPECT_VERSION_IS(50, &obj, key2, obj.myClientId());
EXPECT_VERSION_IS(150, &obj, key, clientId); EXPECT_VERSION_IS(150, &obj, key, clientId);
} }
TEST_F(KnownBlockVersionsTest, saveAndLoad_lastUpdateClientIdIsStored) {
{
KnownBlockVersions obj(stateFile.path());
EXPECT_TRUE(obj.checkAndUpdateVersion(clientId, key, 100));
EXPECT_TRUE(obj.checkAndUpdateVersion(clientId2, key, 10));
}
KnownBlockVersions obj(stateFile.path());
EXPECT_FALSE(obj.checkAndUpdateVersion(clientId, key, 100));
EXPECT_TRUE(obj.checkAndUpdateVersion(clientId2, key, 10));
EXPECT_TRUE(obj.checkAndUpdateVersion(clientId, key, 101));
}

View File

@ -46,6 +46,13 @@ public:
return result; return result;
} }
Data loadBlock(const blockstore::Key &key) {
auto block = blockStore->load(key).value();
Data result(block->size());
std::memcpy(result.data(), block->data(), data.size());
return result;
}
void modifyBlock(const blockstore::Key &key) { void modifyBlock(const blockstore::Key &key) {
auto block = blockStore->load(key).value(); auto block = blockStore->load(key).value();
uint64_t data = 5; uint64_t data = 5;
@ -58,11 +65,27 @@ public:
block->write(data.data(), 0, data.size()); block->write(data.data(), 0, data.size());
} }
void decreaseVersionNumber(const blockstore::Key &key) {
auto baseBlock = baseBlockStore->load(key).value();
uint64_t version = *(uint64_t*)((uint8_t*)baseBlock->data()+VersionCountingBlock::VERSION_HEADER_OFFSET);
ASSERT(version > 1, "Can't decrease the lowest allowed version number");
version -= 1;
baseBlock->write((char*)&version, VersionCountingBlock::VERSION_HEADER_OFFSET, sizeof(version));
}
void changeClientId(const blockstore::Key &key) {
auto baseBlock = baseBlockStore->load(key).value();
uint32_t clientId = *(uint32_t*)((uint8_t*)baseBlock->data()+VersionCountingBlock::CLIENTID_HEADER_OFFSET);
clientId += 1;
baseBlock->write((char*)&clientId, VersionCountingBlock::CLIENTID_HEADER_OFFSET, sizeof(clientId));
}
private: private:
DISALLOW_COPY_AND_ASSIGN(VersionCountingBlockStoreTest); DISALLOW_COPY_AND_ASSIGN(VersionCountingBlockStoreTest);
}; };
TEST_F(VersionCountingBlockStoreTest, DoesntAllowRollbacks) { // Test that a decreasing version number is not allowed
TEST_F(VersionCountingBlockStoreTest, RollbackPrevention_DoesntAllowDecreasingVersionNumberForSameClient_1) {
auto key = CreateBlockReturnKey(); auto key = CreateBlockReturnKey();
Data oldBaseBlock = loadBaseBlock(key); Data oldBaseBlock = loadBaseBlock(key);
modifyBlock(key); modifyBlock(key);
@ -70,6 +93,40 @@ TEST_F(VersionCountingBlockStoreTest, DoesntAllowRollbacks) {
EXPECT_EQ(boost::none, blockStore->load(key)); EXPECT_EQ(boost::none, blockStore->load(key));
} }
TEST_F(VersionCountingBlockStoreTest, RollbackPrevention_DoesntAllowDecreasingVersionNumberForSameClient_2) {
auto key = CreateBlockReturnKey();
// Increase the version number
modifyBlock(key);
// Decrease the version number again
decreaseVersionNumber(key);
EXPECT_EQ(boost::none, blockStore->load(key));
}
// Test that a different client doesn't need to have a higher version number (i.e. version numbers are per client).
TEST_F(VersionCountingBlockStoreTest, RollbackPrevention_DoesAllowDecreasingVersionNumberForDifferentClient) {
auto key = CreateBlockReturnKey();
// Increase the version number
modifyBlock(key);
// Fake a modification by a different client with lower version numbers
changeClientId(key);
decreaseVersionNumber(key);
EXPECT_NE(boost::none, blockStore->load(key));
}
// Test that it doesn't allow a rollback to the "newest" block of a client, when this block was superseded by a version of a different client
TEST_F(VersionCountingBlockStoreTest, RollbackPrevention_DoesntAllowSameVersionNumberForOldClient) {
auto key = CreateBlockReturnKey();
// Increase the version number
modifyBlock(key);
Data oldBaseBlock = loadBaseBlock(key);
// Fake a modification by a different client with lower version numbers
changeClientId(key);
loadBlock(key); // make the block store know about this other client's modification
// Rollback to old client
rollbackBaseBlock(key, oldBaseBlock);
EXPECT_EQ(boost::none, blockStore->load(key));
}
TEST_F(VersionCountingBlockStoreTest, PhysicalBlockSize_zerophysical) { TEST_F(VersionCountingBlockStoreTest, PhysicalBlockSize_zerophysical) {
EXPECT_EQ(0u, blockStore->blockSizeFromPhysicalBlockSize(0)); EXPECT_EQ(0u, blockStore->blockSizeFromPhysicalBlockSize(0));
} }