diff --git a/src/network/backend.c b/src/network/backend.c index 62c892d..e347b3e 100644 --- a/src/network/backend.c +++ b/src/network/backend.c @@ -1,6 +1,6 @@ /* * uhub - A tiny ADC p2p connection hub - * Copyright (C) 2007-2010, Jan Vidar Krey + * Copyright (C) 2007-2012, Jan Vidar Krey * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -180,10 +180,8 @@ void net_con_close(struct net_connection* con) g_backend->handler.con_del(g_backend->data, con); #ifdef SSL_SUPPORT -#ifdef SSL_USE_OPENSSL if (con->ssl) net_ssl_shutdown(con); -#endif /* SSL_USE_OPENSSL */ #endif /* SSL_SUPPORT */ net_close(con->sd); diff --git a/src/network/backend.h b/src/network/backend.h index ba5febe..b908f28 100644 --- a/src/network/backend.h +++ b/src/network/backend.h @@ -1,6 +1,6 @@ /* * uhub - A tiny ADC p2p connection hub - * Copyright (C) 2007-2010, Jan Vidar Krey + * Copyright (C) 2007-2012, Jan Vidar Krey * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by diff --git a/src/network/connection.c b/src/network/connection.c index 56a27c3..5f664cf 100644 --- a/src/network/connection.c +++ b/src/network/connection.c @@ -25,6 +25,18 @@ void net_stats_add_tx(size_t bytes); void net_stats_add_rx(size_t bytes); #endif +static int is_blocked_or_interrupted() +{ + int err = net_error(); + return +#ifdef WINSOCK + err == WSAEWOULDBLOCK +#else + err == EWOULDBLOCK +#endif + || err == EINTR; +} + ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len) { int ret; @@ -35,13 +47,7 @@ ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len) ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL); if (ret == -1) { - if ( -#ifdef WINSOCK - net_error() == WSAEWOULDBLOCK -#else - net_error() == EWOULDBLOCK -#endif - || net_error() == EINTR) + if (is_blocked_or_interrupted()) return 0; return -1; } @@ -65,13 +71,7 @@ ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len) ret = net_recv(con->sd, buf, len, 0); if (ret == -1) { - if ( -#ifdef WINSOCK - net_error() == WSAEWOULDBLOCK -#else - net_error() == EWOULDBLOCK -#endif - || net_error() == EINTR) + if (is_blocked_or_interrupted()) return 0; return -net_error(); } @@ -94,13 +94,7 @@ ssize_t net_con_peek(struct net_connection* con, void* buf, size_t len) int ret = net_recv(con->sd, buf, len, MSG_PEEK); if (ret == -1) { - if ( -#ifdef WINSOCK - net_error() == WSAEWOULDBLOCK -#else - net_error() == EWOULDBLOCK -#endif - || net_error() == EINTR) + if (is_blocked_or_interrupted()) return 0; return -net_error(); } @@ -149,12 +143,10 @@ void net_con_callback(struct net_connection* con, int events) } #ifdef SSL_SUPPORT - if (!con->ssl) + if (con->ssl) + net_ssl_callback(con, events); + else #endif con->callback(con, events, con->ptr); -#ifdef SSL_SUPPORT - else - net_ssl_callback(con, events); -#endif } diff --git a/src/network/openssl.c b/src/network/openssl.c index 9dc0866..cd09d4d 100644 --- a/src/network/openssl.c +++ b/src/network/openssl.c @@ -30,6 +30,9 @@ struct net_ssl_openssl { SSL* ssl; enum ssl_state state; + uint32_t flags; + BIO* biow; + BIO* bior; }; struct net_context_openssl @@ -108,47 +111,6 @@ int ssl_check_private_key(struct ssl_context_handle* ctx_) return 1; } -static int handle_openssl_error(struct net_connection* con, int ret) -{ - struct net_ssl_openssl* handle = get_handle(con); - - int error = SSL_get_error(handle->ssl, ret); - switch (error) - { - case SSL_ERROR_ZERO_RETURN: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_ZERO_RETURN", ret, error); - handle->state = tls_st_error; - return -1; - - case SSL_ERROR_WANT_READ: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_READ", ret, error); - con->flags |= NET_WANT_SSL_READ; - net_con_update(con, NET_EVENT_READ); - return 0; - - case SSL_ERROR_WANT_WRITE: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_WRITE", ret, error); - con->flags |= NET_WANT_SSL_WRITE; - net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE); - return 0; - - case SSL_ERROR_SYSCALL: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SYSCALL", ret, error); - /* if ret == 0, connection closed, if ret == -1, check with errno */ - if (ret == 0) - return -1; - else - return -net_error(); - - case SSL_ERROR_SSL: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SSL", ret, error); - /* internal openssl error */ - handle->state = tls_st_error; - return -1; - } - return -1; -} - ssize_t net_con_ssl_accept(struct net_connection* con) { struct net_ssl_openssl* handle = get_handle(con); @@ -164,7 +126,24 @@ ssize_t net_con_ssl_accept(struct net_connection* con) } else { - return handle_openssl_error(con, ret); + int err = SSL_get_error(handle->ssl, ret); + switch (err) + { + case SSL_ERROR_ZERO_RETURN: + // Not really an error, but SSL was shut down. + return -1; + + case SSL_ERROR_WANT_READ: + net_con_update(con, NET_EVENT_READ); + return 0; + + case SSL_ERROR_WANT_WRITE: + net_con_update(con, NET_EVENT_WRITE); + return 0; + + case SSL_ERROR_SYSCALL: + return -1; + } } return ret; } @@ -175,7 +154,6 @@ ssize_t net_con_ssl_connect(struct net_connection* con) ssize_t ret; handle->state = tls_st_connecting; - ret = SSL_connect(handle->ssl); #ifdef NETWORK_DUMP_DEBUG LOG_PROTO("SSL_connect() ret=%d", ret); @@ -188,7 +166,24 @@ ssize_t net_con_ssl_connect(struct net_connection* con) } else { - return handle_openssl_error(con, ret); + int err = SSL_get_error(handle->ssl, ret); + switch (err) + { + case SSL_ERROR_ZERO_RETURN: + // Not really an error, but SSL was shut down. + return -1; + + case SSL_ERROR_WANT_READ: + net_con_update(con, NET_EVENT_READ); + return 0; + + case SSL_ERROR_WANT_WRITE: + net_con_update(con, NET_EVENT_WRITE); + return 0; + + case SSL_ERROR_SYSCALL: + return -1; + } } return ret; } @@ -209,6 +204,8 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode return -1; } SSL_set_fd(handle->ssl, con->sd); + handle->bior = SSL_get_rbio(handle->ssl); + handle->biow = SSL_get_wbio(handle->ssl); con->ssl = (struct ssl_handle*) handle; return net_con_ssl_accept(con); } @@ -216,6 +213,8 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode { handle->ssl = SSL_new(SSL_CTX_new(TLSv1_method())); SSL_set_fd(handle->ssl, con->sd); + handle->bior = SSL_get_rbio(handle->ssl); + handle->biow = SSL_get_wbio(handle->ssl); con->ssl = (struct ssl_handle*) handle; return net_con_ssl_connect(con); } @@ -225,14 +224,40 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len) { struct net_ssl_openssl* handle = get_handle(con); -// con->write_len = len; + + uhub_assert(handle->state == tls_st_connected || handle->state == tls_st_need_write); + + ERR_clear_error(); ssize_t ret = SSL_write(handle->ssl, buf, len); LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); - if (ret <= 0) + if (ret > 0) { - return handle_openssl_error(con, ret); + handle->state = tls_st_connected; + return ret; + } + else if (ret <= 0) + { + int err = SSL_get_error(handle->ssl, ret); + switch (err) + { + case SSL_ERROR_ZERO_RETURN: + // Not really an error, but SSL was shut down. + return -1; + + case SSL_ERROR_WANT_READ: + handle->state = tls_st_need_write; + net_con_update(con, NET_EVENT_READ); + return 0; + + case SSL_ERROR_WANT_WRITE: + handle->state = tls_st_need_write; + net_con_update(con, NET_EVENT_WRITE); + return 0; + + case SSL_ERROR_SYSCALL: + return -1; + } } - return ret; } ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len) @@ -243,17 +268,40 @@ ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len) if (handle->state == tls_st_error) return -1; + uhub_assert(handle->state == tls_st_connected || handle->state == tls_st_need_read); + + ERR_clear_error(); + ret = SSL_read(handle->ssl, buf, len); LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); if (ret > 0) { - net_con_update(con, NET_EVENT_READ); + handle->state = tls_st_connected; + return ret; } - else + else if (ret <= 0) { - return handle_openssl_error(con, ret); + int err = SSL_get_error(handle->ssl, ret); + switch (err) + { + case SSL_ERROR_ZERO_RETURN: + // Not really an error, but SSL was shut down. + return -1; + + case SSL_ERROR_WANT_READ: + handle->state = tls_st_need_read; + net_con_update(con, NET_EVENT_READ); + return 0; + + case SSL_ERROR_WANT_WRITE: + handle->state = tls_st_need_read; + net_con_update(con, NET_EVENT_WRITE); + return 0; + + case SSL_ERROR_SYSCALL: + return -1; + } } - return ret; } void net_ssl_shutdown(struct net_connection* con) @@ -299,20 +347,15 @@ void net_ssl_callback(struct net_connection* con, int events) con->callback(con, NET_EVENT_READ, con->ptr); break; + case tls_st_need_read: + con->callback(con, NET_EVENT_READ, con->ptr); + break; + + case tls_st_need_write: + con->callback(con, NET_EVENT_WRITE, con->ptr); + break; + case tls_st_connected: - LOG_PROTO("tls_st_connected, events=%s%s, ssl_flags=%s%s", (events & NET_EVENT_READ ? "R" : ""), (events & NET_EVENT_WRITE ? "W" : ""), flags & NET_WANT_SSL_READ ? "R" : "", flags & NET_WANT_SSL_WRITE ? "W" : ""); - if (events & NET_EVENT_WRITE && flags & NET_WANT_SSL_READ) - { - con->callback(con, events & NET_EVENT_READ, con->ptr); - return; - } - - if (events & NET_EVENT_READ && flags & NET_WANT_SSL_WRITE) - { - con->callback(con, events & NET_EVENT_READ, con->ptr); - return; - } - con->callback(con, events, con->ptr); break; diff --git a/src/network/tls.h b/src/network/tls.h index 08b4da6..25391a4 100644 --- a/src/network/tls.h +++ b/src/network/tls.h @@ -32,6 +32,8 @@ enum ssl_state tls_st_accepting, tls_st_connecting, tls_st_connected, + tls_st_need_read, /* special case of connected */ + tls_st_need_write, /* special case of connected */ tls_st_disconnecting, };