diff --git a/src/network/backend.c b/src/network/backend.c index 351938c..384b14e 100644 --- a/src/network/backend.c +++ b/src/network/backend.c @@ -102,14 +102,7 @@ void net_backend_shutdown() } -void net_con_reinitialize(struct net_connection* con, net_connection_cb callback, const void* ptr, int events) -{ - con->callback = callback; - con->ptr = (void*) ptr; - net_con_update(con, events); -} - -void net_con_update(struct net_connection* con, int events) +void net_backend_update(struct net_connection* con, int events) { g_backend->handler.con_mod(g_backend->data, con, events); } diff --git a/src/network/backend.h b/src/network/backend.h index 40c8d82..0448d69 100644 --- a/src/network/backend.h +++ b/src/network/backend.h @@ -75,6 +75,14 @@ extern void net_backend_shutdown(); */ extern int net_backend_process(); +/** + * Update the event mask. + * + * @param con Connection handle. + * @param events Event mask (NET_EVENT_*) + */ +extern void net_backend_update(struct net_connection* con, int events); + /** * Get the current time. */ diff --git a/src/network/connection.c b/src/network/connection.c index 5497f2c..4ef20d7 100644 --- a/src/network/connection.c +++ b/src/network/connection.c @@ -19,6 +19,7 @@ #include "uhub.h" #include "network/common.h" +#include "network/backend.h" static int is_blocked_or_interrupted() { @@ -116,6 +117,23 @@ void* net_con_get_ptr(struct net_connection* con) return con->ptr; } +void net_con_update(struct net_connection* con, int events) +{ +#ifdef SSL_SUPPORT + if (con->ssl) + net_ssl_update(con, events); + else +#endif + net_backend_update(con, events); +} + +void net_con_reinitialize(struct net_connection* con, net_connection_cb callback, const void* ptr, int events) +{ + con->callback = callback; + con->ptr = (void*) ptr; + net_con_update(con, events); +} + void net_con_destroy(struct net_connection* con) { #ifdef SSL_SUPPORT diff --git a/src/network/openssl.c b/src/network/openssl.c index d2ef61f..09efb48 100644 --- a/src/network/openssl.c +++ b/src/network/openssl.c @@ -20,6 +20,7 @@ #include "uhub.h" #include "network/common.h" #include "network/tls.h" +#include "network/backend.h" #ifdef SSL_SUPPORT #ifdef SSL_USE_OPENSSL @@ -32,6 +33,9 @@ struct net_ssl_openssl SSL* ssl; BIO* bio; enum ssl_state state; + int events; + int ssl_read_events; + int ssl_write_events; uint32_t flags; size_t bytes_rx; size_t bytes_tx; @@ -158,7 +162,7 @@ int ssl_check_private_key(struct ssl_context_handle* ctx_) return 1; } -static int handle_openssl_error(struct net_connection* con, int ret, enum ssl_state forced_rwstate) +static int handle_openssl_error(struct net_connection* con, int ret, int read) { struct net_ssl_openssl* handle = get_handle(con); int err = SSL_get_error(handle->ssl, ret); @@ -169,13 +173,17 @@ static int handle_openssl_error(struct net_connection* con, int ret, enum ssl_st return -1; case SSL_ERROR_WANT_READ: - handle->state = forced_rwstate; - net_con_update(con, NET_EVENT_READ); + if (read) + handle->ssl_read_events = NET_EVENT_READ; + else + handle->ssl_write_events = NET_EVENT_READ; return 0; case SSL_ERROR_WANT_WRITE: - handle->state = forced_rwstate; - net_con_update(con, NET_EVENT_WRITE); + if (read) + handle->ssl_read_events = NET_EVENT_WRITE; + else + handle->ssl_write_events = NET_EVENT_WRITE; return 0; case SSL_ERROR_SYSCALL: @@ -249,25 +257,25 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode con->ssl = (struct ssl_handle*) handle; return net_con_ssl_connect(con); } - } ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len) { struct net_ssl_openssl* handle = get_handle(con); - uhub_assert(handle->state == tls_st_connected || handle->state == tls_st_need_write); + uhub_assert(handle->state == tls_st_connected); ERR_clear_error(); ssize_t ret = SSL_write(handle->ssl, buf, len); add_io_stats(handle); LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); if (ret > 0) - { - handle->state = tls_st_connected; - return ret; - } - return handle_openssl_error(con, ret, tls_st_need_write); + handle->ssl_write_events = 0; + else + ret = handle_openssl_error(con, ret, 0); + + net_ssl_update(con, handle->events); // Update backend only + return ret; } ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len) @@ -278,7 +286,7 @@ ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len) if (handle->state == tls_st_error) return -2; - uhub_assert(handle->state == tls_st_connected || handle->state == tls_st_need_read); + uhub_assert(handle->state == tls_st_connected); ERR_clear_error(); @@ -286,11 +294,19 @@ ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len) add_io_stats(handle); LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); if (ret > 0) - { - handle->state = tls_st_connected; - return ret; - } - return handle_openssl_error(con, ret, tls_st_need_read); + handle->ssl_read_events = 0; + else + ret = handle_openssl_error(con, ret, 1); + + net_ssl_update(con, handle->events); // Update backend only + return ret; +} + +void net_ssl_update(struct net_connection* con, int events) +{ + struct net_ssl_openssl* handle = get_handle(con); + handle->events = events; + net_backend_update(con, handle->events | handle->ssl_read_events | handle->ssl_write_events); } void net_ssl_shutdown(struct net_connection* con) @@ -331,15 +347,11 @@ void net_ssl_callback(struct net_connection* con, int events) con->callback(con, NET_EVENT_READ, con->ptr); break; - case tls_st_need_read: - con->callback(con, NET_EVENT_READ, con->ptr); - break; - - case tls_st_need_write: - con->callback(con, NET_EVENT_WRITE, con->ptr); - break; - case tls_st_connected: + if (handle->ssl_read_events & events) + events |= NET_EVENT_READ; + if (handle->ssl_write_events & events) + events |= NET_EVENT_WRITE; con->callback(con, events, con->ptr); break; diff --git a/src/network/tls.h b/src/network/tls.h index b309709..b1822d0 100644 --- a/src/network/tls.h +++ b/src/network/tls.h @@ -32,8 +32,6 @@ enum ssl_state tls_st_accepting, tls_st_connecting, tls_st_connected, - tls_st_need_read, /* special case of connected */ - tls_st_need_write, /* special case of connected */ tls_st_disconnecting, }; @@ -90,6 +88,15 @@ extern ssize_t net_con_ssl_connect(struct net_connection*); extern ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len); extern ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len); +/** + * Update the event mask. Additional events may be requested depending on the + * needs of the TLS layer. + * + * @param con Connection handle. + * @param events Event mask (NET_EVENT_*) + */ +extern void net_ssl_update(struct net_connection* con, int events); + extern void net_ssl_shutdown(struct net_connection* con); extern void net_ssl_destroy(struct net_connection* con); extern void net_ssl_callback(struct net_connection* con, int events);