diff --git a/ixwebsocket/IXWebSocket.cpp b/ixwebsocket/IXWebSocket.cpp index f778dd30..f9b8075d 100644 --- a/ixwebsocket/IXWebSocket.cpp +++ b/ixwebsocket/IXWebSocket.cpp @@ -256,7 +256,8 @@ namespace ix WebSocketInitResult WebSocket::connectToSocket(std::unique_ptr socket, int timeoutSecs, bool enablePerMessageDeflate, - HttpRequestPtr request) + HttpRequestPtr request, + int sendTimeoutSecs) { { std::lock_guard lock(_configMutex); @@ -264,8 +265,8 @@ namespace ix _perMessageDeflateOptions, _socketTLSOptions, _enablePong, _pingIntervalSecs); } - WebSocketInitResult status = - _ws.connectToSocket(std::move(socket), timeoutSecs, enablePerMessageDeflate, request); + WebSocketInitResult status = _ws.connectToSocket( + std::move(socket), timeoutSecs, enablePerMessageDeflate, request, sendTimeoutSecs); if (!status.success) { return status; diff --git a/ixwebsocket/IXWebSocket.h b/ixwebsocket/IXWebSocket.h index 21292e7d..1070836f 100644 --- a/ixwebsocket/IXWebSocket.h +++ b/ixwebsocket/IXWebSocket.h @@ -135,7 +135,8 @@ namespace ix WebSocketInitResult connectToSocket(std::unique_ptr, int timeoutSecs, bool enablePerMessageDeflate, - HttpRequestPtr request = nullptr); + HttpRequestPtr request = nullptr, + int sendTimeoutSecs = -1); WebSocketTransport _ws; diff --git a/ixwebsocket/IXWebSocketCloseConstants.cpp b/ixwebsocket/IXWebSocketCloseConstants.cpp index d8ba57f6..8a266584 100644 --- a/ixwebsocket/IXWebSocketCloseConstants.cpp +++ b/ixwebsocket/IXWebSocketCloseConstants.cpp @@ -19,6 +19,7 @@ namespace ix const std::string WebSocketCloseConstants::kInternalErrorMessage("Internal error"); const std::string WebSocketCloseConstants::kAbnormalCloseMessage("Abnormal closure"); const std::string WebSocketCloseConstants::kPingTimeoutMessage("Ping timeout"); + const std::string WebSocketCloseConstants::kSendTimeoutMessage("Send timeout"); const std::string WebSocketCloseConstants::kProtocolErrorMessage("Protocol error"); const std::string WebSocketCloseConstants::kNoStatusCodeErrorMessage("No status code"); const std::string WebSocketCloseConstants::kProtocolErrorReservedBitUsed("Reserved bit used"); diff --git a/ixwebsocket/IXWebSocketCloseConstants.h b/ixwebsocket/IXWebSocketCloseConstants.h index 145777b9..f7094d82 100644 --- a/ixwebsocket/IXWebSocketCloseConstants.h +++ b/ixwebsocket/IXWebSocketCloseConstants.h @@ -24,6 +24,7 @@ namespace ix static const std::string kInternalErrorMessage; static const std::string kAbnormalCloseMessage; static const std::string kPingTimeoutMessage; + static const std::string kSendTimeoutMessage; static const std::string kProtocolErrorMessage; static const std::string kNoStatusCodeErrorMessage; static const std::string kProtocolErrorReservedBitUsed; diff --git a/ixwebsocket/IXWebSocketServer.cpp b/ixwebsocket/IXWebSocketServer.cpp index cb6988a5..70f9568f 100644 --- a/ixwebsocket/IXWebSocketServer.cpp +++ b/ixwebsocket/IXWebSocketServer.cpp @@ -20,6 +20,7 @@ namespace ix const int WebSocketServer::kDefaultHandShakeTimeoutSecs(3); // 3 seconds const bool WebSocketServer::kDefaultEnablePong(true); const int WebSocketServer::kPingIntervalSeconds(-1); // disable heartbeat + const int WebSocketServer::kSendTimeoutSeconds(-1); WebSocketServer::WebSocketServer(int port, const std::string& host, @@ -27,12 +28,14 @@ namespace ix size_t maxConnections, int handshakeTimeoutSecs, int addressFamily, - int pingIntervalSeconds) + int pingIntervalSeconds, + int sendTimeoutSeconds) : SocketServer(port, host, backlog, maxConnections, addressFamily) , _handshakeTimeoutSecs(handshakeTimeoutSecs) , _enablePong(kDefaultEnablePong) , _enablePerMessageDeflate(true) , _pingIntervalSeconds(pingIntervalSeconds) + , _sendTimeoutSeconds(sendTimeoutSeconds) { } @@ -144,8 +147,11 @@ namespace ix _clients.insert(webSocket); } - auto status = webSocket->connectToSocket( - std::move(socket), _handshakeTimeoutSecs, _enablePerMessageDeflate, request); + auto status = webSocket->connectToSocket(std::move(socket), + _handshakeTimeoutSecs, + _enablePerMessageDeflate, + request, + _sendTimeoutSeconds); if (status.success) { // Process incoming messages and execute callbacks diff --git a/ixwebsocket/IXWebSocketServer.h b/ixwebsocket/IXWebSocketServer.h index 7636074e..a4a1d79f 100644 --- a/ixwebsocket/IXWebSocketServer.h +++ b/ixwebsocket/IXWebSocketServer.h @@ -34,7 +34,8 @@ namespace ix size_t maxConnections = SocketServer::kDefaultMaxConnections, int handshakeTimeoutSecs = WebSocketServer::kDefaultHandShakeTimeoutSecs, int addressFamily = SocketServer::kDefaultAddressFamily, - int pingIntervalSeconds = WebSocketServer::kPingIntervalSeconds); + int pingIntervalSeconds = WebSocketServer::kPingIntervalSeconds, + int sendTimeoutSeconds = WebSocketServer::kSendTimeoutSeconds); virtual ~WebSocketServer(); virtual void stop() final; @@ -63,6 +64,7 @@ namespace ix bool _enablePong; bool _enablePerMessageDeflate; int _pingIntervalSeconds; + int _sendTimeoutSeconds; OnConnectionCallback _onConnectionCallback; OnClientMessageCallback _onClientMessageCallback; @@ -72,6 +74,7 @@ namespace ix const static bool kDefaultEnablePong; const static int kPingIntervalSeconds; + const static int kSendTimeoutSeconds; // Methods virtual void handleConnection(std::unique_ptr socket, diff --git a/ixwebsocket/IXWebSocketTransport.cpp b/ixwebsocket/IXWebSocketTransport.cpp index 1d381ecb..39194f43 100644 --- a/ixwebsocket/IXWebSocketTransport.cpp +++ b/ixwebsocket/IXWebSocketTransport.cpp @@ -61,6 +61,7 @@ namespace ix WebSocketTransport::WebSocketTransport() : _useMask(true) , _blockingSend(false) + , _sendTimeoutSecs(-1) , _receivedMessageCompressed(false) , _readyState(ReadyState::CLOSED) , _closeCode(WebSocketCloseConstants::kInternalErrorCode) @@ -172,13 +173,15 @@ namespace ix WebSocketInitResult WebSocketTransport::connectToSocket(std::unique_ptr socket, int timeoutSecs, bool enablePerMessageDeflate, - HttpRequestPtr request) + HttpRequestPtr request, + int sendTimeoutSecs) { std::lock_guard lock(_socketMutex); // Server should not mask the data it sends to the client _useMask = false; _blockingSend = true; + _sendTimeoutSecs = sendTimeoutSecs; _socket = std::move(socket); _perMessageDeflate = ix::make_unique(); @@ -1242,6 +1245,23 @@ namespace ix bool WebSocketTransport::flushSendBuffer() { + auto start = std::chrono::steady_clock::now(); + + // timeoutMs tracks how long to wait before forcefully + // closing the socket when sending runs into a timeout. + std::chrono::seconds timeoutSecs(0); + if (_sendTimeoutSecs > 0) + { + timeoutSecs = std::chrono::seconds(_sendTimeoutSecs); + } + else if (_pingIntervalSecs > 0) + { + // If a pingInterval is set, use it as a timeout because if we cannot + // send out any data for pingInterval seconds, we may as well disconnet + // the client. + timeoutSecs = std::chrono::seconds(_pingIntervalSecs); + } + while (!isSendBufferEmpty() && !_requestInitCancellation) { // Wait with a 10ms timeout until the socket is ready to write. @@ -1261,6 +1281,20 @@ namespace ix return false; } } + else if (result == PollResultType::Timeout && timeoutSecs.count() > 0) + { + auto now = std::chrono::steady_clock::now(); + // Timeout error and exceeded the allocated timeout: Treat + // as abnormal close and use "Send Timeout" for the reason. + if (now > start + timeoutSecs) + { + closeSocketAndSwitchToClosedState(WebSocketCloseConstants::kAbnormalCloseCode, + WebSocketCloseConstants::kSendTimeoutMessage, + 0, + false); + return false; + } + } } return true; diff --git a/ixwebsocket/IXWebSocketTransport.h b/ixwebsocket/IXWebSocketTransport.h index c89e6015..f2391e15 100644 --- a/ixwebsocket/IXWebSocketTransport.h +++ b/ixwebsocket/IXWebSocketTransport.h @@ -88,7 +88,8 @@ namespace ix WebSocketInitResult connectToSocket(std::unique_ptr socket, int timeoutSecs, bool enablePerMessageDeflate, - HttpRequestPtr request = nullptr); + HttpRequestPtr request = nullptr, + int sendTimeoutSecs = -1); PollResult poll(); WebSocketSendInfo sendBinary(const IXWebSocketSendData& message, @@ -150,6 +151,12 @@ namespace ix // saying that a send is complete. This is the mode for server code. std::atomic _blockingSend; + // A configurable timeout for how long flushSendBuffer() may block + // before forcefully closing the client socket with a "Send timeout" + // message. This is useful when a client doesn't read from its socket + // and the server stalls on trying to send more data. + int _sendTimeoutSecs = -1; + // Buffer for reading from our socket. That buffer is never resized. std::vector _readbuf; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b5fe5fbe..4810e114 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -27,6 +27,7 @@ set (TEST_TARGET_NAMES IXWebSocketCloseTest IXWebSocketHostTest IXWebSocketIPv6Test + IXWebSocketSendTimeoutTest ) # Some unittest don't work on windows yet diff --git a/test/IXWebSocketSendTimeoutTest.cpp b/test/IXWebSocketSendTimeoutTest.cpp new file mode 100644 index 00000000..c994d507 --- /dev/null +++ b/test/IXWebSocketSendTimeoutTest.cpp @@ -0,0 +1,105 @@ +#include "IXTest.h" +#include "catch.hpp" +#include "ixwebsocket/IXWebSocketMessageType.h" +#include +#include +#include +#include + +using namespace ix; + +static std::atomic client_connected {false}; +static std::atomic client_closed {false}; + +TEST_CASE("SendTimeout") +{ + SECTION("Test send timeout kicking in") + { + // Create a server with a one second send timeout + int port = getFreePort(); + std::unique_ptr server = std::unique_ptr( + new ix::WebSocketServer(port, + "127.0.0.1", + SocketServer::kDefaultTcpBacklog, + SocketServer::kDefaultMaxConnections, + WebSocketServer::kDefaultHandShakeTimeoutSecs, + AF_INET, + /*pingIntervalSeconds=*/5, + /*sendTimeoutSeconds=*/1)); + + auto res = server->listen(); + REQUIRE(res.first); + + server->setOnConnectionCallback( + [](std::weak_ptr wws, std::shared_ptr cs) -> void + { + TLogger() << "Client connected!"; + auto ws = wws.lock(); + client_connected = true; + + // When the client sends a message, send it 50k messages back + // to quickly fill up the socket buffer and run into a send + // timeout. + ws->setOnMessageCallback( + [ws](const WebSocketMessagePtr& wsmptr) + { + if (wsmptr->type == WebSocketMessageType::Message) + { + auto i = 0; + while (++i < 50000) + { + auto r = ws->sendText("SPAM SPAM SPAM SPAM SPAM SPAM!"); + if (!r.success) + { + ws->close(); + break; + } + } + } + else if (wsmptr->type == WebSocketMessageType::Close) + { + TLogger() + << "SERVER: Client connection closed:" << wsmptr->closeInfo.reason; + client_closed = true; + } + }); + }); + + std::string url = "ws://127.0.0.1:" + std::to_string(port) + "/"; + ix::WebSocket client; + client.setUrl(url); + + client.setOnMessageCallback( + [&client](const ix::WebSocketMessagePtr& msg) + { + if (msg->type == ix::WebSocketMessageType::Open) + { + TLogger() << "CLIENT: Open"; + client.sendText("Hello"); + } + else if (msg->type == ix::WebSocketMessageType::Close) + { + TLogger() << "CLIENT: Close"; + } + else if (msg->type == ix::WebSocketMessageType::Message) + { + auto r = client.sendText("Hello, again!"); + + // Block the client thread after sending a message + // to make the socket buffers run full. + if (r.success) msleep(1000); + } + }); + + server->start(); + client.start(); + + // Wait for client to connect and be closed again. + while (!client_connected || !client_closed) + { + msleep(10); + } + + server->stop(); + } +}