Implement SignalCatcher
This commit is contained in:
parent
97e0a7e031
commit
5386f5b0c9
@ -479,7 +479,7 @@ jobs:
|
||||
OMP_NUM_THREADS: "1"
|
||||
CXXFLAGS: "-O2 -fsanitize=thread -fno-omit-frame-pointer"
|
||||
BUILD_TYPE: "Debug"
|
||||
GTEST_ARGS: "--gtest_filter=-LoggingTest.LoggingAlsoWorksAfterFork:AssertTest_DebugBuild.*:CliTest_Setup.*:CliTest_IntegrityCheck.*:*/CliTest_WrongEnvironment.*:CliTest_Unmount.*"
|
||||
GTEST_ARGS: "--gtest_filter=-LoggingTest.LoggingAlsoWorksAfterFork:AssertTest_DebugBuild.*:SignalCatcherTest.*_thenDies:CliTest_Setup.*:CliTest_IntegrityCheck.*:*/CliTest_WrongEnvironment.*:CliTest_Unmount.*"
|
||||
CMAKE_FLAGS: ""
|
||||
RUN_TESTS: true
|
||||
clang_tidy:
|
||||
|
@ -11,6 +11,7 @@ set(SOURCES
|
||||
crypto/hash/Hash.cpp
|
||||
process/daemonize.cpp
|
||||
process/subprocess.cpp
|
||||
process/SignalCatcher.cpp
|
||||
tempfile/TempFile.cpp
|
||||
tempfile/TempDir.cpp
|
||||
network/HttpClient.cpp
|
||||
@ -26,6 +27,7 @@ set(SOURCES
|
||||
thread/ThreadSystem.cpp
|
||||
thread/debugging_nonwindows.cpp
|
||||
thread/debugging_windows.cpp
|
||||
thread/LeftRight.cpp
|
||||
random/Random.cpp
|
||||
random/RandomGeneratorThread.cpp
|
||||
random/OSRandomGenerator.cpp
|
||||
|
239
src/cpp-utils/process/SignalCatcher.cpp
Normal file
239
src/cpp-utils/process/SignalCatcher.cpp
Normal file
@ -0,0 +1,239 @@
|
||||
#include "SignalCatcher.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <cpp-utils/assert/assert.h>
|
||||
#include <cpp-utils/thread/LeftRight.h>
|
||||
|
||||
using std::make_unique;
|
||||
using std::vector;
|
||||
using std::pair;
|
||||
|
||||
namespace cpputils {
|
||||
|
||||
namespace {
|
||||
|
||||
void got_signal(int signal);
|
||||
|
||||
using SignalHandlerFunction = void(int);
|
||||
|
||||
constexpr SignalHandlerFunction* signal_catcher_function = &got_signal;
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
|
||||
class SignalHandlerRAII final {
|
||||
public:
|
||||
SignalHandlerRAII(int signal)
|
||||
: _old_handler(), _signal(signal) {
|
||||
struct sigaction new_signal_handler{};
|
||||
std::memset(&new_signal_handler, 0, sizeof(new_signal_handler));
|
||||
new_signal_handler.sa_handler = signal_catcher_function; // NOLINT(cppcoreguidelines-pro-type-union-access)
|
||||
new_signal_handler.sa_flags = SA_RESTART;
|
||||
int error = sigfillset(&new_signal_handler.sa_mask); // block all signals while signal handler is running
|
||||
if (0 != error) {
|
||||
throw std::runtime_error("Error calling sigfillset. Errno: " + std::to_string(errno));
|
||||
}
|
||||
_sigaction(_signal, &new_signal_handler, &_old_handler);
|
||||
}
|
||||
|
||||
~SignalHandlerRAII() {
|
||||
// reset to old signal handler
|
||||
struct sigaction removed_handler{};
|
||||
_sigaction(_signal, &_old_handler, &removed_handler);
|
||||
if (signal_catcher_function != removed_handler.sa_handler) { // NOLINT(cppcoreguidelines-pro-type-union-access)
|
||||
ASSERT(false, "Signal handler screwup. We just replaced a signal handler that wasn't our own.");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static void _sigaction(int signal, struct sigaction *new_handler, struct sigaction *old_handler) {
|
||||
int error = sigaction(signal, new_handler, old_handler);
|
||||
if (0 != error) {
|
||||
throw std::runtime_error("Error calling sigaction. Errno: " + std::to_string(errno));
|
||||
}
|
||||
}
|
||||
|
||||
struct sigaction _old_handler;
|
||||
int _signal;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(SignalHandlerRAII);
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
class SignalHandlerRAII final {
|
||||
public:
|
||||
SignalHandlerRAII(int signal)
|
||||
: _old_handler(nullptr), _signal(signal) {
|
||||
_old_handler = ::signal(_signal, signal_catcher_function);
|
||||
if (_old_handler == SIG_ERR) {
|
||||
throw std::logic_error("Error calling signal(). Errno: " + std::to_string(errno));
|
||||
}
|
||||
}
|
||||
|
||||
~SignalHandlerRAII() {
|
||||
// reset to old signal handler
|
||||
SignalHandlerFunction* error = ::signal(_signal, _old_handler);
|
||||
if (error == SIG_ERR) {
|
||||
throw std::logic_error("Error resetting signal(). Errno: " + std::to_string(errno));
|
||||
}
|
||||
if (error != signal_catcher_function) {
|
||||
throw std::logic_error("Signal handler screwup. We just replaced a signal handler that wasn't our own.");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
SignalHandlerFunction* _old_handler;
|
||||
int _signal;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(SignalHandlerRAII);
|
||||
};
|
||||
|
||||
// The Linux default behavior (i.e. the way we set up sigaction above) is to disable signal processing while the signal
|
||||
// handler is running and to re-enable the custom handler once processing is finished. The Windows default behavior
|
||||
// is to reset the handler to the default handler directly before executing the handler, i.e. the handler will only
|
||||
// be called once. To fix this, we use this RAII class on Windows, of which an instance will live in the signal handler.
|
||||
// In its constructor, it disables signal handling, and in its destructor it re-sets the custom handler.
|
||||
// This is not perfect since there is a small time window between calling the signal handler and calling the constructor
|
||||
// of this class, but it's the best we can do.
|
||||
class SignalHandlerRunningRAII final {
|
||||
public:
|
||||
SignalHandlerRunningRAII(int signal) : _signal(signal) {
|
||||
SignalHandlerFunction* old_handler = ::signal(_signal, SIG_IGN);
|
||||
if (old_handler == SIG_ERR) {
|
||||
throw std::logic_error("Error disabling signal(). Errno: " + std::to_string(errno));
|
||||
}
|
||||
if (old_handler != SIG_DFL) {
|
||||
// see description above, we expected the signal handler to be reset.
|
||||
throw std::logic_error("We expected windows to reset the signal handler but it didn't. Did the Windows API change?");
|
||||
}
|
||||
}
|
||||
|
||||
~SignalHandlerRunningRAII() {
|
||||
SignalHandlerFunction* old_handler = ::signal(_signal, signal_catcher_function);
|
||||
if (old_handler == SIG_ERR) {
|
||||
throw std::logic_error("Error resetting signal() after calling handler. Errno: " + std::to_string(errno));
|
||||
}
|
||||
if (old_handler != SIG_IGN) {
|
||||
throw std::logic_error("Weird, we just did set the signal handler to ignore. Why isn't it still ignore?");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int _signal;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
class SignalCatcherRegistry final {
|
||||
public:
|
||||
void add(int signal, std::atomic<bool>* signal_occurred_flag) {
|
||||
_catchers.write([&] (auto& catchers) {
|
||||
catchers.emplace_back(signal, signal_occurred_flag);
|
||||
});
|
||||
}
|
||||
|
||||
void remove(std::atomic<bool>* catcher) {
|
||||
_catchers.write([&] (auto& catchers) {
|
||||
auto found = std::find_if(catchers.rbegin(), catchers.rend(), [catcher] (const auto& entry) {return entry.second == catcher;});
|
||||
ASSERT(found != catchers.rend(), "Signal handler not found");
|
||||
catchers.erase(--found.base()); // decrement because it's a reverse iterator
|
||||
});
|
||||
}
|
||||
|
||||
~SignalCatcherRegistry() {
|
||||
ASSERT(0 == _catchers.read([] (auto& catchers) {return catchers.size();}), "Leftover signal catchers that weren't destroyed");
|
||||
}
|
||||
|
||||
std::atomic<bool>* find(int signal) {
|
||||
// this is called in a signal handler and must be mutex-free.
|
||||
return _catchers.read([&](auto& catchers) {
|
||||
auto found = std::find_if(catchers.rbegin(), catchers.rend(), [signal](const auto& entry) {return entry.first == signal; });
|
||||
ASSERT(found != catchers.rend(), "Signal handler not found");
|
||||
return found->second;
|
||||
});
|
||||
}
|
||||
|
||||
static SignalCatcherRegistry& singleton() {
|
||||
static SignalCatcherRegistry _singleton;
|
||||
return _singleton;
|
||||
}
|
||||
|
||||
private:
|
||||
SignalCatcherRegistry() = default;
|
||||
|
||||
// using LeftRight datastructure because we need mutex-free reads. Signal handlers can't use mutexes.
|
||||
LeftRight<vector<pair<int, std::atomic<bool>*>>> _catchers;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(SignalCatcherRegistry);
|
||||
};
|
||||
|
||||
void got_signal(int signal) {
|
||||
#if defined(_MSC_VER)
|
||||
// Only needed on Windows, Linux does this by default. See comment on SignalHandlerRunningRAII class.
|
||||
SignalHandlerRunningRAII disable_signal_processing_while_handler_running_and_reset_handler_afterwards(signal);
|
||||
#endif
|
||||
std::atomic<bool>* catcher = SignalCatcherRegistry::singleton().find(signal);
|
||||
*catcher = true;
|
||||
}
|
||||
|
||||
class SignalCatcherRegisterer final {
|
||||
public:
|
||||
SignalCatcherRegisterer(int signal, std::atomic<bool>* catcher)
|
||||
: _catcher(catcher) {
|
||||
SignalCatcherRegistry::singleton().add(signal, _catcher);
|
||||
}
|
||||
|
||||
~SignalCatcherRegisterer() {
|
||||
SignalCatcherRegistry::singleton().remove(_catcher);
|
||||
}
|
||||
|
||||
private:
|
||||
std::atomic<bool>* _catcher;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(SignalCatcherRegisterer);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace details {
|
||||
|
||||
class SignalCatcherImpl final {
|
||||
public:
|
||||
SignalCatcherImpl(int signal, std::atomic<bool>* signal_occurred_flag)
|
||||
: _registerer(signal, signal_occurred_flag)
|
||||
, _handler(signal) {
|
||||
// note: the order of the members ensures that:
|
||||
// - when registering the signal handler fails, the registerer will be destroyed, unregistering the signal_occurred_flag,
|
||||
// i.e. there is no leak.
|
||||
|
||||
// Allow only the set of signals that is supported on all platforms, see for Windows: https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/signal?view=vs-2017
|
||||
ASSERT(signal == SIGABRT || signal == SIGFPE || signal == SIGILL || signal == SIGINT || signal == SIGSEGV || signal == SIGTERM, "Unknown signal");
|
||||
}
|
||||
private:
|
||||
SignalCatcherRegisterer _registerer;
|
||||
SignalHandlerRAII _handler;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(SignalCatcherImpl);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
SignalCatcher::SignalCatcher(std::initializer_list<int> signals)
|
||||
: _signal_occurred(false)
|
||||
, _impls() {
|
||||
// note: the order of the members ensures that:
|
||||
// - when the signal handler is set, the _signal_occurred flag is already initialized.
|
||||
// - the _signal_occurred flag will not be destructed as long as the signal handler might be called (i.e. as long as _impls lives)
|
||||
|
||||
_impls.reserve(signals.size());
|
||||
for (int signal : signals) {
|
||||
_impls.emplace_back(make_unique<details::SignalCatcherImpl>(signal, &_signal_occurred));
|
||||
}
|
||||
}
|
||||
|
||||
SignalCatcher::~SignalCatcher() {}
|
||||
|
||||
}
|
42
src/cpp-utils/process/SignalCatcher.h
Normal file
42
src/cpp-utils/process/SignalCatcher.h
Normal file
@ -0,0 +1,42 @@
|
||||
#pragma once
|
||||
#ifndef MESSMER_CPPUTILS_PROCESS_SIGNALCATCHER_H_
|
||||
#define MESSMER_CPPUTILS_PROCESS_SIGNALCATCHER_H_
|
||||
|
||||
#include <cpp-utils/macros.h>
|
||||
#include <atomic>
|
||||
#include <csignal>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace cpputils {
|
||||
|
||||
namespace details {
|
||||
class SignalCatcherImpl;
|
||||
}
|
||||
|
||||
/*
|
||||
* While an instance of this class is in scope, the specified signal (e.g. SIGINT)
|
||||
* is caught and doesn't exit the application. You can poll if the signal occurred.
|
||||
*/
|
||||
class SignalCatcher final {
|
||||
public:
|
||||
SignalCatcher(): SignalCatcher({SIGINT, SIGTERM}) {}
|
||||
|
||||
SignalCatcher(std::initializer_list<int> signals);
|
||||
~SignalCatcher();
|
||||
|
||||
bool signal_occurred() const {
|
||||
return _signal_occurred;
|
||||
}
|
||||
|
||||
private:
|
||||
// note: _signal_occurred must be initialized before _impl because _impl might use it
|
||||
std::atomic<bool> _signal_occurred;
|
||||
std::vector<std::unique_ptr<details::SignalCatcherImpl>> _impls;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(SignalCatcher);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
@ -3,6 +3,7 @@
|
||||
#include "DirBlob.h"
|
||||
#include "SymlinkBlob.h"
|
||||
#include <cryfs/config/CryConfigFile.h>
|
||||
#include <cpp-utils/process/SignalCatcher.h>
|
||||
|
||||
using cpputils::unique_ref;
|
||||
using cpputils::make_unique_ref;
|
||||
|
@ -16,6 +16,7 @@ set(SOURCES
|
||||
process/daemonize_include_test.cpp
|
||||
process/subprocess_include_test.cpp
|
||||
process/SubprocessTest.cpp
|
||||
process/SignalCatcherTest.cpp
|
||||
tempfile/TempFileTest.cpp
|
||||
tempfile/TempFileIncludeTest.cpp
|
||||
tempfile/TempDirIncludeTest.cpp
|
||||
|
193
test/cpp-utils/process/SignalCatcherTest.cpp
Normal file
193
test/cpp-utils/process/SignalCatcherTest.cpp
Normal file
@ -0,0 +1,193 @@
|
||||
#include <cpp-utils/process/SignalCatcher.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <csignal>
|
||||
|
||||
using cpputils::SignalCatcher;
|
||||
|
||||
namespace {
|
||||
void raise_signal(int signal) {
|
||||
int error = ::raise(signal);
|
||||
if (error != 0) {
|
||||
throw std::runtime_error("Error raising signal");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenNoSignalCatcher_whenRaisingSigint_thenDies) {
|
||||
EXPECT_DEATH(
|
||||
raise_signal(SIGINT),
|
||||
""
|
||||
);
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenNoSignalCatcher_whenRaisingSigterm_thenDies) {
|
||||
EXPECT_DEATH(
|
||||
raise_signal(SIGTERM),
|
||||
""
|
||||
);
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigIntCatcher_whenRaisingSigInt_thenCatches) {
|
||||
SignalCatcher catcher({SIGINT});
|
||||
|
||||
EXPECT_FALSE(catcher.signal_occurred());
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
|
||||
// raise again
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigTermCatcher_whenRaisingSigTerm_thenCatches) {
|
||||
SignalCatcher catcher({SIGTERM});
|
||||
|
||||
EXPECT_FALSE(catcher.signal_occurred());
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
|
||||
// raise again
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigInt_thenCatches) {
|
||||
SignalCatcher catcher({SIGINT, SIGTERM});
|
||||
|
||||
EXPECT_FALSE(catcher.signal_occurred());
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
|
||||
// raise again
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigTerm_thenCatches) {
|
||||
SignalCatcher catcher({SIGINT, SIGTERM});
|
||||
|
||||
EXPECT_FALSE(catcher.signal_occurred());
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
|
||||
// raise again
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigIntAndSigTermCatcher_whenRaisingSigIntAndSigTerm_thenCatches) {
|
||||
SignalCatcher catcher({SIGINT, SIGTERM});
|
||||
|
||||
EXPECT_FALSE(catcher.signal_occurred());
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(catcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigIntCatcherAndSigTermCatcher_whenRaisingSignalsInOrder_thenCorrectCatcherCatches) {
|
||||
SignalCatcher sigintCatcher({SIGINT});
|
||||
SignalCatcher sigtermCatcher({SIGTERM});
|
||||
|
||||
EXPECT_FALSE(sigintCatcher.signal_occurred());
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(sigintCatcher.signal_occurred());
|
||||
|
||||
EXPECT_FALSE(sigtermCatcher.signal_occurred());
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(sigtermCatcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigIntCatcherAndSigTermCatcher_whenRaisingSignalsInReverseOrder_thenCorrectCatcherCatches) {
|
||||
SignalCatcher sigintCatcher({SIGINT});
|
||||
SignalCatcher sigtermCatcher({SIGTERM});
|
||||
|
||||
EXPECT_FALSE(sigtermCatcher.signal_occurred());
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(sigtermCatcher.signal_occurred());
|
||||
|
||||
EXPECT_FALSE(sigintCatcher.signal_occurred());
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(sigintCatcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenNestedSigIntCatchers_whenRaisingSignals_thenCorrectCatcherCatches) {
|
||||
SignalCatcher outerCatcher({SIGINT});
|
||||
{
|
||||
SignalCatcher middleCatcher({SIGINT});
|
||||
|
||||
EXPECT_FALSE(middleCatcher.signal_occurred());
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(middleCatcher.signal_occurred());
|
||||
|
||||
{
|
||||
SignalCatcher innerCatcher({SIGINT});
|
||||
|
||||
EXPECT_FALSE(innerCatcher.signal_occurred());
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(innerCatcher.signal_occurred());
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_FALSE(outerCatcher.signal_occurred());
|
||||
raise_signal(SIGINT);
|
||||
EXPECT_TRUE(outerCatcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenExpiredSigIntCatcher_whenRaisingSigInt_thenDies) {
|
||||
{
|
||||
SignalCatcher catcher({SIGINT});
|
||||
}
|
||||
|
||||
EXPECT_DEATH(
|
||||
raise_signal(SIGINT),
|
||||
""
|
||||
);
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenExpiredSigTermCatcher_whenRaisingSigTerm_thenDies) {
|
||||
{
|
||||
SignalCatcher catcher({SIGTERM});
|
||||
}
|
||||
|
||||
EXPECT_DEATH(
|
||||
raise_signal(SIGTERM),
|
||||
""
|
||||
);
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenExpiredSigIntCatcherAndSigTermCatcher_whenRaisingSigTerm_thenDies) {
|
||||
{
|
||||
SignalCatcher sigIntCatcher({SIGTERM});
|
||||
SignalCatcher sigTermCatcer({SIGTERM});
|
||||
}
|
||||
|
||||
EXPECT_DEATH(
|
||||
raise_signal(SIGTERM),
|
||||
""
|
||||
);
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigTermCatcherAndExpiredSigIntCatcher_whenRaisingSigTerm_thenCatches) {
|
||||
SignalCatcher sigTermCatcher({SIGTERM});
|
||||
{
|
||||
SignalCatcher sigIntCatcher({SIGINT});
|
||||
}
|
||||
|
||||
EXPECT_FALSE(sigTermCatcher.signal_occurred());
|
||||
raise_signal(SIGTERM);
|
||||
EXPECT_TRUE(sigTermCatcher.signal_occurred());
|
||||
}
|
||||
|
||||
TEST(SignalCatcherTest, givenSigTermCatcherAndExpiredSigIntCatcher_whenRaisingSigInt_thenDies) {
|
||||
SignalCatcher sigTermCacher({SIGTERM});
|
||||
{
|
||||
SignalCatcher sigIntCatcher({SIGINT});
|
||||
}
|
||||
|
||||
EXPECT_DEATH(
|
||||
raise_signal(SIGINT),
|
||||
""
|
||||
);
|
||||
}
|
Loading…
Reference in New Issue
Block a user