Reworked the error handling of non-blocking reads and writes.

This should prevent busy loops where the socket is makred readable
but we are really only looking for it to become writable.
This commit is contained in:
Jan Vidar Krey 2012-10-16 20:15:38 +02:00
parent 50912bdf75
commit b1f2c93738
5 changed files with 130 additions and 95 deletions

View File

@ -1,6 +1,6 @@
/* /*
* uhub - A tiny ADC p2p connection hub * uhub - A tiny ADC p2p connection hub
* Copyright (C) 2007-2010, Jan Vidar Krey * Copyright (C) 2007-2012, Jan Vidar Krey
* *
* This program is free software; you can redistribute it and/or modify * This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by * it under the terms of the GNU General Public License as published by
@ -180,10 +180,8 @@ void net_con_close(struct net_connection* con)
g_backend->handler.con_del(g_backend->data, con); g_backend->handler.con_del(g_backend->data, con);
#ifdef SSL_SUPPORT #ifdef SSL_SUPPORT
#ifdef SSL_USE_OPENSSL
if (con->ssl) if (con->ssl)
net_ssl_shutdown(con); net_ssl_shutdown(con);
#endif /* SSL_USE_OPENSSL */
#endif /* SSL_SUPPORT */ #endif /* SSL_SUPPORT */
net_close(con->sd); net_close(con->sd);

View File

@ -1,6 +1,6 @@
/* /*
* uhub - A tiny ADC p2p connection hub * uhub - A tiny ADC p2p connection hub
* Copyright (C) 2007-2010, Jan Vidar Krey * Copyright (C) 2007-2012, Jan Vidar Krey
* *
* This program is free software; you can redistribute it and/or modify * This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by * it under the terms of the GNU General Public License as published by

View File

@ -25,6 +25,18 @@ void net_stats_add_tx(size_t bytes);
void net_stats_add_rx(size_t bytes); void net_stats_add_rx(size_t bytes);
#endif #endif
static int is_blocked_or_interrupted()
{
int err = net_error();
return
#ifdef WINSOCK
err == WSAEWOULDBLOCK
#else
err == EWOULDBLOCK
#endif
|| err == EINTR;
}
ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len) ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len)
{ {
int ret; int ret;
@ -35,13 +47,7 @@ ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len)
ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL); ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL);
if (ret == -1) if (ret == -1)
{ {
if ( if (is_blocked_or_interrupted())
#ifdef WINSOCK
net_error() == WSAEWOULDBLOCK
#else
net_error() == EWOULDBLOCK
#endif
|| net_error() == EINTR)
return 0; return 0;
return -1; return -1;
} }
@ -65,13 +71,7 @@ ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len)
ret = net_recv(con->sd, buf, len, 0); ret = net_recv(con->sd, buf, len, 0);
if (ret == -1) if (ret == -1)
{ {
if ( if (is_blocked_or_interrupted())
#ifdef WINSOCK
net_error() == WSAEWOULDBLOCK
#else
net_error() == EWOULDBLOCK
#endif
|| net_error() == EINTR)
return 0; return 0;
return -net_error(); return -net_error();
} }
@ -94,13 +94,7 @@ ssize_t net_con_peek(struct net_connection* con, void* buf, size_t len)
int ret = net_recv(con->sd, buf, len, MSG_PEEK); int ret = net_recv(con->sd, buf, len, MSG_PEEK);
if (ret == -1) if (ret == -1)
{ {
if ( if (is_blocked_or_interrupted())
#ifdef WINSOCK
net_error() == WSAEWOULDBLOCK
#else
net_error() == EWOULDBLOCK
#endif
|| net_error() == EINTR)
return 0; return 0;
return -net_error(); return -net_error();
} }
@ -149,12 +143,10 @@ void net_con_callback(struct net_connection* con, int events)
} }
#ifdef SSL_SUPPORT #ifdef SSL_SUPPORT
if (!con->ssl) if (con->ssl)
net_ssl_callback(con, events);
else
#endif #endif
con->callback(con, events, con->ptr); con->callback(con, events, con->ptr);
#ifdef SSL_SUPPORT
else
net_ssl_callback(con, events);
#endif
} }

View File

@ -30,6 +30,9 @@ struct net_ssl_openssl
{ {
SSL* ssl; SSL* ssl;
enum ssl_state state; enum ssl_state state;
uint32_t flags;
BIO* biow;
BIO* bior;
}; };
struct net_context_openssl struct net_context_openssl
@ -108,47 +111,6 @@ int ssl_check_private_key(struct ssl_context_handle* ctx_)
return 1; return 1;
} }
static int handle_openssl_error(struct net_connection* con, int ret)
{
struct net_ssl_openssl* handle = get_handle(con);
int error = SSL_get_error(handle->ssl, ret);
switch (error)
{
case SSL_ERROR_ZERO_RETURN:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_ZERO_RETURN", ret, error);
handle->state = tls_st_error;
return -1;
case SSL_ERROR_WANT_READ:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_READ", ret, error);
con->flags |= NET_WANT_SSL_READ;
net_con_update(con, NET_EVENT_READ);
return 0;
case SSL_ERROR_WANT_WRITE:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_WRITE", ret, error);
con->flags |= NET_WANT_SSL_WRITE;
net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE);
return 0;
case SSL_ERROR_SYSCALL:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SYSCALL", ret, error);
/* if ret == 0, connection closed, if ret == -1, check with errno */
if (ret == 0)
return -1;
else
return -net_error();
case SSL_ERROR_SSL:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SSL", ret, error);
/* internal openssl error */
handle->state = tls_st_error;
return -1;
}
return -1;
}
ssize_t net_con_ssl_accept(struct net_connection* con) ssize_t net_con_ssl_accept(struct net_connection* con)
{ {
struct net_ssl_openssl* handle = get_handle(con); struct net_ssl_openssl* handle = get_handle(con);
@ -164,7 +126,24 @@ ssize_t net_con_ssl_accept(struct net_connection* con)
} }
else else
{ {
return handle_openssl_error(con, ret); int err = SSL_get_error(handle->ssl, ret);
switch (err)
{
case SSL_ERROR_ZERO_RETURN:
// Not really an error, but SSL was shut down.
return -1;
case SSL_ERROR_WANT_READ:
net_con_update(con, NET_EVENT_READ);
return 0;
case SSL_ERROR_WANT_WRITE:
net_con_update(con, NET_EVENT_WRITE);
return 0;
case SSL_ERROR_SYSCALL:
return -1;
}
} }
return ret; return ret;
} }
@ -175,7 +154,6 @@ ssize_t net_con_ssl_connect(struct net_connection* con)
ssize_t ret; ssize_t ret;
handle->state = tls_st_connecting; handle->state = tls_st_connecting;
ret = SSL_connect(handle->ssl); ret = SSL_connect(handle->ssl);
#ifdef NETWORK_DUMP_DEBUG #ifdef NETWORK_DUMP_DEBUG
LOG_PROTO("SSL_connect() ret=%d", ret); LOG_PROTO("SSL_connect() ret=%d", ret);
@ -188,7 +166,24 @@ ssize_t net_con_ssl_connect(struct net_connection* con)
} }
else else
{ {
return handle_openssl_error(con, ret); int err = SSL_get_error(handle->ssl, ret);
switch (err)
{
case SSL_ERROR_ZERO_RETURN:
// Not really an error, but SSL was shut down.
return -1;
case SSL_ERROR_WANT_READ:
net_con_update(con, NET_EVENT_READ);
return 0;
case SSL_ERROR_WANT_WRITE:
net_con_update(con, NET_EVENT_WRITE);
return 0;
case SSL_ERROR_SYSCALL:
return -1;
}
} }
return ret; return ret;
} }
@ -209,6 +204,8 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode
return -1; return -1;
} }
SSL_set_fd(handle->ssl, con->sd); SSL_set_fd(handle->ssl, con->sd);
handle->bior = SSL_get_rbio(handle->ssl);
handle->biow = SSL_get_wbio(handle->ssl);
con->ssl = (struct ssl_handle*) handle; con->ssl = (struct ssl_handle*) handle;
return net_con_ssl_accept(con); return net_con_ssl_accept(con);
} }
@ -216,6 +213,8 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode
{ {
handle->ssl = SSL_new(SSL_CTX_new(TLSv1_method())); handle->ssl = SSL_new(SSL_CTX_new(TLSv1_method()));
SSL_set_fd(handle->ssl, con->sd); SSL_set_fd(handle->ssl, con->sd);
handle->bior = SSL_get_rbio(handle->ssl);
handle->biow = SSL_get_wbio(handle->ssl);
con->ssl = (struct ssl_handle*) handle; con->ssl = (struct ssl_handle*) handle;
return net_con_ssl_connect(con); return net_con_ssl_connect(con);
} }
@ -225,14 +224,40 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode
ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len) ssize_t net_ssl_send(struct net_connection* con, const void* buf, size_t len)
{ {
struct net_ssl_openssl* handle = get_handle(con); struct net_ssl_openssl* handle = get_handle(con);
// con->write_len = len;
uhub_assert(handle->state == tls_st_connected || handle->state == tls_st_need_write);
ERR_clear_error();
ssize_t ret = SSL_write(handle->ssl, buf, len); ssize_t ret = SSL_write(handle->ssl, buf, len);
LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
if (ret <= 0) if (ret > 0)
{ {
return handle_openssl_error(con, ret); handle->state = tls_st_connected;
}
return ret; return ret;
}
else if (ret <= 0)
{
int err = SSL_get_error(handle->ssl, ret);
switch (err)
{
case SSL_ERROR_ZERO_RETURN:
// Not really an error, but SSL was shut down.
return -1;
case SSL_ERROR_WANT_READ:
handle->state = tls_st_need_write;
net_con_update(con, NET_EVENT_READ);
return 0;
case SSL_ERROR_WANT_WRITE:
handle->state = tls_st_need_write;
net_con_update(con, NET_EVENT_WRITE);
return 0;
case SSL_ERROR_SYSCALL:
return -1;
}
}
} }
ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len) ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len)
@ -243,17 +268,40 @@ ssize_t net_ssl_recv(struct net_connection* con, void* buf, size_t len)
if (handle->state == tls_st_error) if (handle->state == tls_st_error)
return -1; return -1;
uhub_assert(handle->state == tls_st_connected || handle->state == tls_st_need_read);
ERR_clear_error();
ret = SSL_read(handle->ssl, buf, len); ret = SSL_read(handle->ssl, buf, len);
LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret); LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
if (ret > 0) if (ret > 0)
{ {
net_con_update(con, NET_EVENT_READ); handle->state = tls_st_connected;
}
else
{
return handle_openssl_error(con, ret);
}
return ret; return ret;
}
else if (ret <= 0)
{
int err = SSL_get_error(handle->ssl, ret);
switch (err)
{
case SSL_ERROR_ZERO_RETURN:
// Not really an error, but SSL was shut down.
return -1;
case SSL_ERROR_WANT_READ:
handle->state = tls_st_need_read;
net_con_update(con, NET_EVENT_READ);
return 0;
case SSL_ERROR_WANT_WRITE:
handle->state = tls_st_need_read;
net_con_update(con, NET_EVENT_WRITE);
return 0;
case SSL_ERROR_SYSCALL:
return -1;
}
}
} }
void net_ssl_shutdown(struct net_connection* con) void net_ssl_shutdown(struct net_connection* con)
@ -299,20 +347,15 @@ void net_ssl_callback(struct net_connection* con, int events)
con->callback(con, NET_EVENT_READ, con->ptr); con->callback(con, NET_EVENT_READ, con->ptr);
break; break;
case tls_st_need_read:
con->callback(con, NET_EVENT_READ, con->ptr);
break;
case tls_st_need_write:
con->callback(con, NET_EVENT_WRITE, con->ptr);
break;
case tls_st_connected: case tls_st_connected:
LOG_PROTO("tls_st_connected, events=%s%s, ssl_flags=%s%s", (events & NET_EVENT_READ ? "R" : ""), (events & NET_EVENT_WRITE ? "W" : ""), flags & NET_WANT_SSL_READ ? "R" : "", flags & NET_WANT_SSL_WRITE ? "W" : "");
if (events & NET_EVENT_WRITE && flags & NET_WANT_SSL_READ)
{
con->callback(con, events & NET_EVENT_READ, con->ptr);
return;
}
if (events & NET_EVENT_READ && flags & NET_WANT_SSL_WRITE)
{
con->callback(con, events & NET_EVENT_READ, con->ptr);
return;
}
con->callback(con, events, con->ptr); con->callback(con, events, con->ptr);
break; break;

View File

@ -32,6 +32,8 @@ enum ssl_state
tls_st_accepting, tls_st_accepting,
tls_st_connecting, tls_st_connecting,
tls_st_connected, tls_st_connected,
tls_st_need_read, /* special case of connected */
tls_st_need_write, /* special case of connected */
tls_st_disconnecting, tls_st_disconnecting,
}; };