From 368149575d2cf8aefc6b50983dddb6e1a25b6426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Gomez?= Date: Sun, 9 Jan 2022 20:24:40 +0100 Subject: [PATCH] sanitize: null writes, write errors, oversized int, and set incoming_end to size_t --- src/connection.c | 64 +++++++++++++++++++++++++++++++++++------------- src/connection.h | 2 +- src/irc.c | 13 +++++++--- src/log.c | 17 ++++++++++--- 4 files changed, 71 insertions(+), 25 deletions(-) diff --git a/src/connection.c b/src/connection.c index 54caa4b..fdbab10 100644 --- a/src/connection.c +++ b/src/connection.c @@ -212,6 +212,16 @@ static int _write_socket_SSL(connection_t *cn, char* message) size = sizeof(char)*strlen(message); + // let's not ERR (SSL_write doesn't allow 0 len writes) + if (size == 0) + return WRITE_OK; + + // this will fail anyways + if (size > INT_MAX) { + mylog(LOG_ERROR, "Message too long in SSL write_socket"); + return WRITE_ERROR; + } + if (!cn->client && cn->cert == NULL) { cn->cert = mySSL_get_cert(cn->ssl_h); if (cn->cert == NULL) { @@ -219,7 +229,7 @@ static int _write_socket_SSL(connection_t *cn, char* message) return WRITE_ERROR; } } - count = SSL_write(cn->ssl_h, (const void *)message, size); + count = SSL_write(cn->ssl_h, (const void *)message, (int)size); ERR_print_errors(errbio); if (count <= 0) { int err = SSL_get_error(cn->ssl_h, count); @@ -234,10 +244,10 @@ static int _write_socket_SSL(connection_t *cn, char* message) } return WRITE_ERROR; } - if (count != size) { + if (count != (int)size) { /* abnormal : openssl keeps writing until message is not fully * sent */ - mylog(LOG_DEBUG, "only %d written while message length is %d", + mylog(LOG_ERROR, "SSL_write wrote only %d while message length is %d", count,size); } @@ -474,9 +484,23 @@ list_t *read_lines(connection_t *cn, int *error) /* returns 1 if connection must be notified */ static int read_socket_SSL(connection_t *cn) { - int max, count; + int count; + size_t max; + + if (cn == NULL) + return 0; + + if (cn->incoming_end >= CONN_BUFFER_SIZE) { + mylog(LOG_ERROR, "read_socket_SSL: internal error"); + return -1; + } + + max = sizeof(char)*(CONN_BUFFER_SIZE - cn->incoming_end); + if (max > INT_MAX) { + mylog(LOG_ERROR, "read_socket_SSL: cannot read that much data"); + return -1; + } - max = CONN_BUFFER_SIZE - cn->incoming_end; if (!cn->client && cn->cert == NULL) { cn->cert = mySSL_get_cert(cn->ssl_h); if (cn->cert == NULL) { @@ -484,8 +508,8 @@ static int read_socket_SSL(connection_t *cn) return -1; } } - count = SSL_read(cn->ssl_h, (void *)cn->incoming + cn->incoming_end, - sizeof(char) * max); + count = SSL_read(cn->ssl_h, (void *)(cn->incoming + cn->incoming_end), + (int)max); ERR_print_errors(errbio); if (count < 0) { int err = SSL_get_error(cn->ssl_h, count); @@ -511,23 +535,29 @@ static int read_socket_SSL(connection_t *cn) connection_close(cn); } return 1; + } else { + cn->incoming_end += (size_t)count; + return 0; } - - cn->incoming_end += count; - return 0; } #endif /* returns 1 if connection must be notified */ static int read_socket(connection_t *cn) { - int max, count; + ssize_t count; + size_t max; if (cn == NULL) return 0; - max = CONN_BUFFER_SIZE - cn->incoming_end; - count = read(cn->handle, cn->incoming+cn->incoming_end, - sizeof(char)*max); + + if (cn->incoming_end >= CONN_BUFFER_SIZE) { + mylog(LOG_ERROR, "read_socket: internal error"); + return -1; + } + + max = sizeof(char)*(CONN_BUFFER_SIZE - cn->incoming_end); + count = read(cn->handle, cn->incoming+cn->incoming_end, max); if (count < 0) { if (errno == EAGAIN || errno == EINTR || errno == EINPROGRESS) return 0; @@ -544,10 +574,10 @@ static int read_socket(connection_t *cn) connection_close(cn); } return 1; + } else { + cn->incoming_end += (unsigned)count; + return 0; } - - cn->incoming_end += count; - return 0; } static void data_find_lines(connection_t *cn) diff --git a/src/connection.h b/src/connection.h index 5109936..b5f78a1 100644 --- a/src/connection.h +++ b/src/connection.h @@ -75,7 +75,7 @@ typedef struct connection { time_t connect_time; time_t timeout; char *incoming; - unsigned incoming_end; + size_t incoming_end; list_t *outgoing; char *partial; list_t *incoming_lines; diff --git a/src/irc.c b/src/irc.c index 4762751..d431178 100644 --- a/src/irc.c +++ b/src/irc.c @@ -1294,7 +1294,9 @@ static void irc_copy_cli(struct link_client *src, struct link_client *dest, } /* LINK(src) == LINK(dest) */ - size_t len = strlen(irc_line_elem(line, 2)) + 5; + size_t len = strlen(irc_line_elem(line, 2)) + 6; + // snprintf fix ^ + // ‘__builtin___snprintf_chk’ output may be truncated before the last format character char *tmp; if (len == 0) @@ -2897,10 +2899,13 @@ static void server_set_prefix(struct link_server *s, const char *modes) static int bip_get_index(const char* str, char car) { char *cur; - if ((cur = strchr(str, car))) - return cur - str + 1; - else + long diff; + if (!(cur = strchr(str, car))) return 0; + diff = cur - str + 1; + if (diff > INT_MAX) + fatal("bip_get_index: string too long"); + return (int)diff; } static int bip_fls(long v) diff --git a/src/log.c b/src/log.c index 6aea268..b32018e 100644 --- a/src/log.c +++ b/src/log.c @@ -1116,13 +1116,21 @@ static list_t *log_backread(log_t *log, const char *storename, const char *dest) static char *_log_wrap(const char *dest, const char *line) { char *buf; - size_t count; + int count; buf = bip_malloc((size_t)LOGLINE_MAXLEN + 1); count = snprintf(buf, (size_t)LOGLINE_MAXLEN + 1, ":" P_IRCMASK " PRIVMSG %s :%s\r\n", dest, line); + if (count < 0) { + mylog(LOG_ERROR, "_log_wrap: error on snprintf: %s", + strerror(errno)); + buf[LOGLINE_MAXLEN - 2] = '\r'; + buf[LOGLINE_MAXLEN - 1] = '\n'; + buf[LOGLINE_MAXLEN] = 0; + return buf; + } if (count >= LOGLINE_MAXLEN + 1) { - mylog(LOG_DEBUG, "line too long"); + mylog(LOG_WARN, "_log_wrap: line too long"); buf[LOGLINE_MAXLEN - 2] = '\r'; buf[LOGLINE_MAXLEN - 1] = '\n'; buf[LOGLINE_MAXLEN] = 0; @@ -1172,13 +1180,16 @@ static size_t _log_write(log_t *logdata, logstore_t *store, void log_write(log_t *logdata, const char *destination, const char *str) { logstore_t *store = log_find_file(logdata, destination); + size_t written; if (!store) { mylog(LOG_ERROR, "Unable to find/create logfile for '%s'", destination); return; } - _log_write(logdata, store, destination, str); + written = _log_write(logdata, store, destination, str); + if (written <= 0) + mylog(LOG_WARN, "log_write to '%s' failed", destination); } static list_t *log_all_logs = NULL;