Fix TLS protocol mismatch assert causing the hub to shutdown.

This commit is contained in:
Jan Vidar Krey 2014-10-16 23:08:17 +02:00
parent 1da917e5b9
commit 76ff2a1a13
3 changed files with 78 additions and 13 deletions

View File

@ -28,6 +28,7 @@
#define NET_EVENT_TIMEOUT 0x0001 #define NET_EVENT_TIMEOUT 0x0001
#define NET_EVENT_READ 0x0002 #define NET_EVENT_READ 0x0002
#define NET_EVENT_WRITE 0x0004 #define NET_EVENT_WRITE 0x0004
#define NET_EVENT_ERROR 0x1000
struct net_connection struct net_connection
{ {

View File

@ -52,6 +52,27 @@ static struct net_ssl_openssl* get_handle(struct net_connection* con)
return (struct net_ssl_openssl*) con->ssl; return (struct net_ssl_openssl*) con->ssl;
} }
static const char* get_state_str(enum ssl_state state)
{
switch (state)
{
case tls_st_none: return "tls_st_none";
case tls_st_error: return "tls_st_error";
case tls_st_accepting: return "tls_st_accepting";
case tls_st_connecting: return "tls_st_connecting";
case tls_st_connected: return "tls_st_connected";
case tls_st_disconnecting: return "tls_st_disconnecting";
}
uhub_assert(!"This should not happen - invalid state!");
return "(UNKNOWN STATE)";
}
static void net_ssl_set_state(struct net_ssl_openssl* handle, enum ssl_state new_state)
{
LOG_ERROR("net_ssl_set_state(): prev_state=%s, new_state=%s", get_state_str(handle->state), get_state_str(new_state));
handle->state = new_state;
}
const char* net_ssl_get_provider() const char* net_ssl_get_provider()
{ {
return OPENSSL_VERSION_TEXT; return OPENSSL_VERSION_TEXT;
@ -220,26 +241,32 @@ static int handle_openssl_error(struct net_connection* con, int ret, int read)
handle->ssl_write_events = NET_EVENT_WRITE; handle->ssl_write_events = NET_EVENT_WRITE;
return 0; return 0;
case SSL_ERROR_SSL:
net_ssl_set_state(handle, tls_st_error);
return -2;
case SSL_ERROR_SYSCALL: case SSL_ERROR_SYSCALL:
handle->state = tls_st_error; net_ssl_set_state(handle, tls_st_error);
return -2; return -2;
} }
LOG_ERROR("WTF?: err=%d, ret=%d", err, ret);
return -2; return -2;
} }
ssize_t net_con_ssl_accept(struct net_connection* con) ssize_t net_con_ssl_accept(struct net_connection* con)
{ {
struct net_ssl_openssl* handle = get_handle(con); struct net_ssl_openssl* handle = get_handle(con);
handle->state = tls_st_accepting;
ssize_t ret; ssize_t ret;
net_ssl_set_state(handle, tls_st_accepting);
ret = SSL_accept(handle->ssl); ret = SSL_accept(handle->ssl);
LOG_PROTO("SSL_accept() ret=%d", ret); LOG_PROTO("SSL_accept() ret=%d", ret);
if (ret > 0) if (ret > 0)
{ {
net_con_update(con, NET_EVENT_READ); net_con_update(con, NET_EVENT_READ);
handle->state = tls_st_connected; net_ssl_set_state(handle, tls_st_connected);
return ret; return ret;
} }
return handle_openssl_error(con, ret, tls_st_accepting); return handle_openssl_error(con, ret, tls_st_accepting);
@ -249,18 +276,21 @@ ssize_t net_con_ssl_connect(struct net_connection* con)
{ {
struct net_ssl_openssl* handle = get_handle(con); struct net_ssl_openssl* handle = get_handle(con);
ssize_t ret; ssize_t ret;
handle->state = tls_st_connecting; net_ssl_set_state(handle, tls_st_connecting);
ret = SSL_connect(handle->ssl); ret = SSL_connect(handle->ssl);
LOG_PROTO("SSL_connect() ret=%d", ret); LOG_PROTO("SSL_connect() ret=%d", ret);
if (ret > 0) if (ret > 0)
{ {
handle->state = tls_st_connected;
net_con_update(con, NET_EVENT_READ); net_con_update(con, NET_EVENT_READ);
net_ssl_set_state(handle, tls_st_connected);
return ret; return ret;
} }
return handle_openssl_error(con, ret, tls_st_connecting);
ret = handle_openssl_error(con, ret, tls_st_connecting);
LOG_ERROR("net_con_ssl_connect: ret=%d", ret);
return ret;
} }
ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssl_mode, struct ssl_context_handle* ssl_ctx) ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode ssl_mode, struct ssl_context_handle* ssl_ctx)
@ -298,8 +328,14 @@ ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len)
{ {
struct net_ssl_openssl* handle = get_handle(con); struct net_ssl_openssl* handle = get_handle(con);
LOG_ERROR("net_ssl_send(), state=%d", (int) handle->state);
if (handle->state == tls_st_error)
return -2;
uhub_assert(handle->state == tls_st_connected); uhub_assert(handle->state == tls_st_connected);
ERR_clear_error(); ERR_clear_error();
ssize_t ret = SSL_write(handle->ssl, buf, len); ssize_t ret = SSL_write(handle->ssl, buf, len);
add_io_stats(handle); add_io_stats(handle);
@ -321,6 +357,9 @@ ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len)
if (handle->state == tls_st_error) if (handle->state == tls_st_error)
return -2; return -2;
if (handle->state == tls_st_accepting || handle->state == tls_st_connecting)
return -1;
uhub_assert(handle->state == tls_st_connected); uhub_assert(handle->state == tls_st_connected);
ERR_clear_error(); ERR_clear_error();
@ -347,9 +386,12 @@ void net_ssl_update(struct net_connection* con, int events)
void net_ssl_shutdown(struct net_connection* con) void net_ssl_shutdown(struct net_connection* con)
{ {
struct net_ssl_openssl* handle = get_handle(con); struct net_ssl_openssl* handle = get_handle(con);
if (handle)
{
SSL_shutdown(handle->ssl); SSL_shutdown(handle->ssl);
SSL_clear(handle->ssl); SSL_clear(handle->ssl);
} }
}
void net_ssl_destroy(struct net_connection* con) void net_ssl_destroy(struct net_connection* con)
{ {
@ -362,6 +404,7 @@ void net_ssl_destroy(struct net_connection* con)
void net_ssl_callback(struct net_connection* con, int events) void net_ssl_callback(struct net_connection* con, int events)
{ {
struct net_ssl_openssl* handle = get_handle(con); struct net_ssl_openssl* handle = get_handle(con);
int ret;
switch (handle->state) switch (handle->state)
{ {
@ -370,7 +413,7 @@ void net_ssl_callback(struct net_connection* con, int events)
break; break;
case tls_st_error: case tls_st_error:
con->callback(con, NET_EVENT_READ, con->ptr); con->callback(con, NET_EVENT_ERROR, con->ptr);
break; break;
case tls_st_accepting: case tls_st_accepting:
@ -379,8 +422,20 @@ void net_ssl_callback(struct net_connection* con, int events)
break; break;
case tls_st_connecting: case tls_st_connecting:
if (net_con_ssl_connect(con) != 0) ret = net_con_ssl_connect(con);
if (ret == 0)
return;
if (ret > 0)
{
LOG_ERROR("%p SSL connected!", con);
con->callback(con, NET_EVENT_READ, con->ptr); con->callback(con, NET_EVENT_READ, con->ptr);
}
else
{
LOG_ERROR("%p SSL handshake failed!");
con->callback(con, NET_EVENT_ERROR, con->ptr);
}
break; break;
case tls_st_connected: case tls_st_connected:

View File

@ -190,7 +190,16 @@ static void event_callback(struct net_connection* con, int events, void *arg)
return; return;
} }
if (events == NET_EVENT_ERROR)
{
ADC_client_on_disconnected(client);
client->callback(client, ADC_CLIENT_DISCONNECTED, 0);
return;
}
else
{
ADC_client_on_connected_ssl(client); ADC_client_on_connected_ssl(client);
}
break; break;
#endif #endif