~rcr/rirc

f1e3f011921bd44ac3ceb306698ad27d7da2b021 — Richard Robbins 2 months ago c677d85
refactor tls state to be per connection thread
1 files changed, 49 insertions(+), 57 deletions(-)

M src/io.c
M src/io.c => src/io.c +49 -57
@@ 114,9 114,12 @@ struct connection
		IO_ST_PING, /* Socket connected, network state in question */
	} st_cur, /* current thread state */
	  st_new; /* new thread state */
	mbedtls_ctr_drbg_context tls_ctr_drbg;
	mbedtls_entropy_context tls_entropy;
	mbedtls_net_context net_ctx;
	mbedtls_ssl_config  tls_conf;
	mbedtls_ssl_config tls_conf;
	mbedtls_ssl_context tls_ctx;
	mbedtls_x509_crt tls_x509_crt;
	pthread_mutex_t mtx;
	pthread_t tid;
	uint32_t flags;


@@ 138,9 141,6 @@ static void io_tty_winsize(void);
static void* io_thread(void*);

static int io_running;
static mbedtls_ctr_drbg_context tls_ctr_drbg;
static mbedtls_entropy_context  tls_entropy;
static mbedtls_x509_crt         tls_x509_crt;
static pthread_mutex_t io_cb_mutex = PTHREAD_MUTEX_INITIALIZER;
static struct termios term;
static volatile sig_atomic_t flag_sigwinch_cb; /* sigwinch callback */


@@ 153,8 153,6 @@ static void io_net_close(int);
static const char* io_tls_err(int);
static int io_tls_establish(struct connection*);
static int io_tls_x509_vrfy(struct connection*);
static void io_tls_init(void);
static void io_tls_term(void);

const char *ca_cert_paths[] = {
	"/etc/ssl/ca-bundle.pem",


@@ 309,7 307,6 @@ io_init(void)
{
	io_sig_init();
	io_tty_init();
	io_tls_init();
}

void


@@ 435,8 432,11 @@ io_state_cxed(struct connection *cx)
	mbedtls_net_free(&(cx->net_ctx));

	if (cx->flags & IO_TLS_ENABLED) {
		mbedtls_ctr_drbg_free(&(cx->tls_ctr_drbg));
		mbedtls_entropy_free(&(cx->tls_entropy));
		mbedtls_ssl_config_free(&(cx->tls_conf));
		mbedtls_ssl_free(&(cx->tls_ctx));
		mbedtls_x509_crt_free(&(cx->tls_x509_crt));
	}

	return IO_ST_CXNG;


@@ 474,8 474,11 @@ io_state_ping(struct connection *cx)
	mbedtls_net_free(&(cx->net_ctx));

	if (cx->flags & IO_TLS_ENABLED) {
		mbedtls_ctr_drbg_free(&(cx->tls_ctr_drbg));
		mbedtls_entropy_free(&(cx->tls_entropy));
		mbedtls_ssl_config_free(&(cx->tls_conf));
		mbedtls_ssl_free(&(cx->tls_ctx));
		mbedtls_x509_crt_free(&(cx->tls_x509_crt));
	}

	return IO_ST_CXNG;


@@ 776,12 779,17 @@ io_strerror(char *buf, size_t buflen)
static int
io_tls_establish(struct connection *cx)
{
	const unsigned char pers[] = "rirc-drbg-seed";

	int ret;

	io_info(cx, " .. Establishing TLS connection");

	mbedtls_ctr_drbg_init(&(cx->tls_ctr_drbg));
	mbedtls_entropy_init(&(cx->tls_entropy));
	mbedtls_ssl_init(&(cx->tls_ctx));
	mbedtls_ssl_config_init(&(cx->tls_conf));
	mbedtls_x509_crt_init(&(cx->tls_x509_crt));

	if ((ret = mbedtls_ssl_config_defaults(
			&(cx->tls_conf),


@@ 802,12 810,41 @@ io_tls_establish(struct connection *cx)
			MBEDTLS_SSL_MAJOR_VERSION_3,
			MBEDTLS_SSL_MINOR_VERSION_3);

	mbedtls_ssl_conf_rng(&(cx->tls_conf), mbedtls_ctr_drbg_random, &tls_ctr_drbg);
	if ((ret = mbedtls_ctr_drbg_seed(
			&(cx->tls_ctr_drbg),
			mbedtls_entropy_func,
			&(cx->tls_entropy),
			pers,
			sizeof(pers)))) {
		io_error(cx, " .. %s ", io_tls_err(ret));
		goto err;
	}

	if (ca_cert_path && *ca_cert_path) {

		if ((ret = mbedtls_x509_crt_parse_file(&(cx->tls_x509_crt), ca_cert_path)) < 0) {
			io_error(cx, "  .. Failed to load ca cert: '%s': %s", ca_cert_path, io_tls_err(ret));
			goto err;
		}

	} else {

		for (size_t i = 0; i < ARR_LEN(ca_cert_paths); i++) {
			if ((ret = mbedtls_x509_crt_parse_file(&(cx->tls_x509_crt), ca_cert_paths[i])) >= 0) {
				break;
			} else if (i == ARR_LEN(ca_cert_paths)) {
				io_error(cx, "  .. Failed to load ca cert: %s", io_tls_err(ret));
				goto err;
			}
		}
	}

	mbedtls_ssl_conf_rng(&(cx->tls_conf), mbedtls_ctr_drbg_random, &(cx->tls_ctr_drbg));

	if (cx->flags & IO_TLS_VRFY_DISABLED) {
		mbedtls_ssl_conf_authmode(&(cx->tls_conf), MBEDTLS_SSL_VERIFY_NONE);
	} else {
		mbedtls_ssl_conf_ca_chain(&(cx->tls_conf), &tls_x509_crt, NULL);
		mbedtls_ssl_conf_ca_chain(&(cx->tls_conf), &(cx->tls_x509_crt), NULL);

		if (cx->flags & IO_TLS_VRFY_OPTIONAL)
			mbedtls_ssl_conf_authmode(&(cx->tls_conf), MBEDTLS_SSL_VERIFY_OPTIONAL);


@@ 867,8 904,11 @@ err:

	io_error(cx, " .. TLS connection failure");

	mbedtls_ctr_drbg_free(&(cx->tls_ctr_drbg));
	mbedtls_entropy_free(&(cx->tls_entropy));
	mbedtls_ssl_config_free(&(cx->tls_conf));
	mbedtls_ssl_free(&(cx->tls_ctx));
	mbedtls_x509_crt_free(&(cx->tls_x509_crt));
	mbedtls_net_free(&(cx->net_ctx));

	return -1;


@@ 914,51 954,3 @@ io_tls_err(int err)

	return "Unknown error";
}

static void
io_tls_init(void)
{
	const unsigned char pers[] = "rirc-drbg-seed";
	int ret;

	mbedtls_ctr_drbg_init(&tls_ctr_drbg);
	mbedtls_entropy_init(&tls_entropy);
	mbedtls_x509_crt_init(&tls_x509_crt);

	if (atexit(io_tls_term))
		fatal("atexit");

	if ((ret = mbedtls_ctr_drbg_seed(
			&tls_ctr_drbg,
			mbedtls_entropy_func,
			&tls_entropy,
			pers,
			sizeof(pers)))) {
		fatal("mbedtls_ctr_drbg_seed: %s", io_tls_err(ret));
	}

	if (ca_cert_path && *ca_cert_path) {

		if ((ret = mbedtls_x509_crt_parse_file(&tls_x509_crt, ca_cert_path)) < 0)
			fatal("mbedtls_x509_crt_parse_file: '%s': %s", ca_cert_path, io_tls_err(ret));

	} else {

		for (size_t i = 0; i < ARR_LEN(ca_cert_paths); i++) {
			if ((ret = mbedtls_x509_crt_parse_file(&tls_x509_crt, ca_cert_paths[i])) >= 0)
				return;
		}

		fatal("Failed to load ca cert: %s", io_tls_err(ret));
	}
}

static void
io_tls_term(void)
{
	/* Exit handler, must return normally */

	mbedtls_ctr_drbg_free(&tls_ctr_drbg);
	mbedtls_entropy_free(&tls_entropy);
	mbedtls_x509_crt_free(&tls_x509_crt);
}