diff --git a/src/network/connection.h b/src/network/connection.h index e151d0c..490e636 100644 --- a/src/network/connection.h +++ b/src/network/connection.h @@ -28,6 +28,7 @@ #define NET_EVENT_TIMEOUT 0x0001 #define NET_EVENT_READ 0x0002 #define NET_EVENT_WRITE 0x0004 +#define NET_EVENT_ERROR 0x1000 struct net_connection { diff --git a/src/network/openssl.c b/src/network/openssl.c index affa715..d79552b 100644 --- a/src/network/openssl.c +++ b/src/network/openssl.c @@ -52,6 +52,27 @@ static struct net_ssl_openssl* get_handle(struct net_connection* con) return (struct net_ssl_openssl*) con->ssl; } +static const char* get_state_str(enum ssl_state state) +{ + switch (state) + { + case tls_st_none: return "tls_st_none"; + case tls_st_error: return "tls_st_error"; + case tls_st_accepting: return "tls_st_accepting"; + case tls_st_connecting: return "tls_st_connecting"; + case tls_st_connected: return "tls_st_connected"; + case tls_st_disconnecting: return "tls_st_disconnecting"; + } + uhub_assert(!"This should not happen - invalid state!"); + return "(UNKNOWN STATE)"; +} + +static void net_ssl_set_state(struct net_ssl_openssl* handle, enum ssl_state new_state) +{ + LOG_ERROR("net_ssl_set_state(): prev_state=%s, new_state=%s", get_state_str(handle->state), get_state_str(new_state)); + handle->state = new_state; +} + const char* net_ssl_get_provider() { return OPENSSL_VERSION_TEXT; @@ -73,9 +94,9 @@ int net_ssl_library_shutdown() ENGINE_cleanup(); CONF_modules_unload(1); - ERR_free_strings(); + ERR_free_strings(); EVP_cleanup(); - CRYPTO_cleanup_all_ex_data(); + CRYPTO_cleanup_all_ex_data(); // sk_SSL_COMP_free(SSL_COMP_get_compression_methods()); return 1; @@ -220,26 +241,32 @@ static int handle_openssl_error(struct net_connection* con, int ret, int read) handle->ssl_write_events = NET_EVENT_WRITE; return 0; + case SSL_ERROR_SSL: + net_ssl_set_state(handle, tls_st_error); + return -2; + case SSL_ERROR_SYSCALL: - handle->state = tls_st_error; + net_ssl_set_state(handle, tls_st_error); return -2; } + LOG_ERROR("WTF?: err=%d, ret=%d", err, ret); + return -2; } ssize_t net_con_ssl_accept(struct net_connection* con) { struct net_ssl_openssl* handle = get_handle(con); - handle->state = tls_st_accepting; ssize_t ret; + net_ssl_set_state(handle, tls_st_accepting); ret = SSL_accept(handle->ssl); LOG_PROTO("SSL_accept() ret=%d", ret); if (ret > 0) { net_con_update(con, NET_EVENT_READ); - handle->state = tls_st_connected; + net_ssl_set_state(handle, tls_st_connected); return ret; } return handle_openssl_error(con, ret, tls_st_accepting); @@ -249,18 +276,21 @@ ssize_t net_con_ssl_connect(struct net_connection* con) { struct net_ssl_openssl* handle = get_handle(con); ssize_t ret; - handle->state = tls_st_connecting; + net_ssl_set_state(handle, tls_st_connecting); ret = SSL_connect(handle->ssl); LOG_PROTO("SSL_connect() ret=%d", ret); if (ret > 0) { - handle->state = tls_st_connected; net_con_update(con, NET_EVENT_READ); + net_ssl_set_state(handle, tls_st_connected); return ret; } - return handle_openssl_error(con, ret, tls_st_connecting); + + ret = handle_openssl_error(con, ret, tls_st_connecting); + LOG_ERROR("net_con_ssl_connect: ret=%d", ret); + return ret; } ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssl_mode, struct ssl_context_handle* ssl_ctx) @@ -298,7 +328,13 @@ ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len) { struct net_ssl_openssl* handle = get_handle(con); + LOG_ERROR("net_ssl_send(), state=%d", (int) handle->state); + + if (handle->state == tls_st_error) + return -2; + uhub_assert(handle->state == tls_st_connected); + ERR_clear_error(); ssize_t ret = SSL_write(handle->ssl, buf, len); @@ -321,6 +357,9 @@ ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len) if (handle->state == tls_st_error) return -2; + if (handle->state == tls_st_accepting || handle->state == tls_st_connecting) + return -1; + uhub_assert(handle->state == tls_st_connected); ERR_clear_error(); @@ -347,8 +386,11 @@ void net_ssl_update(struct net_connection* con, int events) void net_ssl_shutdown(struct net_connection* con) { struct net_ssl_openssl* handle = get_handle(con); - SSL_shutdown(handle->ssl); - SSL_clear(handle->ssl); + if (handle) + { + SSL_shutdown(handle->ssl); + SSL_clear(handle->ssl); + } } void net_ssl_destroy(struct net_connection* con) @@ -362,6 +404,7 @@ void net_ssl_destroy(struct net_connection* con) void net_ssl_callback(struct net_connection* con, int events) { struct net_ssl_openssl* handle = get_handle(con); + int ret; switch (handle->state) { @@ -370,7 +413,7 @@ void net_ssl_callback(struct net_connection* con, int events) break; case tls_st_error: - con->callback(con, NET_EVENT_READ, con->ptr); + con->callback(con, NET_EVENT_ERROR, con->ptr); break; case tls_st_accepting: @@ -379,8 +422,20 @@ void net_ssl_callback(struct net_connection* con, int events) break; case tls_st_connecting: - if (net_con_ssl_connect(con) != 0) + ret = net_con_ssl_connect(con); + if (ret == 0) + return; + + if (ret > 0) + { + LOG_ERROR("%p SSL connected!", con); con->callback(con, NET_EVENT_READ, con->ptr); + } + else + { + LOG_ERROR("%p SSL handshake failed!"); + con->callback(con, NET_EVENT_ERROR, con->ptr); + } break; case tls_st_connected: diff --git a/src/tools/adcclient.c b/src/tools/adcclient.c index 437cd3b..efd315c 100644 --- a/src/tools/adcclient.c +++ b/src/tools/adcclient.c @@ -190,7 +190,16 @@ static void event_callback(struct net_connection* con, int events, void *arg) return; } - ADC_client_on_connected_ssl(client); + if (events == NET_EVENT_ERROR) + { + ADC_client_on_disconnected(client); + client->callback(client, ADC_CLIENT_DISCONNECTED, 0); + return; + } + else + { + ADC_client_on_connected_ssl(client); + } break; #endif