From d464e4beb423881e49313aa432474d9e71ff0460 Mon Sep 17 00:00:00 2001 From: szeka9 Date: Tue, 24 Mar 2026 23:39:55 +0100 Subject: [PATCH] Add new exception for server busy state Allow applications to indicate busy state by raising ServerBusyError, resulting in 503. --- src/pyrobusta/bindings/socket_http.py | 6 +++- src/pyrobusta/protocol/http.py | 6 ++++ tests/functional/test_http.py | 42 +++++++++++++++++++++++++-- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/pyrobusta/bindings/socket_http.py b/src/pyrobusta/bindings/socket_http.py index d1c51ec..955ba3b 100644 --- a/src/pyrobusta/bindings/socket_http.py +++ b/src/pyrobusta/bindings/socket_http.py @@ -8,7 +8,7 @@ from ..stream.buffer import MemoryPool, SlidingBuffer, BufferFullError from ..transport.socket import SocketBase -from ..protocol.http import HttpEngine +from ..protocol.http import HttpEngine, ServerBusyError from ..utils.config import get_config from ..utils import logging @@ -153,6 +153,10 @@ async def _run_state_machine(self): self._engine.on_failure(self._send_buf, b"Buffer full") await self._flush_response() return + except ServerBusyError: + self._engine.on_busy(self._send_buf) + await self._flush_response() + return except Exception as e: # pylint: disable=W0718 logging.warning(f"[SocketHttp] error in _run_state_machine: {e}") self._engine.on_failure(self._send_buf, str(e).encode("ascii")) diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index eb324a1..ddd9127 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -16,6 +16,12 @@ class HeaderParsingError(ValueError): pass +class ServerBusyError(RuntimeError): + """Exception for applications to indicate busy state""" + + pass + + class HttpEngine: """ HTTP protocol parser state machine diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py index a491d3e..76b084a 100644 --- a/tests/functional/test_http.py +++ b/tests/functional/test_http.py @@ -3,7 +3,11 @@ from pyrobusta.server import http_server from pyrobusta.protocol import http_multipart -from pyrobusta.protocol.http import HttpEngine, enable_optional_features +from pyrobusta.protocol.http import ( + HttpEngine, + enable_optional_features, + ServerBusyError, +) from pyrobusta.utils import config ################################################# @@ -20,7 +24,7 @@ def test_assert(name, actual, expected): raise AssertionError(f"{actual} != {expected}") -async def send_request(request, tls): +async def send_request(request, tls=False): port = ( http_server.HttpServer.LISTEN_PORT_HTTPS if tls @@ -80,6 +84,11 @@ def multipart_callback(headers, body): return "multipart/form-data", ("text/plain", multipart_response(part_count)) +@HttpEngine.route("/test/busy", "POST") +def busy_callback(headers, body): + raise ServerBusyError() + + async def test_simple_response(tls_enabled): setup_config(multipart=False, tls_enabled=tls_enabled) @@ -164,6 +173,27 @@ async def test_multipart_response(tls_enabled): await server.terminate() +async def test_server_busy(): + setup_config() + + server = http_server.HttpServer() + server_task = asyncio.create_task(server.run_server()) + await asyncio.sleep_ms(100) + + # Test: 1 part + plain_response = await send_request( + b"POST /test/busy HTTP/1.1\r\n" b"Host: localhost\r\n\r\n" + ) + test_assert( + f"response is rejected by busy service with 503", + b"503 Service Unavailable" in plain_response, + True, + ) + + server_task.cancel() + await server.terminate() + + ################################################# # Test methods ################################################# @@ -190,6 +220,12 @@ def test_registration(): HttpEngine.ENDPOINTS[b"/test/multipart"][b"GET"], ) + test_assert( + "busy endpoint registration", + busy_callback, + HttpEngine.ENDPOINTS[b"/test/busy"][b"POST"], + ) + def test_multipart_patches(): setup_config(multipart=True) @@ -209,5 +245,7 @@ def test_main(): asyncio.run(test_multipart_response(tls_enabled=False)) asyncio.run(test_multipart_response(tls_enabled=True)) + asyncio.run(test_server_busy()) + test_main()