sanitize: null writes, write errors, oversized int,

and set incoming_end to size_t
This commit is contained in:
Loïc Gomez 2022-01-09 20:24:40 +01:00 committed by Pierre-Louis Bonicoli
parent b990a071b3
commit 368149575d
Signed by: pilou
GPG Key ID: 06914C4A5EDAA6DD
4 changed files with 71 additions and 25 deletions

View File

@ -212,6 +212,16 @@ static int _write_socket_SSL(connection_t *cn, char* message)
size = sizeof(char)*strlen(message); size = sizeof(char)*strlen(message);
// let's not ERR (SSL_write doesn't allow 0 len writes)
if (size == 0)
return WRITE_OK;
// this will fail anyways
if (size > INT_MAX) {
mylog(LOG_ERROR, "Message too long in SSL write_socket");
return WRITE_ERROR;
}
if (!cn->client && cn->cert == NULL) { if (!cn->client && cn->cert == NULL) {
cn->cert = mySSL_get_cert(cn->ssl_h); cn->cert = mySSL_get_cert(cn->ssl_h);
if (cn->cert == NULL) { if (cn->cert == NULL) {
@ -219,7 +229,7 @@ static int _write_socket_SSL(connection_t *cn, char* message)
return WRITE_ERROR; return WRITE_ERROR;
} }
} }
count = SSL_write(cn->ssl_h, (const void *)message, size); count = SSL_write(cn->ssl_h, (const void *)message, (int)size);
ERR_print_errors(errbio); ERR_print_errors(errbio);
if (count <= 0) { if (count <= 0) {
int err = SSL_get_error(cn->ssl_h, count); int err = SSL_get_error(cn->ssl_h, count);
@ -234,10 +244,10 @@ static int _write_socket_SSL(connection_t *cn, char* message)
} }
return WRITE_ERROR; return WRITE_ERROR;
} }
if (count != size) { if (count != (int)size) {
/* abnormal : openssl keeps writing until message is not fully /* abnormal : openssl keeps writing until message is not fully
* sent */ * sent */
mylog(LOG_DEBUG, "only %d written while message length is %d", mylog(LOG_ERROR, "SSL_write wrote only %d while message length is %d",
count,size); count,size);
} }
@ -474,9 +484,23 @@ list_t *read_lines(connection_t *cn, int *error)
/* returns 1 if connection must be notified */ /* returns 1 if connection must be notified */
static int read_socket_SSL(connection_t *cn) static int read_socket_SSL(connection_t *cn)
{ {
int max, count; int count;
size_t max;
if (cn == NULL)
return 0;
if (cn->incoming_end >= CONN_BUFFER_SIZE) {
mylog(LOG_ERROR, "read_socket_SSL: internal error");
return -1;
}
max = sizeof(char)*(CONN_BUFFER_SIZE - cn->incoming_end);
if (max > INT_MAX) {
mylog(LOG_ERROR, "read_socket_SSL: cannot read that much data");
return -1;
}
max = CONN_BUFFER_SIZE - cn->incoming_end;
if (!cn->client && cn->cert == NULL) { if (!cn->client && cn->cert == NULL) {
cn->cert = mySSL_get_cert(cn->ssl_h); cn->cert = mySSL_get_cert(cn->ssl_h);
if (cn->cert == NULL) { if (cn->cert == NULL) {
@ -484,8 +508,8 @@ static int read_socket_SSL(connection_t *cn)
return -1; return -1;
} }
} }
count = SSL_read(cn->ssl_h, (void *)cn->incoming + cn->incoming_end, count = SSL_read(cn->ssl_h, (void *)(cn->incoming + cn->incoming_end),
sizeof(char) * max); (int)max);
ERR_print_errors(errbio); ERR_print_errors(errbio);
if (count < 0) { if (count < 0) {
int err = SSL_get_error(cn->ssl_h, count); int err = SSL_get_error(cn->ssl_h, count);
@ -511,23 +535,29 @@ static int read_socket_SSL(connection_t *cn)
connection_close(cn); connection_close(cn);
} }
return 1; return 1;
} else {
cn->incoming_end += (size_t)count;
return 0;
} }
cn->incoming_end += count;
return 0;
} }
#endif #endif
/* returns 1 if connection must be notified */ /* returns 1 if connection must be notified */
static int read_socket(connection_t *cn) static int read_socket(connection_t *cn)
{ {
int max, count; ssize_t count;
size_t max;
if (cn == NULL) if (cn == NULL)
return 0; return 0;
max = CONN_BUFFER_SIZE - cn->incoming_end;
count = read(cn->handle, cn->incoming+cn->incoming_end, if (cn->incoming_end >= CONN_BUFFER_SIZE) {
sizeof(char)*max); mylog(LOG_ERROR, "read_socket: internal error");
return -1;
}
max = sizeof(char)*(CONN_BUFFER_SIZE - cn->incoming_end);
count = read(cn->handle, cn->incoming+cn->incoming_end, max);
if (count < 0) { if (count < 0) {
if (errno == EAGAIN || errno == EINTR || errno == EINPROGRESS) if (errno == EAGAIN || errno == EINTR || errno == EINPROGRESS)
return 0; return 0;
@ -544,10 +574,10 @@ static int read_socket(connection_t *cn)
connection_close(cn); connection_close(cn);
} }
return 1; return 1;
} else {
cn->incoming_end += (unsigned)count;
return 0;
} }
cn->incoming_end += count;
return 0;
} }
static void data_find_lines(connection_t *cn) static void data_find_lines(connection_t *cn)

View File

@ -75,7 +75,7 @@ typedef struct connection {
time_t connect_time; time_t connect_time;
time_t timeout; time_t timeout;
char *incoming; char *incoming;
unsigned incoming_end; size_t incoming_end;
list_t *outgoing; list_t *outgoing;
char *partial; char *partial;
list_t *incoming_lines; list_t *incoming_lines;

View File

@ -1294,7 +1294,9 @@ static void irc_copy_cli(struct link_client *src, struct link_client *dest,
} }
/* LINK(src) == LINK(dest) */ /* LINK(src) == LINK(dest) */
size_t len = strlen(irc_line_elem(line, 2)) + 5; size_t len = strlen(irc_line_elem(line, 2)) + 6;
// snprintf fix ^
// __builtin___snprintf_chk output may be truncated before the last format character
char *tmp; char *tmp;
if (len == 0) if (len == 0)
@ -2897,10 +2899,13 @@ static void server_set_prefix(struct link_server *s, const char *modes)
static int bip_get_index(const char* str, char car) static int bip_get_index(const char* str, char car)
{ {
char *cur; char *cur;
if ((cur = strchr(str, car))) long diff;
return cur - str + 1; if (!(cur = strchr(str, car)))
else
return 0; return 0;
diff = cur - str + 1;
if (diff > INT_MAX)
fatal("bip_get_index: string too long");
return (int)diff;
} }
static int bip_fls(long v) static int bip_fls(long v)

View File

@ -1116,13 +1116,21 @@ static list_t *log_backread(log_t *log, const char *storename, const char *dest)
static char *_log_wrap(const char *dest, const char *line) static char *_log_wrap(const char *dest, const char *line)
{ {
char *buf; char *buf;
size_t count; int count;
buf = bip_malloc((size_t)LOGLINE_MAXLEN + 1); buf = bip_malloc((size_t)LOGLINE_MAXLEN + 1);
count = snprintf(buf, (size_t)LOGLINE_MAXLEN + 1, count = snprintf(buf, (size_t)LOGLINE_MAXLEN + 1,
":" P_IRCMASK " PRIVMSG %s :%s\r\n", dest, line); ":" P_IRCMASK " PRIVMSG %s :%s\r\n", dest, line);
if (count < 0) {
mylog(LOG_ERROR, "_log_wrap: error on snprintf: %s",
strerror(errno));
buf[LOGLINE_MAXLEN - 2] = '\r';
buf[LOGLINE_MAXLEN - 1] = '\n';
buf[LOGLINE_MAXLEN] = 0;
return buf;
}
if (count >= LOGLINE_MAXLEN + 1) { if (count >= LOGLINE_MAXLEN + 1) {
mylog(LOG_DEBUG, "line too long"); mylog(LOG_WARN, "_log_wrap: line too long");
buf[LOGLINE_MAXLEN - 2] = '\r'; buf[LOGLINE_MAXLEN - 2] = '\r';
buf[LOGLINE_MAXLEN - 1] = '\n'; buf[LOGLINE_MAXLEN - 1] = '\n';
buf[LOGLINE_MAXLEN] = 0; buf[LOGLINE_MAXLEN] = 0;
@ -1172,13 +1180,16 @@ static size_t _log_write(log_t *logdata, logstore_t *store,
void log_write(log_t *logdata, const char *destination, const char *str) void log_write(log_t *logdata, const char *destination, const char *str)
{ {
logstore_t *store = log_find_file(logdata, destination); logstore_t *store = log_find_file(logdata, destination);
size_t written;
if (!store) { if (!store) {
mylog(LOG_ERROR, "Unable to find/create logfile for '%s'", mylog(LOG_ERROR, "Unable to find/create logfile for '%s'",
destination); destination);
return; return;
} }
_log_write(logdata, store, destination, str); written = _log_write(logdata, store, destination, str);
if (written <= 0)
mylog(LOG_WARN, "log_write to '%s' failed", destination);
} }
static list_t *log_all_logs = NULL; static list_t *log_all_logs = NULL;