~rcr/rirc

65e3c0364517289a2743131caf7cae710d1c7140 — Richard Robbins 1 year, 4 months ago 7bc8d18
fix ping timeout not generating disconnect error
3 files changed, 140 insertions(+), 124 deletions(-)

M src/io.c
M src/io.h
M src/io_net.h
M src/io.c => src/io.c +137 -123
@@ 1,3 1,5 @@
#include "src/io.h"

#include <errno.h>
#include <netdb.h>
#include <pthread.h>


@@ 12,13 14,13 @@

#include "mbedtls/ctr_drbg.h"
#include "mbedtls/entropy.h"
#include "mbedtls/error.h"
#include "mbedtls/net_sockets.h"
#include "mbedtls/ssl.h"
#include "mbedtls/x509.h"
#include "mbedtls/x509_crt.h"

#include "config.h"
#include "rirc.h"
#include "src/io.h"
#include "src/io_net.h"
#include "utils/utils.h"



@@ 117,35 119,35 @@ struct connection
		IO_ST_PING, /* Socket connected, network state in question */
	} st_cur, /* current thread state */
	  st_new; /* new thread state */
	mbedtls_net_context ssl_fd;
	mbedtls_ssl_config ssl_conf;
	mbedtls_ssl_context ssl_ctx;
	mbedtls_net_context tls_fd;
	mbedtls_ssl_config  tls_conf;
	mbedtls_ssl_context tls_ctx;
	pthread_mutex_t mtx;
	pthread_t tid;
	unsigned ping;
	unsigned rx_sleep;
};

static const char* io_tls_strerror(int, char*, size_t);
static enum io_state_t io_state_cxed(struct connection*);
static enum io_state_t io_state_cxng(struct connection*);
static enum io_state_t io_state_ping(struct connection*);
static enum io_state_t io_state_rxng(struct connection*);
static int io_cx_read(struct connection*);
static int io_cx_read(struct connection*, unsigned);
static void io_fatal(const char*, int);
static void io_sig_handle(int);
static void io_sig_init(void);
static void io_ssl_init(void);
static void io_ssl_term(void);
static void io_tls_init(void);
static void io_tls_term(void);
static void io_tty_init(void);
static void io_tty_term(void);
static void io_tty_winsize(void);
static void* io_thread(void*);

static int io_running;
static mbedtls_ctr_drbg_context ssl_ctr_drbg;
static mbedtls_entropy_context ssl_entropy;
static mbedtls_ssl_config ssl_conf;
static mbedtls_x509_crt ssl_cacert;
static mbedtls_ctr_drbg_context tls_ctr_drbg;
static mbedtls_entropy_context  tls_entropy;
static mbedtls_x509_crt         tls_x509_crt;
static pthread_mutex_t cb_mutex = PTHREAD_MUTEX_INITIALIZER;
static struct termios term;
static unsigned io_cols;


@@ 268,7 270,7 @@ io_sendf(struct connection *cx, const char *fmt, ...)
	written = 0;

	do {
		if ((ret = mbedtls_ssl_write(&(cx->ssl_ctx), sendbuf + ret, len - ret)) < 0) {
		if ((ret = mbedtls_ssl_write(&(cx->tls_ctx), sendbuf + ret, len - ret)) < 0) {
			switch (ret) {
				case MBEDTLS_ERR_SSL_WANT_READ:
				case MBEDTLS_ERR_SSL_WANT_WRITE:


@@ 290,7 292,7 @@ io_init(void)
{
	io_sig_init();
	io_tty_init();
	io_ssl_init();
	io_tls_init();
}

void


@@ 373,6 375,13 @@ io_err(int err)
	}
}

const char*
io_tls_strerror(int err, char *buf, size_t len)
{
	mbedtls_strerror(err, buf, len);
	return buf;
}

static enum io_state_t
io_state_rxng(struct connection *cx)
{


@@ 397,8 406,7 @@ io_state_rxng(struct connection *cx)
static enum io_state_t
io_state_cxng(struct connection *cx)
{
	char addr_buf[INET6_ADDRSTRLEN];
	char vrfy_buf[512];
	char buf[MIN(INET6_ADDRSTRLEN, 512)];
	enum io_state_t st = IO_ST_RXNG;
	int ret;
	int soc;


@@ 406,95 414,109 @@ io_state_cxng(struct connection *cx)

	io_cb_info(cx, "Connecting to %s:%s", cx->host, cx->port);

	if ((ret = io_net_connect(&soc, cx->host, cx->port)) != IO_NET_ERR_NONE) {
	if ((ret = io_net_connect(&soc, cx->host, cx->port)) != 0) {
		switch (ret) {
			case IO_NET_ERR_EINTR:
				st = IO_ST_DXED;
				goto error_net;
				goto err_net;
			case IO_NET_ERR_SOCKET_FAILED:
				io_cb_err(cx, " ... Failed to obtain socket");
				goto error_net;
				goto err_net;
			case IO_NET_ERR_UNKNOWN_HOST:
				io_cb_err(cx, " ... Failed to resolve host");
				goto error_net;
				goto err_net;
			case IO_NET_ERR_CONNECT_FAILED:
				io_cb_err(cx, " ... Failed to connect to host");
				goto error_net;
				goto err_net;
			default:
				fatal("unknown net error");
		}
	}

	if ((ret = io_net_ip_str(soc, addr_buf, sizeof(addr_buf))) != IO_NET_ERR_NONE) {
	if ((ret = io_net_ip_str(soc, buf, sizeof(buf))) != 0) {
		if (ret == IO_NET_ERR_EINTR) {
			st = IO_ST_DXED;
			goto error_net;
			goto err_net;
		}
		io_cb_info(cx, " ... Connected (failed to optain IP address)");
	} else {
		io_cb_info(cx, " ... Connected to [%s]", addr_buf);
		io_cb_info(cx, " ... Connected to [%s]", buf);
	}

	io_cb_info(cx, " ... Establishing SSL");
	io_cb_info(cx, " ... Establishing TLS connection");

	mbedtls_net_init(&(cx->ssl_fd));
	mbedtls_ssl_init(&(cx->ssl_ctx));
	mbedtls_ssl_config_init(&(cx->ssl_conf));
	mbedtls_net_init(&(cx->tls_fd));
	mbedtls_ssl_init(&(cx->tls_ctx));
	mbedtls_ssl_config_init(&(cx->tls_conf));

	cx->ssl_conf = ssl_conf;
	cx->ssl_fd.fd = soc;
	cx->tls_fd.fd = soc;

	if ((ret = mbedtls_net_set_block(&(cx->ssl_fd))) != 0) {
		io_cb_err(cx, " ... mbedtls_net_set_block failure");
		goto error_ssl;
	if (mbedtls_ssl_config_defaults(
			&(cx->tls_conf),
			MBEDTLS_SSL_IS_CLIENT,
			MBEDTLS_SSL_TRANSPORT_STREAM,
			MBEDTLS_SSL_PRESET_DEFAULT) != 0) {
		io_cb_err(cx, " ... mbedtls_ssl_config_defaults: %s", io_tls_strerror(ret, buf, sizeof(buf)));
		goto err_tls;
	}

	if ((ret = mbedtls_ssl_setup(&(cx->ssl_ctx), &(cx->ssl_conf))) != 0) {
		io_cb_err(cx, " ... mbedtls_ssl_setup failure");
		goto error_ssl;
	mbedtls_ssl_conf_ca_chain(&(cx->tls_conf), &tls_x509_crt, NULL);
	mbedtls_ssl_conf_rng(&(cx->tls_conf), mbedtls_ctr_drbg_random, &tls_ctr_drbg);

	if ((ret = mbedtls_net_set_block(&(cx->tls_fd))) != 0) {
		io_cb_err(cx, " ... mbedtls_net_set_block: %s", io_tls_strerror(ret, buf, sizeof(buf)));
		goto err_tls;
	}

	if ((ret = mbedtls_ssl_set_hostname(&(cx->ssl_ctx), cx->host)) != 0) {
		io_cb_err(cx, " ... mbedtls_ssl_set_hostname failure");
		goto error_ssl;
	if ((ret = mbedtls_ssl_setup(&(cx->tls_ctx), &(cx->tls_conf))) != 0) {
		io_cb_err(cx, " ... mbedtls_ssl_setup: %s", io_tls_strerror(ret, buf, sizeof(buf)));
		goto err_tls;
	}

	if ((ret = mbedtls_ssl_set_hostname(&(cx->tls_ctx), cx->host)) != 0) {
		io_cb_err(cx, " ... mbedtls_ssl_set_hostname: %s", io_tls_strerror(ret, buf, sizeof(buf)));
		goto err_tls;
	}

	mbedtls_ssl_set_bio(
		&(cx->ssl_ctx),
		&(cx->ssl_fd),
		&(cx->tls_ctx),
		&(cx->tls_fd),
		mbedtls_net_send,
		NULL,
		mbedtls_net_recv_timeout);

	while ((ret = mbedtls_ssl_handshake(&(cx->ssl_ctx))) != 0) {
	while ((ret = mbedtls_ssl_handshake(&(cx->tls_ctx))) != 0) {
		if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) {
			io_cb_err(cx, " ... mbedtls_ssl_handshake failure");
			goto error_ssl;
			io_cb_err(cx, " ... mbedtls_ssl_handshake: %s", io_tls_strerror(ret, buf, sizeof(buf)));
			goto err_tls;
		}
	}

	if ((cert_ret = mbedtls_ssl_get_verify_result(&(cx->ssl_ctx))) != 0) {
		if (mbedtls_x509_crt_verify_info(vrfy_buf, sizeof(vrfy_buf), "", cert_ret) <= 0) {
	if ((cert_ret = mbedtls_ssl_get_verify_result(&(cx->tls_ctx))) != 0) {
		if (mbedtls_x509_crt_verify_info(buf, sizeof(buf), "", cert_ret) <= 0) {
			io_cb_err(cx, " ... failed to verify cert: unknown failure");
		} else {
			io_cb_err(cx, " ... failed to verify cert: %s", vrfy_buf);
			io_cb_err(cx, " ... failed to verify cert: %s", buf);
		}
		goto error_ssl;
		goto err_tls;
	}

	io_cb_info(cx, " ... SSL connection established");
	io_cb_info(cx, " ...   - version:     %s", mbedtls_ssl_get_version(&(cx->ssl_ctx)));
	io_cb_info(cx, " ...   - ciphersuite: %s", mbedtls_ssl_get_ciphersuite(&(cx->ssl_ctx)));
	io_cb_info(cx, " ... TLS connection established");
	io_cb_info(cx, " ...   - version:     %s", mbedtls_ssl_get_version(&(cx->tls_ctx)));
	io_cb_info(cx, " ...   - ciphersuite: %s", mbedtls_ssl_get_ciphersuite(&(cx->tls_ctx)));

	return IO_ST_CXED;

error_ssl:
err_tls:

	io_cb_err(cx, " ... TLS connection failure");

	mbedtls_ssl_config_free(&(cx->tls_conf));
	mbedtls_ssl_free(&(cx->tls_ctx));

	mbedtls_net_free(&(cx->ssl_fd));
	mbedtls_ssl_free(&(cx->ssl_ctx));
	mbedtls_ssl_config_free(&(cx->ssl_conf));
err_net:

error_net:
	mbedtls_net_free(&(cx->tls_fd));

	return st;
}


@@ 503,64 525,56 @@ static enum io_state_t
io_state_cxed(struct connection *cx)
{
	int ret;
	enum io_state_t st = IO_ST_CXNG;

	mbedtls_ssl_conf_read_timeout(&(cx->ssl_conf), SEC_IN_MS(IO_PING_MIN));

	while ((ret = io_cx_read(cx)) > 0)
	while ((ret = io_cx_read(cx, IO_PING_MIN)) > 0)
		continue;

	if (ret == MBEDTLS_ERR_SSL_TIMEOUT)
		return IO_ST_PING;

	switch (ret) {
		case MBEDTLS_ERR_SSL_WANT_READ:
		case MBEDTLS_ERR_SSL_WANT_WRITE:
			break;
		case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
			io_cb_info(cx, "connection closed gracefully");
			break;
		case MBEDTLS_ERR_SSL_TIMEOUT:
			return IO_ST_PING;
		case MBEDTLS_ERR_NET_CONN_RESET:
		case 0:
			io_cb_err(cx, "connection reset by peer");
			break;
		default:
			io_cb_err(cx, "connection ssl error");
			io_cb_err(cx, "connection tls error");
			break;
	}

	mbedtls_net_free(&(cx->ssl_fd));
	mbedtls_ssl_free(&(cx->ssl_ctx));
	mbedtls_ssl_config_free(&(cx->ssl_conf));
	mbedtls_net_free(&(cx->tls_fd));
	mbedtls_ssl_config_free(&(cx->tls_conf));
	mbedtls_ssl_free(&(cx->tls_ctx));

	return st;
	return IO_ST_CXNG;
}

static enum io_state_t
io_state_ping(struct connection *cx)
{
	int ret;
	enum io_state_t st = IO_ST_CXNG;

	mbedtls_ssl_conf_read_timeout(&(cx->ssl_conf), SEC_IN_MS(IO_PING_REFRESH));

	while ((ret = io_cx_read(cx)) <= 0 && ret == MBEDTLS_ERR_SSL_TIMEOUT) {
		if ((cx->ping += IO_PING_REFRESH) < IO_PING_MAX) {
			io_cb_ping_n(cx, cx->ping);
		} else {
			break;
		}
	}
	if (cx->ping >= IO_PING_MAX)
		return IO_ST_CXNG;

	if (ret > 0)
	if ((ret = io_cx_read(cx, IO_PING_REFRESH)) > 0)
		return IO_ST_CXED;

	if (ret == MBEDTLS_ERR_SSL_TIMEOUT)
		return IO_ST_PING;

	switch (ret) {
		case MBEDTLS_ERR_SSL_WANT_READ:
		case MBEDTLS_ERR_SSL_WANT_WRITE:
			break;
		case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
			break;
		case MBEDTLS_ERR_SSL_TIMEOUT: /* read timeout */
			io_cb_err(cx, "connection timeout (%u)", cx->ping);
			io_cb_info(cx, "connection closed gracefully");
			break;
		case MBEDTLS_ERR_NET_CONN_RESET:
		case 0:


@@ 571,11 585,11 @@ io_state_ping(struct connection *cx)
			break;
	}

	mbedtls_net_free(&(cx->ssl_fd));
	mbedtls_ssl_free(&(cx->ssl_ctx));
	mbedtls_ssl_config_free(&(cx->ssl_conf));
	mbedtls_net_free(&(cx->tls_fd));
	mbedtls_ssl_config_free(&(cx->tls_conf));
	mbedtls_ssl_free(&(cx->tls_ctx));

	return st;
	return IO_ST_CXNG;
}

static void*


@@ 622,11 636,15 @@ io_thread(void *arg)

		/* State transitions */
		switch (ST_X(st_old, st_new)) {
			case ST_X(IO_ST_DXED, IO_ST_CXNG): /* A1 */
			case ST_X(IO_ST_RXNG, IO_ST_CXNG): /* A2,C */
				break;
			case ST_X(IO_ST_CXED, IO_ST_CXNG): /* F1 */
				io_cb_dxed(cx);
				break;
			case ST_X(IO_ST_PING, IO_ST_CXNG): /* F2 */
				io_cb_err(cx, "connection timeout (%u)", cx->ping);
				io_cb_dxed(cx);
			case ST_X(IO_ST_DXED, IO_ST_CXNG): /* A1 */
			case ST_X(IO_ST_RXNG, IO_ST_CXNG): /* A2,C */
				break;
			case ST_X(IO_ST_RXNG, IO_ST_DXED): /* B1 */
			case ST_X(IO_ST_CXNG, IO_ST_DXED): /* B2 */


@@ 650,6 668,7 @@ io_thread(void *arg)
				io_cb_ping_1(cx, cx->ping);
				break;
			case ST_X(IO_ST_PING, IO_ST_PING): /* H */
				cx->ping += IO_PING_REFRESH;
				io_cb_ping_n(cx, cx->ping);
				break;
			case ST_X(IO_ST_PING, IO_ST_CXED): /* I */


@@ 666,12 685,14 @@ io_thread(void *arg)
}

static int
io_cx_read(struct connection *cx)
io_cx_read(struct connection *cx, unsigned timeout)
{
	int ret;
	unsigned char ssl_readbuf[1024];

	if ((ret = mbedtls_ssl_read(&(cx->ssl_ctx), ssl_readbuf, sizeof(ssl_readbuf))) > 0) {
	mbedtls_ssl_conf_read_timeout(&(cx->tls_conf), SEC_IN_MS(timeout));

	if ((ret = mbedtls_ssl_read(&(cx->tls_ctx), ssl_readbuf, sizeof(ssl_readbuf))) > 0) {
		PT_LK(&cb_mutex);
		io_cb_read_soc((char *)ssl_readbuf, (size_t)ret,  cx->obj);
		PT_UL(&cb_mutex);


@@ 718,53 739,46 @@ io_sig_init(void)
}

static void
io_ssl_init(void)
io_tls_init(void)
{
	const char *tls_pers = "rirc-drbg-ctr-pers";
	char buf[512];
	int err;
	struct timespec ts;

	mbedtls_ssl_config_init(&ssl_conf);
	mbedtls_ctr_drbg_init(&ssl_ctr_drbg);
	mbedtls_entropy_init(&ssl_entropy);
	mbedtls_x509_crt_init(&ssl_cacert);
	mbedtls_ctr_drbg_init(&tls_ctr_drbg);
	mbedtls_entropy_init(&tls_entropy);
	mbedtls_x509_crt_init(&tls_x509_crt);

	if (mbedtls_x509_crt_parse_path(&ssl_cacert, ca_cert_path) != 0) {
		fatal("ssl init failed: mbedtls_x509_crt_parse_path");
	}
	if (atexit(io_tls_term) != 0)
		fatal("atexit");

	if (mbedtls_ctr_drbg_seed(
			&ssl_ctr_drbg,
			mbedtls_entropy_func,
			&ssl_entropy,
			(unsigned char *)tls_pers,
			strlen(tls_pers)) != 0) {
		fatal("ssl init failed: mbedtls_ctr_drbg_seed");
	}
	if (timespec_get(&ts, TIME_UTC) != TIME_UTC)
		fatal("timespec_get");

	if (mbedtls_ssl_config_defaults(
			&ssl_conf,
			MBEDTLS_SSL_IS_CLIENT,
			MBEDTLS_SSL_TRANSPORT_STREAM,
			MBEDTLS_SSL_PRESET_DEFAULT) != 0) {
		fatal("ssl init failed: mbedtls_ssl_config_defaults");
	}
	if (snprintf(buf, sizeof(buf), "rirc-%lu-%lu", ts.tv_sec, ts.tv_nsec) < 0)
		fatal("snprintf");

	mbedtls_ssl_conf_ca_chain(&ssl_conf, &ssl_cacert, NULL);
	mbedtls_ssl_conf_read_timeout(&ssl_conf, SEC_IN_MS(IO_PING_MIN));
	mbedtls_ssl_conf_rng(&ssl_conf, mbedtls_ctr_drbg_random, &ssl_ctr_drbg);
	if ((err = mbedtls_ctr_drbg_seed(
			&tls_ctr_drbg,
			mbedtls_entropy_func,
			&tls_entropy,
			(const unsigned char *)buf,
			strlen(buf))) != 0) {
		fatal("mbedtls_ctr_drbg_seed: %s", io_tls_strerror(err, buf, sizeof(buf)));
	}

	if (atexit(io_ssl_term) != 0)
		fatal("atexit");
	if ((err = mbedtls_x509_crt_parse_path(&tls_x509_crt, ca_cert_path)) != 0)
		fatal("mbedtls_x509_crt_parse_path: %s", io_tls_strerror(err, buf, sizeof(buf)));
}

static void
io_ssl_term(void)
io_tls_term(void)
{
	/* Exit handler, must return normally */

	mbedtls_ctr_drbg_free(&ssl_ctr_drbg);
	mbedtls_entropy_free(&ssl_entropy);
	mbedtls_ssl_config_free(&ssl_conf);
	mbedtls_x509_crt_free(&ssl_cacert);
	mbedtls_ctr_drbg_free(&tls_ctr_drbg);
	mbedtls_entropy_free(&tls_entropy);
	mbedtls_x509_crt_free(&tls_x509_crt);
}

static void

M src/io.h => src/io.h +2 -0
@@ 79,6 79,8 @@
 * a call to io_stop
 */

#include <stddef.h>

struct connection;

enum io_sig_t

M src/io_net.h => src/io_net.h +1 -1
@@ 5,7 5,7 @@

enum io_net_err
{
	IO_NET_ERR_NONE,
	IO_NET_ERR_NONE = 0,
	IO_NET_ERR_SOCKET_FAILED,
	IO_NET_ERR_UNKNOWN_HOST,
	IO_NET_ERR_CONNECT_FAILED,