Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/pyrobusta/bindings/socket_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..stream.buffer import MemoryPool, SlidingBuffer, BufferFullError
from ..transport.socket import SocketBase
from ..protocol.http import HttpEngine
from ..protocol.http import HttpEngine, ServerBusyError
from ..utils.config import get_config
from ..utils import logging

Expand Down Expand Up @@ -153,6 +153,10 @@ async def _run_state_machine(self):
self._engine.on_failure(self._send_buf, b"Buffer full")
await self._flush_response()
return
except ServerBusyError:
self._engine.on_busy(self._send_buf)
await self._flush_response()
return
except Exception as e: # pylint: disable=W0718
logging.warning(f"[SocketHttp] error in _run_state_machine: {e}")
self._engine.on_failure(self._send_buf, str(e).encode("ascii"))
Expand Down
6 changes: 6 additions & 0 deletions src/pyrobusta/protocol/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ class HeaderParsingError(ValueError):
pass


class ServerBusyError(RuntimeError):
"""Exception for applications to indicate busy state"""

pass


class HttpEngine:
"""
HTTP protocol parser state machine
Expand Down
42 changes: 40 additions & 2 deletions tests/functional/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

from pyrobusta.server import http_server
from pyrobusta.protocol import http_multipart
from pyrobusta.protocol.http import HttpEngine, enable_optional_features
from pyrobusta.protocol.http import (
HttpEngine,
enable_optional_features,
ServerBusyError,
)
from pyrobusta.utils import config

#################################################
Expand All @@ -20,7 +24,7 @@ def test_assert(name, actual, expected):
raise AssertionError(f"{actual} != {expected}")


async def send_request(request, tls):
async def send_request(request, tls=False):
port = (
http_server.HttpServer.LISTEN_PORT_HTTPS
if tls
Expand Down Expand Up @@ -80,6 +84,11 @@ def multipart_callback(headers, body):
return "multipart/form-data", ("text/plain", multipart_response(part_count))


@HttpEngine.route("/test/busy", "POST")
def busy_callback(headers, body):
raise ServerBusyError()


async def test_simple_response(tls_enabled):
setup_config(multipart=False, tls_enabled=tls_enabled)

Expand Down Expand Up @@ -164,6 +173,27 @@ async def test_multipart_response(tls_enabled):
await server.terminate()


async def test_server_busy():
setup_config()

server = http_server.HttpServer()
server_task = asyncio.create_task(server.run_server())
await asyncio.sleep_ms(100)

# Test: 1 part
plain_response = await send_request(
b"POST /test/busy HTTP/1.1\r\n" b"Host: localhost\r\n\r\n"
)
test_assert(
f"response is rejected by busy service with 503",
b"503 Service Unavailable" in plain_response,
True,
)

server_task.cancel()
await server.terminate()


#################################################
# Test methods
#################################################
Expand All @@ -190,6 +220,12 @@ def test_registration():
HttpEngine.ENDPOINTS[b"/test/multipart"][b"GET"],
)

test_assert(
"busy endpoint registration",
busy_callback,
HttpEngine.ENDPOINTS[b"/test/busy"][b"POST"],
)


def test_multipart_patches():
setup_config(multipart=True)
Expand All @@ -209,5 +245,7 @@ def test_main():
asyncio.run(test_multipart_response(tls_enabled=False))
asyncio.run(test_multipart_response(tls_enabled=True))

asyncio.run(test_server_busy())


test_main()
Loading