diff --git a/src/blockstore/implementations/versioncounting/ClientIdAndBlockKey.h b/src/blockstore/implementations/versioncounting/ClientIdAndBlockKey.h new file mode 100644 index 00000000..aa83ca18 --- /dev/null +++ b/src/blockstore/implementations/versioncounting/ClientIdAndBlockKey.h @@ -0,0 +1,33 @@ +#pragma once +#ifndef MESSMER_BLOCKSTORE_IMPLEMENTATIONS_VERSIONCOUNTING_CLIENTIDANDBLOCKKEY_H_ +#define MESSMER_BLOCKSTORE_IMPLEMENTATIONS_VERSIONCOUNTING_CLIENTIDANDBLOCKKEY_H_ + +#include + +namespace blockstore { + namespace versioncounting { + + struct ClientIdAndBlockKey { + uint32_t clientId; + Key blockKey; + }; + + } +} + +// Allow using it in std::unordered_set / std::unordered_map +namespace std { + template<> struct hash { + size_t operator()(const blockstore::versioncounting::ClientIdAndBlockKey &ref) const { + return std::hash()(ref.clientId) ^ std::hash()(ref.blockKey); + } + }; + + template<> struct equal_to { + size_t operator()(const blockstore::versioncounting::ClientIdAndBlockKey &lhs, const blockstore::versioncounting::ClientIdAndBlockKey &rhs) const { + return lhs.clientId == rhs.clientId && lhs.blockKey == rhs.blockKey; + } + }; +} + +#endif diff --git a/src/blockstore/implementations/versioncounting/KnownBlockVersions.cpp b/src/blockstore/implementations/versioncounting/KnownBlockVersions.cpp index 63003783..6c0d23f0 100644 --- a/src/blockstore/implementations/versioncounting/KnownBlockVersions.cpp +++ b/src/blockstore/implementations/versioncounting/KnownBlockVersions.cpp @@ -1,6 +1,8 @@ #include #include #include "KnownBlockVersions.h" +#include +#include namespace bf = boost::filesystem; using std::unordered_map; @@ -8,7 +10,10 @@ using std::pair; using std::string; using boost::optional; using boost::none; +using cpputils::Data; using cpputils::Random; +using cpputils::Serializer; +using cpputils::Deserializer; namespace blockstore { namespace versioncounting { @@ -31,10 +36,10 @@ KnownBlockVersions::~KnownBlockVersions() { } } -bool KnownBlockVersions::checkAndUpdateVersion(const Key &key, uint64_t version) { +bool KnownBlockVersions::checkAndUpdateVersion(uint32_t clientId, const Key &key, uint64_t version) { ASSERT(_valid, "Object not valid due to a std::move"); - uint64_t &found = _knownVersions[key]; // If the entry doesn't exist, this creates it with value 0. + uint64_t &found = _knownVersions[{clientId, key}]; // If the entry doesn't exist, this creates it with value 0. if (found > version) { return false; } @@ -44,73 +49,59 @@ bool KnownBlockVersions::checkAndUpdateVersion(const Key &key, uint64_t version) } void KnownBlockVersions::updateVersion(const Key &key, uint64_t version) { - if (!checkAndUpdateVersion(key, version)) { + if (!checkAndUpdateVersion(_myClientId, key, version)) { throw std::logic_error("Tried to decrease block version"); } } void KnownBlockVersions::_loadStateFile() { - std::ifstream file(_stateFilePath.native().c_str()); - if (!file.good()) { + optional file = Data::LoadFromFile(_stateFilePath); + if (file == none) { // File doesn't exist means we loaded empty state. Assign a random client id. _myClientId = *reinterpret_cast(Random::PseudoRandom().getFixedSize().data()); return; } - _checkHeader(&file); - file.read((char*)&_myClientId, sizeof(_myClientId)); - ASSERT(file.good(), "Error reading file"); - uint64_t numEntries; - file.read((char*)&numEntries, sizeof(numEntries)); - ASSERT(file.good(), "Error reading file"); + + Deserializer deserializer(&*file); + if (HEADER != deserializer.readString()) { + throw std::runtime_error("Invalid local state: Invalid integrity file header."); + } + _myClientId = deserializer.readUint32(); + uint64_t numEntries = deserializer.readUint64(); _knownVersions.clear(); _knownVersions.reserve(static_cast(1.2 * numEntries)); // Reserve for factor 1.2 more, so the file system doesn't immediately have to resize it on the first new block. for (uint64_t i = 0 ; i < numEntries; ++i) { - auto entry = _readEntry(&file); + auto entry = _readEntry(&deserializer); _knownVersions.insert(entry); } - _checkIsEof(&file); + deserializer.finished(); }; -void KnownBlockVersions::_checkHeader(std::ifstream *file) { - char actualHeader[HEADER.size()]; - file->read(actualHeader, HEADER.size()); - ASSERT(file->good(), "Error reading file"); - if (HEADER != string(actualHeader, HEADER.size())) { - throw std::runtime_error("Invalid local state: Invalid integrity file header."); - } -} +pair KnownBlockVersions::_readEntry(Deserializer *deserializer) { + uint32_t clientId = deserializer->readUint32(); + Key blockKey = deserializer->readFixedSizeData(); + uint64_t version = deserializer->readUint64(); -pair KnownBlockVersions::_readEntry(std::ifstream *file) { - pair result(Key::Null(), 0); - - file->read((char*)result.first.data(), result.first.BINARY_LENGTH); - ASSERT(file->good(), "Error reading file"); - file->read((char*)&result.second, sizeof(result.second)); - ASSERT(file->good(), "Error reading file"); - - return result; + return {{clientId, blockKey}, version}; }; -void KnownBlockVersions::_checkIsEof(std::ifstream *file) { - char dummy; - file->read(&dummy, sizeof(dummy)); - if (!file->eof()) { - throw std::runtime_error("There are more entries in the file than advertised"); - } -} - void KnownBlockVersions::_saveStateFile() const { - std::ofstream file(_stateFilePath.native().c_str()); - file.write(HEADER.c_str(), HEADER.size()); - file.write((char*)&_myClientId, sizeof(_myClientId)); uint64_t numEntries = _knownVersions.size(); - file.write((char*)&numEntries, sizeof(numEntries)); + + Serializer serializer(Serializer::StringSize(HEADER) + sizeof(uint32_t) + sizeof(uint64_t) + numEntries * (sizeof(uint32_t) + Key::BINARY_LENGTH + sizeof(uint64_t))); + serializer.writeString(HEADER); + serializer.writeUint32(_myClientId); + serializer.writeUint64(numEntries); + for (const auto &entry : _knownVersions) { - file.write((char*)entry.first.data(), entry.first.BINARY_LENGTH); - file.write((char*)&entry.second, sizeof(entry.second)); + serializer.writeUint32(entry.first.clientId); + serializer.writeFixedSizeData(entry.first.blockKey); + serializer.writeUint64(entry.second); } + + serializer.finished().StoreToFile(_stateFilePath); } uint32_t KnownBlockVersions::myClientId() const { diff --git a/src/blockstore/implementations/versioncounting/KnownBlockVersions.h b/src/blockstore/implementations/versioncounting/KnownBlockVersions.h index 4522af46..4e696f37 100644 --- a/src/blockstore/implementations/versioncounting/KnownBlockVersions.h +++ b/src/blockstore/implementations/versioncounting/KnownBlockVersions.h @@ -6,6 +6,8 @@ #include #include #include +#include "ClientIdAndBlockKey.h" +#include namespace blockstore { namespace versioncounting { @@ -17,14 +19,14 @@ namespace blockstore { ~KnownBlockVersions(); __attribute__((warn_unused_result)) - bool checkAndUpdateVersion(const Key &key, uint64_t version); + bool checkAndUpdateVersion(uint32_t clientId, const Key &key, uint64_t version); void updateVersion(const Key &key, uint64_t version); uint32_t myClientId() const; private: - std::unordered_map _knownVersions; + std::unordered_map _knownVersions; boost::filesystem::path _stateFilePath; uint32_t _myClientId; bool _valid; @@ -32,9 +34,7 @@ namespace blockstore { static const std::string HEADER; void _loadStateFile(); - static void _checkHeader(std::ifstream *file); - static std::pair _readEntry(std::ifstream *file); - static void _checkIsEof(std::ifstream *file); + static std::pair _readEntry(cpputils::Deserializer *deserializer); void _saveStateFile() const; DISALLOW_COPY_AND_ASSIGN(KnownBlockVersions); diff --git a/src/blockstore/implementations/versioncounting/VersionCountingBlock.h b/src/blockstore/implementations/versioncounting/VersionCountingBlock.h index b6fa00f2..726bdd1b 100644 --- a/src/blockstore/implementations/versioncounting/VersionCountingBlock.h +++ b/src/blockstore/implementations/versioncounting/VersionCountingBlock.h @@ -54,7 +54,8 @@ private: static cpputils::Data _prependHeaderToData(uint32_t myClientId, uint64_t version, cpputils::Data data); static void _checkFormatHeader(const cpputils::Data &data); static uint64_t _readVersion(const cpputils::Data &data); - static bool _versionIsNondecreasing(const Key &key, uint64_t version, KnownBlockVersions *knownBlockVersions); + static uint32_t _readClientId(const cpputils::Data &data); + static bool _versionIsNondecreasing(uint32_t clientId, const Key &key, uint64_t version, KnownBlockVersions *knownBlockVersions); // This header is prepended to blocks to allow future versions to have compatibility. static constexpr uint16_t FORMAT_VERSION_HEADER = 0; @@ -93,8 +94,9 @@ inline boost::optional> VersionCounti cpputils::Data data(baseBlock->size()); std::memcpy(data.data(), baseBlock->data(), data.size()); _checkFormatHeader(data); + uint32_t lastClientId = _readClientId(data); uint64_t version = _readVersion(data); - if(!_versionIsNondecreasing(baseBlock->key(), version, knownBlockVersions)) { + if(!_versionIsNondecreasing(lastClientId, baseBlock->key(), version, knownBlockVersions)) { //The stored key in the block data is incorrect - an attacker might have exchanged the contents with the encrypted data from a different block cpputils::logging::LOG(cpputils::logging::WARN) << "Decrypting block " << baseBlock->key().ToString() << " failed due to wrong version number. Was the block rolled back by an attacker?"; return boost::none; @@ -108,14 +110,20 @@ inline void VersionCountingBlock::_checkFormatHeader(const cpputils::Data &data) } } +inline uint32_t VersionCountingBlock::_readClientId(const cpputils::Data &data) { + uint32_t clientId; + std::memcpy(&clientId, data.dataOffset(sizeof(FORMAT_VERSION_HEADER)), sizeof(clientId)); + return clientId; +} + inline uint64_t VersionCountingBlock::_readVersion(const cpputils::Data &data) { uint64_t version; std::memcpy(&version, data.dataOffset(sizeof(FORMAT_VERSION_HEADER) + sizeof(uint32_t)), sizeof(version)); return version; } -inline bool VersionCountingBlock::_versionIsNondecreasing(const Key &key, uint64_t version, KnownBlockVersions *knownBlockVersions) { - return knownBlockVersions->checkAndUpdateVersion(key, version); +inline bool VersionCountingBlock::_versionIsNondecreasing(uint32_t clientId, const Key &key, uint64_t version, KnownBlockVersions *knownBlockVersions) { + return knownBlockVersions->checkAndUpdateVersion(clientId, key, version); } inline VersionCountingBlock::VersionCountingBlock(cpputils::unique_ref baseBlock, cpputils::Data dataWithHeader, uint64_t version, KnownBlockVersions *knownBlockVersions) diff --git a/src/cpp-utils/data/Deserializer.h b/src/cpp-utils/data/Deserializer.h index a51a92a6..3533b75c 100644 --- a/src/cpp-utils/data/Deserializer.h +++ b/src/cpp-utils/data/Deserializer.h @@ -5,6 +5,7 @@ #include "Data.h" #include "../macros.h" #include "../assert/assert.h" +#include "FixedSizeData.h" namespace cpputils { class Deserializer final { @@ -21,6 +22,7 @@ namespace cpputils { int64_t readInt64(); std::string readString(); Data readData(); + template FixedSizeData readFixedSizeData(); Data readTailData(); void finished(); @@ -28,6 +30,7 @@ namespace cpputils { private: template DataType _read(); Data _readData(size_t size); + void _readData(void *target, size_t size); size_t _pos; const Data *_source; @@ -96,8 +99,19 @@ namespace cpputils { inline Data Deserializer::_readData(size_t size) { Data result(size); - std::memcpy(static_cast(result.data()), static_cast(_source->dataOffset(_pos)), size); + _readData(result.data(), size); + return result; + } + + inline void Deserializer::_readData(void *target, size_t size) { + std::memcpy(static_cast(target), static_cast(_source->dataOffset(_pos)), size); _pos += size; + } + + template + inline FixedSizeData Deserializer::readFixedSizeData() { + FixedSizeData result(FixedSizeData::Null()); + _readData(result.data(), SIZE); return result; } diff --git a/src/cpp-utils/data/Serializer.h b/src/cpp-utils/data/Serializer.h index f1f8b42c..dbfef896 100644 --- a/src/cpp-utils/data/Serializer.h +++ b/src/cpp-utils/data/Serializer.h @@ -3,6 +3,7 @@ #define MESSMER_CPPUTILS_DATA_SERIALIZER_H #include "Data.h" +#include "FixedSizeData.h" #include "../macros.h" #include "../assert/assert.h" #include @@ -24,6 +25,7 @@ namespace cpputils { void writeInt64(int64_t value); void writeString(const std::string &value); void writeData(const Data &value); + template void writeFixedSizeData(const FixedSizeData &value); // Write the data as last element when serializing. // It does not store a data size but limits the size by the size of the serialization result @@ -36,7 +38,7 @@ namespace cpputils { private: template void _write(DataType obj); - void _writeData(const Data &value); + void _writeData(const void *data, size_t count); size_t _pos; Data _result; @@ -91,33 +93,33 @@ namespace cpputils { inline void Serializer::writeData(const Data &data) { writeUint64(data.size()); - _writeData(data); + _writeData(data.data(), data.size()); } inline size_t Serializer::DataSize(const Data &data) { return sizeof(uint64_t) + data.size(); } - inline void Serializer::writeTailData(const Data &data) { - ASSERT(_pos + data.size() == _result.size(), "Not enough data given to write until the end of the stream"); - _writeData(data); + template + inline void Serializer::writeFixedSizeData(const FixedSizeData &data) { + _writeData(data.data(), SIZE); } - inline void Serializer::_writeData(const Data &data) { - if (_pos + data.size() > _result.size()) { + inline void Serializer::writeTailData(const Data &data) { + ASSERT(_pos + data.size() == _result.size(), "Not enough data given to write until the end of the stream"); + _writeData(data.data(), data.size()); + } + + inline void Serializer::_writeData(const void *data, size_t count) { + if (_pos + count > _result.size()) { throw std::runtime_error("Serialization failed - size overflow"); } - std::memcpy(static_cast(_result.dataOffset(_pos)), static_cast(data.data()), data.size()); - _pos += data.size(); + std::memcpy(static_cast(_result.dataOffset(_pos)), static_cast(data), count); + _pos += count; } inline void Serializer::writeString(const std::string &value) { - size_t size = value.size() + 1; // +1 for the nullbyte - if (_pos + size > _result.size()) { - throw std::runtime_error("Serialization failed - size overflow"); - } - std::memcpy(static_cast(_result.dataOffset(_pos)), value.c_str(), size); - _pos += size; + _writeData(value.c_str(), value.size() + 1); // +1 for the nullbyte } inline size_t Serializer::StringSize(const std::string &value) { diff --git a/test/blockstore/implementations/versioncounting/KnownBlockVersionsTest.cpp b/test/blockstore/implementations/versioncounting/KnownBlockVersionsTest.cpp index 372d2d64..fefd9279 100644 --- a/test/blockstore/implementations/versioncounting/KnownBlockVersionsTest.cpp +++ b/test/blockstore/implementations/versioncounting/KnownBlockVersionsTest.cpp @@ -11,9 +11,16 @@ public: blockstore::Key key = blockstore::Key::FromString("1491BB4932A389EE14BC7090AC772972"); blockstore::Key key2 = blockstore::Key::FromString("C772972491BB4932A1389EE14BC7090A"); + uint32_t clientId = 0x12345678; + uint32_t clientId2 = 0x23456789; TempFile stateFile; KnownBlockVersions testobj; + + void EXPECT_VERSION_IS(uint64_t version, KnownBlockVersions *testobj, blockstore::Key &key, uint32_t clientId) { + EXPECT_FALSE(testobj->checkAndUpdateVersion(clientId, key, version-1)); + EXPECT_TRUE(testobj->checkAndUpdateVersion(clientId, key, version)); + } }; TEST_F(KnownBlockVersionsTest, update_newEntry_zero) { @@ -46,80 +53,93 @@ TEST_F(KnownBlockVersionsTest, update_existingEntry_invalid) { ); } +TEST_F(KnownBlockVersionsTest, update_updatesOwnClientId) { + testobj.updateVersion(key, 100); + EXPECT_VERSION_IS(100, &testobj, key, testobj.myClientId()); +} + TEST_F(KnownBlockVersionsTest, checkAndUpdate_newEntry_zero) { - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 0)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 0)); } TEST_F(KnownBlockVersionsTest, checkAndUpdate_newEntry_nonzero) { - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); } TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_equal_zero) { - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 0)); - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 0)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 0)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 0)); } TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_equal_nonzero) { - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 100)); - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); } TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_nonequal) { - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 100)); - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 101)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 101)); } TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_invalid) { - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 100)); - EXPECT_FALSE(testobj.checkAndUpdateVersion(key, 99)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); + EXPECT_FALSE(testobj.checkAndUpdateVersion(clientId, key, 99)); } TEST_F(KnownBlockVersionsTest, checkAndUpdate_existingEntry_invalidDoesntModifyEntry) { - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 100)); - EXPECT_FALSE(testobj.checkAndUpdateVersion(key, 99)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); + EXPECT_FALSE(testobj.checkAndUpdateVersion(clientId, key, 99)); - EXPECT_FALSE(testobj.checkAndUpdateVersion(key, 99)); - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 100)); + EXPECT_VERSION_IS(100, &testobj, key, clientId); } -TEST_F(KnownBlockVersionsTest, checkAndUpdate_twoEntriesDontInfluenceEachOther) { - testobj.updateVersion(key, 100); - testobj.updateVersion(key2, 100); +TEST_F(KnownBlockVersionsTest, checkAndUpdate_twoEntriesDontInfluenceEachOther_differentKeys) { + // Setup + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key2, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 150)); - testobj.updateVersion(key, 150); + // Checks + EXPECT_VERSION_IS(150, &testobj, key, clientId); + EXPECT_VERSION_IS(100, &testobj, key2, clientId); +} - EXPECT_FALSE(testobj.checkAndUpdateVersion(key, 149)); - EXPECT_TRUE(testobj.checkAndUpdateVersion(key, 150)); - EXPECT_FALSE(testobj.checkAndUpdateVersion(key2, 99)); - EXPECT_TRUE(testobj.checkAndUpdateVersion(key2, 100)); +TEST_F(KnownBlockVersionsTest, checkAndUpdate_twoEntriesDontInfluenceEachOther_differentClientIds) { + // Setup + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId2, key, 100)); + EXPECT_TRUE(testobj.checkAndUpdateVersion(clientId, key, 150)); + + EXPECT_VERSION_IS(150, &testobj, key, clientId); + EXPECT_VERSION_IS(100, &testobj, key, clientId2); } TEST_F(KnownBlockVersionsTest, saveAndLoad_empty) { TempFile stateFile(false); KnownBlockVersions(stateFile.path()); - EXPECT_TRUE(KnownBlockVersions(stateFile.path()).checkAndUpdateVersion(key, 0)); + EXPECT_TRUE(KnownBlockVersions(stateFile.path()).checkAndUpdateVersion(clientId, key, 0)); } TEST_F(KnownBlockVersionsTest, saveAndLoad_oneentry) { TempFile stateFile(false); - KnownBlockVersions(stateFile.path()).updateVersion(key, 100); + EXPECT_TRUE(KnownBlockVersions(stateFile.path()).checkAndUpdateVersion(clientId, key, 100)); - EXPECT_FALSE(KnownBlockVersions(stateFile.path()).checkAndUpdateVersion(key, 99)); - EXPECT_TRUE(KnownBlockVersions(stateFile.path()).checkAndUpdateVersion(key, 100)); + KnownBlockVersions obj(stateFile.path()); + EXPECT_VERSION_IS(100, &obj, key, clientId); } -TEST_F(KnownBlockVersionsTest, saveAndLoad_twoentries) { +TEST_F(KnownBlockVersionsTest, saveAndLoad_threeentries) { TempFile stateFile(false); { KnownBlockVersions obj(stateFile.path()); obj.updateVersion(key, 100); obj.updateVersion(key2, 50); + EXPECT_TRUE(obj.checkAndUpdateVersion(clientId, key, 150)); } KnownBlockVersions obj(stateFile.path()); - EXPECT_FALSE(obj.checkAndUpdateVersion(key, 99)); - EXPECT_TRUE(obj.checkAndUpdateVersion(key, 100)); - EXPECT_FALSE(obj.checkAndUpdateVersion(key2, 49)); - EXPECT_TRUE(obj.checkAndUpdateVersion(key2, 50)); + EXPECT_VERSION_IS(100, &obj, key, obj.myClientId()); + EXPECT_VERSION_IS(50, &obj, key2, obj.myClientId()); + EXPECT_VERSION_IS(150, &obj, key, clientId); }