386 lines
11 KiB
C++
386 lines
11 KiB
C++
/*
|
|
* Copyright (C) 2006-2016 Jacek Sieka, arnetheduck on gmail point com
|
|
*
|
|
* This program is free software; you can redistribute it and/or modify
|
|
* it under the terms of the GNU General Public License as published by
|
|
* the Free Software Foundation; either version 2 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* This program is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU General Public License
|
|
* along with this program; if not, write to the Free Software
|
|
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
|
|
*/
|
|
|
|
#include "adchpp.h"
|
|
|
|
#include "SocketManager.h"
|
|
|
|
#include "LogManager.h"
|
|
#include "ManagedSocket.h"
|
|
#include "ServerInfo.h"
|
|
#include "SimpleXML.h"
|
|
#include "Core.h"
|
|
|
|
#ifdef HAVE_OPENSSL
|
|
#include <boost/asio/ssl.hpp>
|
|
#endif
|
|
|
|
#include <boost/date_time/posix_time/time_parsers.hpp>
|
|
#include <boost/asio/ip/tcp.hpp>
|
|
#include <boost/asio/ip/v6_only.hpp>
|
|
|
|
namespace adchpp {
|
|
|
|
using namespace std;
|
|
using namespace std::placeholders;
|
|
using namespace boost::asio;
|
|
using boost::system::error_code;
|
|
using boost::system::system_error;
|
|
|
|
SocketManager::SocketManager(Core &core) :
|
|
core(core),
|
|
bufferSize(1024),
|
|
maxBufferSize(16 * 1024),
|
|
overflowTimeout(60 * 1000),
|
|
disconnectTimeout(10 * 1000)
|
|
{
|
|
}
|
|
|
|
const string SocketManager::className = "SocketManager";
|
|
|
|
template<typename T>
|
|
class SocketStream : public AsyncStream {
|
|
public:
|
|
template<typename X>
|
|
SocketStream(X& x) : sock(x) { }
|
|
|
|
template<typename X, typename Y>
|
|
SocketStream(X& x, Y& y) : sock(x, y) { }
|
|
|
|
~SocketStream() { dcdebug("SocketStream deleted\n"); }
|
|
|
|
virtual size_t available() {
|
|
return sock.lowest_layer().available();
|
|
}
|
|
|
|
virtual void setOptions(size_t bufferSize) {
|
|
sock.lowest_layer().set_option(socket_base::receive_buffer_size(bufferSize));
|
|
sock.lowest_layer().set_option(socket_base::send_buffer_size(bufferSize));
|
|
}
|
|
|
|
virtual std::string getIp() {
|
|
try { return sock.lowest_layer().remote_endpoint().address().to_string(); }
|
|
catch(const system_error&) { return Util::emptyString; }
|
|
}
|
|
|
|
virtual void prepareRead(const BufferPtr& buf, const Handler& handler) {
|
|
if(buf) {
|
|
sock.async_read_some(buffer(buf->data(), buf->size()), handler);
|
|
} else {
|
|
sock.async_read_some(null_buffers(), handler);
|
|
}
|
|
}
|
|
|
|
virtual size_t read(const BufferPtr& buf) {
|
|
return sock.read_some(buffer(buf->data(), buf->size()));
|
|
}
|
|
|
|
virtual void write(const BufferList& bufs, const Handler& handler) {
|
|
if(bufs.size() == 1) {
|
|
sock.async_write_some(buffer(bufs[0]->data(), bufs[0]->size()), handler);
|
|
} else {
|
|
size_t n = std::min(bufs.size(), static_cast<size_t>(64));
|
|
std::vector<const_buffer> buffers;
|
|
buffers.reserve(n);
|
|
|
|
const size_t maxBytes = 1024;
|
|
|
|
for(size_t i = 0, total = 0; i < n && total < maxBytes; ++i) {
|
|
size_t left = maxBytes - total;
|
|
size_t bytes = min(bufs[i]->size(), left);
|
|
buffers.push_back(const_buffer(bufs[i]->data(), bytes));
|
|
total += bytes;
|
|
}
|
|
|
|
sock.async_write_some(buffers, handler);
|
|
}
|
|
}
|
|
|
|
T sock;
|
|
};
|
|
|
|
class SimpleSocketStream : public SocketStream<ip::tcp::socket> {
|
|
typedef SocketStream<ip::tcp::socket> Stream;
|
|
|
|
struct ShutdownHandler {
|
|
ShutdownHandler(const Handler& h) : h(h) { }
|
|
void operator()() { error_code ec; h(ec, 0); }
|
|
Handler h;
|
|
};
|
|
|
|
public:
|
|
SimpleSocketStream(boost::asio::io_service& x) : Stream(x) { }
|
|
|
|
virtual void init(const std::function<void ()>& postInit) {
|
|
postInit();
|
|
}
|
|
|
|
virtual void shutdown(const Handler& handler) {
|
|
sock.shutdown(ip::tcp::socket::shutdown_send);
|
|
sock.get_io_service().post(ShutdownHandler(handler));
|
|
}
|
|
|
|
virtual void close() {
|
|
// Abortive close, just go away...
|
|
if(sock.is_open()) {
|
|
error_code ec;
|
|
sock.close(ec); // Ignore errors
|
|
}
|
|
}
|
|
};
|
|
|
|
#ifdef HAVE_OPENSSL
|
|
|
|
class TLSSocketStream : public SocketStream<ssl::stream<ip::tcp::socket> > {
|
|
typedef SocketStream<ssl::stream<ip::tcp::socket> > Stream;
|
|
|
|
struct ShutdownHandler {
|
|
ShutdownHandler(const Handler& h) : h(h) { }
|
|
void operator()(const error_code &ec) { h(ec, 0); }
|
|
Handler h;
|
|
};
|
|
|
|
public:
|
|
TLSSocketStream(io_service& x, ssl::context& y) : Stream(x, y) { }
|
|
|
|
virtual void init(const std::function<void ()>& postInit) {
|
|
sock.async_handshake(ssl::stream_base::server, std::bind(&TLSSocketStream::handleHandshake,
|
|
this, std::placeholders::_1, postInit));
|
|
}
|
|
|
|
virtual void shutdown(const Handler& handler) {
|
|
sock.async_shutdown(ShutdownHandler(handler));
|
|
}
|
|
|
|
virtual void close() {
|
|
// Abortive close, just go away...
|
|
if(sock.lowest_layer().is_open()) {
|
|
error_code ec;
|
|
sock.lowest_layer().close(ec); // Ignore errors
|
|
}
|
|
}
|
|
|
|
private:
|
|
void handleHandshake(const error_code& ec, const std::function<void ()>& postInit) {
|
|
if(!ec) {
|
|
postInit();
|
|
}
|
|
}
|
|
};
|
|
|
|
#endif
|
|
|
|
static string formatEndpoint(const ip::tcp::endpoint& ep) {
|
|
return (ep.address().is_v4() ? ep.address().to_string() + ':' : '[' + ep.address().to_string() + "]:")
|
|
+ Util::toString(ep.port());
|
|
}
|
|
|
|
class SocketFactory : public enable_shared_from_this<SocketFactory>, boost::noncopyable {
|
|
public:
|
|
SocketFactory(SocketManager& sm, const SocketManager::IncomingHandler& handler_, const ServerInfo& info, const ip::tcp::endpoint& endpoint) :
|
|
sm(sm),
|
|
acceptor(sm.io),
|
|
handler(handler_)
|
|
{
|
|
acceptor.open(endpoint.protocol());
|
|
acceptor.set_option(socket_base::reuse_address(true));
|
|
if(endpoint.protocol() == ip::tcp::v6()) {
|
|
acceptor.set_option(ip::v6_only(true));
|
|
}
|
|
|
|
acceptor.bind(endpoint);
|
|
acceptor.listen(socket_base::max_connections);
|
|
|
|
LOGC(sm.getCore(), SocketManager::className,
|
|
"Listening on " + formatEndpoint(endpoint) +
|
|
" (Encrypted: " + (info.secure() ? "Yes)" : "No)"));
|
|
|
|
#ifdef HAVE_OPENSSL
|
|
if(info.secure()) {
|
|
context.reset(new ssl::context(sm.io, ssl::context::sslv23_server));
|
|
context->set_options(ssl::context::no_sslv2 | ssl::context::no_sslv3 | ssl::context::single_dh_use);
|
|
//context->set_password_callback(boost::bind(&server::get_password, this));
|
|
context->use_certificate_chain_file(info.TLSParams.cert);
|
|
context->use_private_key_file(info.TLSParams.pkey, ssl::context::pem);
|
|
context->use_tmp_dh_file(info.TLSParams.dh);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void prepareAccept() {
|
|
if(!sm.work.get()) {
|
|
return;
|
|
}
|
|
|
|
#ifdef HAVE_OPENSSL
|
|
if(context) {
|
|
auto s = make_shared<TLSSocketStream>(sm.io, *context);
|
|
auto socket = make_shared<ManagedSocket>(sm, s);
|
|
acceptor.async_accept(s->sock.lowest_layer(), std::bind(&SocketFactory::handleAccept, shared_from_this(), std::placeholders::_1, socket));
|
|
} else {
|
|
#endif
|
|
auto s = make_shared<SimpleSocketStream>(sm.io);
|
|
auto socket = make_shared<ManagedSocket>(sm, s);
|
|
acceptor.async_accept(s->sock.lowest_layer(), std::bind(&SocketFactory::handleAccept, shared_from_this(), std::placeholders::_1, socket));
|
|
#ifdef HAVE_OPENSSL
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void handleAccept(const error_code& ec, const ManagedSocketPtr& socket) {
|
|
if(!ec) {
|
|
socket->sock->setOptions(sm.getBufferSize());
|
|
socket->setIp(socket->sock->getIp());
|
|
}
|
|
|
|
completeAccept(ec, socket);
|
|
|
|
prepareAccept();
|
|
}
|
|
|
|
void completeAccept(const error_code& ec, const ManagedSocketPtr& socket) {
|
|
handler(socket);
|
|
socket->completeAccept(ec);
|
|
}
|
|
|
|
void close() { acceptor.close(); }
|
|
|
|
SocketManager &sm;
|
|
ip::tcp::acceptor acceptor;
|
|
SocketManager::IncomingHandler handler;
|
|
|
|
#ifdef HAVE_OPENSSL
|
|
unique_ptr<ssl::context> context;
|
|
#endif
|
|
|
|
};
|
|
|
|
int SocketManager::run() {
|
|
LOG(SocketManager::className, "Starting");
|
|
|
|
work.reset(new io_service::work(io));
|
|
|
|
for(auto i = servers.begin(), iend = servers.end(); i != iend; ++i) {
|
|
auto& si = *i;
|
|
|
|
try {
|
|
using ip::tcp;
|
|
tcp::resolver r(io);
|
|
auto local = r.resolve(tcp::resolver::query(si->ip, si->port,
|
|
tcp::resolver::query::address_configured | tcp::resolver::query::passive));
|
|
|
|
for(auto i = local; i != tcp::resolver::iterator(); ++i) {
|
|
SocketFactoryPtr factory = make_shared<SocketFactory>(*this, incomingHandler, *si, *i);
|
|
factory->prepareAccept();
|
|
factories.push_back(factory);
|
|
}
|
|
} catch(const std::exception& e) {
|
|
LOG(SocketManager::className, "Error while loading server on port " + si->port +": " + e.what());
|
|
}
|
|
}
|
|
|
|
io.run();
|
|
|
|
io.reset();
|
|
|
|
return 0;
|
|
}
|
|
|
|
void SocketManager::closeFactories() {
|
|
for(auto i = factories.begin(), iend = factories.end(); i != iend; ++i) {
|
|
(*i)->close();
|
|
}
|
|
factories.clear();
|
|
}
|
|
|
|
void SocketManager::addJob(const Callback& callback) throw() {
|
|
io.post(callback);
|
|
}
|
|
|
|
void SocketManager::addJob(const long msec, const Callback& callback) {
|
|
addJob(boost::posix_time::milliseconds(msec), callback);
|
|
}
|
|
|
|
void SocketManager::addJob(const std::string& time, const Callback& callback) {
|
|
addJob(boost::posix_time::duration_from_string(time), callback);
|
|
}
|
|
|
|
SocketManager::Callback SocketManager::addTimedJob(const long msec, const Callback& callback) {
|
|
return addTimedJob(boost::posix_time::milliseconds(msec), callback);
|
|
}
|
|
|
|
SocketManager::Callback SocketManager::addTimedJob(const std::string& time, const Callback& callback) {
|
|
return addTimedJob(boost::posix_time::duration_from_string(time), callback);
|
|
}
|
|
|
|
void SocketManager::addJob(const deadline_timer::duration_type& duration, const Callback& callback) {
|
|
setTimer(make_shared<timer_ptr::element_type>(io, duration), deadline_timer::duration_type(), new Callback(callback));
|
|
}
|
|
|
|
SocketManager::Callback SocketManager::addTimedJob(const deadline_timer::duration_type& duration, const Callback& callback) {
|
|
timer_ptr timer = make_shared<timer_ptr::element_type>(io, duration);
|
|
Callback* pCallback = new Callback(callback); // create a separate callback on the heap to avoid shutdown crashes
|
|
setTimer(timer, duration, pCallback);
|
|
return std::bind(&SocketManager::cancelTimer, this, timer, pCallback);
|
|
}
|
|
|
|
void SocketManager::setTimer(timer_ptr timer, const deadline_timer::duration_type& duration, Callback* callback) {
|
|
timer->async_wait(std::bind(&SocketManager::handleWait, this, timer, duration, std::placeholders::_1, callback));
|
|
}
|
|
|
|
void SocketManager::handleWait(timer_ptr timer, const deadline_timer::duration_type& duration, const error_code& error, Callback* callback) {
|
|
bool run_on = duration.ticks();
|
|
|
|
if(!error) {
|
|
if(run_on) {
|
|
// re-schedule the timer
|
|
timer->expires_at(timer->expires_at() + duration);
|
|
setTimer(timer, duration, callback);
|
|
}
|
|
|
|
addJob(*callback);
|
|
}
|
|
|
|
if(!run_on) {
|
|
// this timer was only running once, so it has no cancel function
|
|
delete callback;
|
|
}
|
|
}
|
|
|
|
void SocketManager::cancelTimer(timer_ptr timer, Callback* callback) {
|
|
if(timer.get()) {
|
|
error_code ec;
|
|
timer->cancel(ec);
|
|
}
|
|
|
|
delete callback;
|
|
}
|
|
|
|
void SocketManager::shutdown() {
|
|
closeFactories();
|
|
|
|
work.reset();
|
|
io.stop();
|
|
}
|
|
|
|
void SocketManager::onLoad(const SimpleXML& xml) throw() {
|
|
servers.clear();
|
|
}
|
|
|
|
}
|