diff --git a/src/core/netevent.c b/src/core/netevent.c index 8b90b24..0afe7cf 100644 --- a/src/core/netevent.c +++ b/src/core/netevent.c @@ -27,7 +27,7 @@ extern struct hub_info* g_hub; #ifdef DEBUG_SENDQ void debug_sendq_send(struct hub_user* user, int sent, int total) { - LOG_DUMP("SEND: sd=%d, %d/%d bytes\n", user->net.connection.sd, sent, total); + LOG_DUMP("SEND: sd=%d, %d/%d bytes\n", user->connection->sd, sent, total); if (sent == -1) { int err = net_error(); @@ -162,7 +162,7 @@ void net_event(struct net_connection* con, int event, void *arg) int flag_close = 0; #ifdef DEBUG_SENDQ - LOG_TRACE("net_event() : fd=%d, ev=%d, arg=%p", fd, (int) event, arg); + LOG_TRACE("net_event() : fd=%d, ev=%d, arg=%p", con->sd, (int) event, arg); #endif if (event == NET_EVENT_TIMEOUT) diff --git a/src/network/common.h b/src/network/common.h index 5fe7fa8..014ddfb 100644 --- a/src/network/common.h +++ b/src/network/common.h @@ -28,7 +28,6 @@ #define NET_CLEANUP 0x8000 - #define NET_CON_STRUCT_BASIC \ int sd; /** socket descriptor */ \ uint32_t flags; /** Connection flags */ \ @@ -38,6 +37,7 @@ #define NET_CON_STRUCT_SSL \ SSL* ssl; /** SSL handle */ \ + uint32_t ssl_state; /** SSL state */ \ size_t write_len; /** Length of last SSL_write(), only used if flags is NET_WANT_SSL_READ. */ \ #ifdef SSL_SUPPORT diff --git a/src/network/connection.c b/src/network/connection.c index 27c2363..59c3795 100644 --- a/src/network/connection.c +++ b/src/network/connection.c @@ -21,6 +21,17 @@ #include "network/common.h" #ifdef SSL_SUPPORT + +enum uhub_tls_state +{ + tls_st_none, + tls_st_error, + tls_st_accepting, + tls_st_connecting, + tls_st_connected, + tls_st_disconnecting, +}; + static int handle_openssl_error(struct net_connection* con, int ret) { uhub_assert(con); @@ -30,30 +41,19 @@ static int handle_openssl_error(struct net_connection* con, int ret) { case SSL_ERROR_ZERO_RETURN: LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_ZERO_RETURN", ret, error); + con->ssl_state = tls_st_error; return -1; case SSL_ERROR_WANT_READ: LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_READ", ret, error); - net_con_update(con, NET_EVENT_READ | NET_WANT_SSL_READ); + con->flags |= NET_WANT_SSL_READ; + net_con_update(con, NET_EVENT_READ); return 0; case SSL_ERROR_WANT_WRITE: LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_WRITE", ret, error); - net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE | NET_WANT_SSL_WRITE); - return 0; - - case SSL_ERROR_WANT_CONNECT: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_CONNECT", ret, error); - net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE | NET_WANT_SSL_CONNECT); - return 0; - - case SSL_ERROR_WANT_ACCEPT: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_ACCEPT", ret, error); - net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE | NET_WANT_SSL_ACCEPT); - return 0; - - case SSL_ERROR_WANT_X509_LOOKUP: - LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_X509_LOOKUP", ret, error); + con->flags |= NET_WANT_SSL_WRITE; + net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE); return 0; case SSL_ERROR_SYSCALL: @@ -67,6 +67,7 @@ static int handle_openssl_error(struct net_connection* con, int ret) case SSL_ERROR_SSL: LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SSL", ret, error); /* internal openssl error */ + con->ssl_state = tls_st_error; return -1; } @@ -76,7 +77,7 @@ static int handle_openssl_error(struct net_connection* con, int ret) ssize_t net_con_ssl_accept(struct net_connection* con) { uhub_assert(con); - + con->ssl_state = tls_st_accepting; ssize_t ret = SSL_accept(con->ssl); #ifdef NETWORK_DUMP_DEBUG LOG_PROTO("SSL_accept() ret=%d", ret); @@ -84,6 +85,7 @@ ssize_t net_con_ssl_accept(struct net_connection* con) if (ret > 0) { net_con_update(con, NET_EVENT_READ); + con->ssl_state = tls_st_connected; } else { @@ -96,12 +98,14 @@ ssize_t net_con_ssl_connect(struct net_connection* con) { uhub_assert(con); + con->ssl_state = tls_st_connecting; ssize_t ret = SSL_connect(con->ssl); #ifdef NETWORK_DUMP_DEBUG LOG_PROTO("SSL_connect() ret=%d", ret); #endif if (ret > 0) { + con->ssl_state = tls_st_connected; net_con_update(con, NET_EVENT_READ); } else @@ -136,28 +140,42 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len) { - int ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL); - if (ret == -1) + int ret; +#ifdef SSL_SUPPORT + if (!con->ssl) { - if (net_error() == EWOULDBLOCK || net_error() == EINTR) - return 0; - return -1; +#endif + ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL); + if (ret == -1) + { + if (net_error() == EWOULDBLOCK || net_error() == EINTR) + return 0; + return -1; + } +#ifdef SSL_SUPPORT } + else + { + con->write_len = len; + ret = SSL_write(con->ssl, buf, len); + LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); + if (ret <= 0) + { + return -handle_openssl_error(con, ret); + } + } +#endif return ret; } ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len) { - uhub_assert(con); - + int ret; #ifdef SSL_SUPPORT if (!net_con_is_ssl(con)) { #endif - int ret = net_recv(con->sd, buf, len, 0); -#ifdef NETWORK_DUMP_DEBUG - LOG_PROTO("net_recv: ret=%d", ret); -#endif + ret = net_recv(con->sd, buf, len, 0); if (ret == -1) { if (net_error() == EWOULDBLOCK || net_error() == EINTR) @@ -168,16 +186,15 @@ ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len) { return -1; } - - return ret; #ifdef SSL_SUPPORT } else { - int ret = SSL_read(con->ssl, buf, len); -#ifdef NETWORK_DUMP_DEBUG - LOG_PROTO("net_recv: ret=%d", ret); -#endif + if (con->ssl_state == tls_st_error) + return -1; + + ret = SSL_read(con->ssl, buf, len); + LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); if (ret > 0) { net_con_update(con, NET_EVENT_READ); @@ -186,9 +203,9 @@ ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len) { return -handle_openssl_error(con, ret); } - return ret; } #endif + return ret; } ssize_t net_con_peek(struct net_connection* con, void* buf, size_t len) @@ -234,7 +251,7 @@ void* net_con_get_ptr(struct net_connection* con) void net_con_callback(struct net_connection* con, int events) { - if ((con->flags & NET_CLEANUP)) + if (con->flags & NET_CLEANUP) return; if (events == NET_EVENT_TIMEOUT) @@ -254,27 +271,47 @@ void net_con_callback(struct net_connection* con, int events) else { #ifdef NETWORK_DUMP_DEBUG - LOG_PROTO("net_con_event: events=%d, con=%p", ev, con); + LOG_PROTO("net_con_event: events=%d, con=%p, state=%d", events, con, con->ssl_state); #endif - if (events == NET_EVENT_READ && con->flags & NET_WANT_SSL_READ) + switch (con->ssl_state) { - con->callback(con, NET_EVENT_WRITE, con->ptr); - } - else if (events == NET_EVENT_WRITE && con->flags & NET_WANT_SSL_WRITE) - { - con->callback(con, events & NET_EVENT_READ, con->ptr); - } - if (con->flags & NET_WANT_SSL_ACCEPT) - { - net_con_ssl_accept(con); - } - else if (con->flags & NET_WANT_SSL_CONNECT) - { - net_con_ssl_connect(con); - } - else - { - con->callback(con, events, con->ptr); + case tls_st_none: + con->callback(con, events, con->ptr); + break; + + case tls_st_accepting: + if (net_con_ssl_accept(con) < 0) + { + con->callback(con, NET_EVENT_READ, con->ptr); + } + break; + + case tls_st_connecting: + if (net_con_ssl_connect(con) < 0) + { + con->callback(con, NET_EVENT_READ, con->ptr); + } + break; + + case tls_st_connected: + LOG_PROTO("tls_st_connected, events=%s%s, ssl_flags=%s%s", (events & NET_EVENT_READ ? "R" : ""), (events & NET_EVENT_WRITE ? "W" : ""), con->flags & NET_WANT_SSL_READ ? "R" : "", con->flags & NET_WANT_SSL_WRITE ? "W" : ""); + if (events & NET_EVENT_WRITE && con->flags & NET_WANT_SSL_READ) + { + con->callback(con, events & NET_EVENT_READ, con->ptr); + return; + } + + if (events & NET_EVENT_READ && con->flags & NET_WANT_SSL_WRITE) + { + con->callback(con, events & NET_EVENT_READ, con->ptr); + return; + } + + con->callback(con, events, con->ptr); + break; + + case tls_st_disconnecting: + return; } } #endif