SSL fixes, seems to work with stunnel4 as a client but not linuxdcpp using SSL directly.

This commit is contained in:
Jan Vidar Krey 2010-01-22 23:40:41 +01:00
parent 4aa65733d0
commit 84bd2591d6
3 changed files with 95 additions and 58 deletions

View File

@ -27,7 +27,7 @@ extern struct hub_info* g_hub;
#ifdef DEBUG_SENDQ
void debug_sendq_send(struct hub_user* user, int sent, int total)
{
LOG_DUMP("SEND: sd=%d, %d/%d bytes\n", user->net.connection.sd, sent, total);
LOG_DUMP("SEND: sd=%d, %d/%d bytes\n", user->connection->sd, sent, total);
if (sent == -1)
{
int err = net_error();
@ -162,7 +162,7 @@ void net_event(struct net_connection* con, int event, void *arg)
int flag_close = 0;
#ifdef DEBUG_SENDQ
LOG_TRACE("net_event() : fd=%d, ev=%d, arg=%p", fd, (int) event, arg);
LOG_TRACE("net_event() : fd=%d, ev=%d, arg=%p", con->sd, (int) event, arg);
#endif
if (event == NET_EVENT_TIMEOUT)

View File

@ -28,7 +28,6 @@
#define NET_CLEANUP 0x8000
#define NET_CON_STRUCT_BASIC \
int sd; /** socket descriptor */ \
uint32_t flags; /** Connection flags */ \
@ -38,6 +37,7 @@
#define NET_CON_STRUCT_SSL \
SSL* ssl; /** SSL handle */ \
uint32_t ssl_state; /** SSL state */ \
size_t write_len; /** Length of last SSL_write(), only used if flags is NET_WANT_SSL_READ. */ \
#ifdef SSL_SUPPORT

View File

@ -21,6 +21,17 @@
#include "network/common.h"
#ifdef SSL_SUPPORT
enum uhub_tls_state
{
tls_st_none,
tls_st_error,
tls_st_accepting,
tls_st_connecting,
tls_st_connected,
tls_st_disconnecting,
};
static int handle_openssl_error(struct net_connection* con, int ret)
{
uhub_assert(con);
@ -30,30 +41,19 @@ static int handle_openssl_error(struct net_connection* con, int ret)
{
case SSL_ERROR_ZERO_RETURN:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_ZERO_RETURN", ret, error);
con->ssl_state = tls_st_error;
return -1;
case SSL_ERROR_WANT_READ:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_READ", ret, error);
net_con_update(con, NET_EVENT_READ | NET_WANT_SSL_READ);
con->flags |= NET_WANT_SSL_READ;
net_con_update(con, NET_EVENT_READ);
return 0;
case SSL_ERROR_WANT_WRITE:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_WRITE", ret, error);
net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE | NET_WANT_SSL_WRITE);
return 0;
case SSL_ERROR_WANT_CONNECT:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_CONNECT", ret, error);
net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE | NET_WANT_SSL_CONNECT);
return 0;
case SSL_ERROR_WANT_ACCEPT:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_ACCEPT", ret, error);
net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE | NET_WANT_SSL_ACCEPT);
return 0;
case SSL_ERROR_WANT_X509_LOOKUP:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_WANT_X509_LOOKUP", ret, error);
con->flags |= NET_WANT_SSL_WRITE;
net_con_update(con, NET_EVENT_READ | NET_EVENT_WRITE);
return 0;
case SSL_ERROR_SYSCALL:
@ -67,6 +67,7 @@ static int handle_openssl_error(struct net_connection* con, int ret)
case SSL_ERROR_SSL:
LOG_PROTO("SSL_get_error: ret=%d, error=%d: SSL_ERROR_SSL", ret, error);
/* internal openssl error */
con->ssl_state = tls_st_error;
return -1;
}
@ -76,7 +77,7 @@ static int handle_openssl_error(struct net_connection* con, int ret)
ssize_t net_con_ssl_accept(struct net_connection* con)
{
uhub_assert(con);
con->ssl_state = tls_st_accepting;
ssize_t ret = SSL_accept(con->ssl);
#ifdef NETWORK_DUMP_DEBUG
LOG_PROTO("SSL_accept() ret=%d", ret);
@ -84,6 +85,7 @@ ssize_t net_con_ssl_accept(struct net_connection* con)
if (ret > 0)
{
net_con_update(con, NET_EVENT_READ);
con->ssl_state = tls_st_connected;
}
else
{
@ -96,12 +98,14 @@ ssize_t net_con_ssl_connect(struct net_connection* con)
{
uhub_assert(con);
con->ssl_state = tls_st_connecting;
ssize_t ret = SSL_connect(con->ssl);
#ifdef NETWORK_DUMP_DEBUG
LOG_PROTO("SSL_connect() ret=%d", ret);
#endif
if (ret > 0)
{
con->ssl_state = tls_st_connected;
net_con_update(con, NET_EVENT_READ);
}
else
@ -136,28 +140,42 @@ ssize_t net_con_ssl_handshake(struct net_connection* con, enum net_con_ssl_mode
ssize_t net_con_send(struct net_connection* con, const void* buf, size_t len)
{
int ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL);
if (ret == -1)
int ret;
#ifdef SSL_SUPPORT
if (!con->ssl)
{
if (net_error() == EWOULDBLOCK || net_error() == EINTR)
return 0;
return -1;
#endif
ret = net_send(con->sd, buf, len, UHUB_SEND_SIGNAL);
if (ret == -1)
{
if (net_error() == EWOULDBLOCK || net_error() == EINTR)
return 0;
return -1;
}
#ifdef SSL_SUPPORT
}
else
{
con->write_len = len;
ret = SSL_write(con->ssl, buf, len);
LOG_PROTO("SSL_write(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
if (ret <= 0)
{
return -handle_openssl_error(con, ret);
}
}
#endif
return ret;
}
ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len)
{
uhub_assert(con);
int ret;
#ifdef SSL_SUPPORT
if (!net_con_is_ssl(con))
{
#endif
int ret = net_recv(con->sd, buf, len, 0);
#ifdef NETWORK_DUMP_DEBUG
LOG_PROTO("net_recv: ret=%d", ret);
#endif
ret = net_recv(con->sd, buf, len, 0);
if (ret == -1)
{
if (net_error() == EWOULDBLOCK || net_error() == EINTR)
@ -168,16 +186,15 @@ ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len)
{
return -1;
}
return ret;
#ifdef SSL_SUPPORT
}
else
{
int ret = SSL_read(con->ssl, buf, len);
#ifdef NETWORK_DUMP_DEBUG
LOG_PROTO("net_recv: ret=%d", ret);
#endif
if (con->ssl_state == tls_st_error)
return -1;
ret = SSL_read(con->ssl, buf, len);
LOG_PROTO("SSL_read(con=%p, buf=%p, len=" PRINTF_SIZE_T ") => %d", con, buf, len, ret);
if (ret > 0)
{
net_con_update(con, NET_EVENT_READ);
@ -186,9 +203,9 @@ ssize_t net_con_recv(struct net_connection* con, void* buf, size_t len)
{
return -handle_openssl_error(con, ret);
}
return ret;
}
#endif
return ret;
}
ssize_t net_con_peek(struct net_connection* con, void* buf, size_t len)
@ -234,7 +251,7 @@ void* net_con_get_ptr(struct net_connection* con)
void net_con_callback(struct net_connection* con, int events)
{
if ((con->flags & NET_CLEANUP))
if (con->flags & NET_CLEANUP)
return;
if (events == NET_EVENT_TIMEOUT)
@ -254,27 +271,47 @@ void net_con_callback(struct net_connection* con, int events)
else
{
#ifdef NETWORK_DUMP_DEBUG
LOG_PROTO("net_con_event: events=%d, con=%p", ev, con);
LOG_PROTO("net_con_event: events=%d, con=%p, state=%d", events, con, con->ssl_state);
#endif
if (events == NET_EVENT_READ && con->flags & NET_WANT_SSL_READ)
switch (con->ssl_state)
{
con->callback(con, NET_EVENT_WRITE, con->ptr);
}
else if (events == NET_EVENT_WRITE && con->flags & NET_WANT_SSL_WRITE)
{
con->callback(con, events & NET_EVENT_READ, con->ptr);
}
if (con->flags & NET_WANT_SSL_ACCEPT)
{
net_con_ssl_accept(con);
}
else if (con->flags & NET_WANT_SSL_CONNECT)
{
net_con_ssl_connect(con);
}
else
{
con->callback(con, events, con->ptr);
case tls_st_none:
con->callback(con, events, con->ptr);
break;
case tls_st_accepting:
if (net_con_ssl_accept(con) < 0)
{
con->callback(con, NET_EVENT_READ, con->ptr);
}
break;
case tls_st_connecting:
if (net_con_ssl_connect(con) < 0)
{
con->callback(con, NET_EVENT_READ, con->ptr);
}
break;
case tls_st_connected:
LOG_PROTO("tls_st_connected, events=%s%s, ssl_flags=%s%s", (events & NET_EVENT_READ ? "R" : ""), (events & NET_EVENT_WRITE ? "W" : ""), con->flags & NET_WANT_SSL_READ ? "R" : "", con->flags & NET_WANT_SSL_WRITE ? "W" : "");
if (events & NET_EVENT_WRITE && con->flags & NET_WANT_SSL_READ)
{
con->callback(con, events & NET_EVENT_READ, con->ptr);
return;
}
if (events & NET_EVENT_READ && con->flags & NET_WANT_SSL_WRITE)
{
con->callback(con, events & NET_EVENT_READ, con->ptr);
return;
}
con->callback(con, events, con->ptr);
break;
case tls_st_disconnecting:
return;
}
}
#endif