diff --git a/Makefile b/Makefile index 5da1e69..016ada9 100644 --- a/Makefile +++ b/Makefile @@ -101,7 +101,7 @@ deploy-config: # ----------------------------- -# Deploy index page # TODO use install_www from assets module +# Deploy index page # ----------------------------- .PHONY: deploy-www deploy-www: diff --git a/README.md b/README.md index de54e56..17b1a3c 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ A lightweight HTTP server library for MicroPython designed for constrained embed - Finite-state-machine parser with linear sliding buffer - Robust byte-stream handling - Query parameter parsing with percent encoding support +- Persistent connections (set by **connection: keep-alive**) - TLS support # Installation diff --git a/assets/www/examples.html b/assets/www/examples.html index e1a2ca6..95a1a4d 100644 --- a/assets/www/examples.html +++ b/assets/www/examples.html @@ -92,7 +92,7 @@

Simple Server Application

if value_format == "%": free = round(100 * free / (free + mem_alloc()), 2) - return "text/plain", (f"Free memory [{value_format}]: {free}\n") + return "text/plain", f"Free memory [{value_format}]: {free}\n" async def run_server(): server = HttpServer() diff --git a/docs/dimensioning/http_dimensioning.md b/docs/dimensioning/http_dimensioning.md index e2a91bd..daf2d33 100644 --- a/docs/dimensioning/http_dimensioning.md +++ b/docs/dimensioning/http_dimensioning.md @@ -28,19 +28,19 @@ with no active network traffic. | id | http_mem_cap | http_multipart | socket_max_con | tls | footprint_bytes | | --- | --- | --- | --- | --- | --- | -| base | 0.05 | False | 1 | False | 38448 | -| low_mem_cap_001 | 0.0127 | False | 1 | False | 38448 | -| low_mem_cap_002 | 0.0253 | False | 2 | False | 39664 | -| low_mem_cap_003 | 0.0505 | False | 4 | False | 42096 | -| high_mem_cap_001 | 0.0568 | False | 1 | False | 45616 | -| high_mem_cap_002 | 0.114 | False | 2 | False | 54000 | -| high_mem_cap_003 | 0.228 | False | 4 | False | 70768 | -| multipart_001 | 0.0127 | True | 1 | False | 40704 | -| multipart_002 | 0.0253 | True | 2 | False | 41920 | -| multipart_003 | 0.0505 | True | 4 | False | 44352 | -| tls_001 | 0.0127 | False | 1 | True | 41120 | -| tls_002 | 0.0253 | False | 2 | True | 42336 | -| tls_003 | 0.0505 | False | 4 | True | 44768 | +| base | 0.05 | False | 1 | False | 38512 | +| low_mem_cap_001 | 0.0127 | False | 1 | False | 38512 | +| low_mem_cap_002 | 0.0253 | False | 2 | False | 39728 | +| low_mem_cap_003 | 0.0505 | False | 4 | False | 42112 | +| high_mem_cap_001 | 0.0568 | False | 1 | False | 45680 | +| high_mem_cap_002 | 0.114 | False | 2 | False | 54064 | +| high_mem_cap_003 | 0.228 | False | 4 | False | 70832 | +| multipart_001 | 0.0127 | True | 1 | False | 40608 | +| multipart_002 | 0.0253 | True | 2 | False | 41824 | +| multipart_003 | 0.0505 | True | 4 | False | 44256 | +| tls_001 | 0.0127 | False | 1 | True | 41184 | +| tls_002 | 0.0253 | False | 2 | True | 42400 | +| tls_003 | 0.0505 | False | 4 | True | 44832 | ## Heap usage under network traffic ![image info](./img/esp32_c3/base.png) @@ -60,9 +60,9 @@ with no active network traffic. | high_mem_cap_001 | 0.00111 | False | 1 | False | 45520 | | high_mem_cap_002 | 0.00222 | False | 2 | False | 53904 | | high_mem_cap_003 | 0.00443 | False | 4 | False | 70672 | -| multipart_001 | 0.000247 | True | 1 | False | 40656 | -| multipart_002 | 0.000493 | True | 2 | False | 41872 | -| multipart_003 | 0.000985 | True | 4 | False | 44304 | +| multipart_001 | 0.000247 | True | 1 | False | 40528 | +| multipart_002 | 0.000493 | True | 2 | False | 41744 | +| multipart_003 | 0.000985 | True | 4 | False | 44176 | | tls_001 | 0.000247 | False | 1 | True | 40736 | | tls_002 | 0.000493 | False | 2 | True | 41952 | | tls_003 | 0.000985 | False | 4 | True | 44384 | diff --git a/docs/dimensioning/img/esp32_c3/base.png b/docs/dimensioning/img/esp32_c3/base.png index 75fa126..1686010 100644 Binary files a/docs/dimensioning/img/esp32_c3/base.png and b/docs/dimensioning/img/esp32_c3/base.png differ diff --git a/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png b/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png index 9ae01ff..a4bf50e 100644 Binary files a/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png and b/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png differ diff --git a/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png b/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png index e0af9fa..c57115b 100644 Binary files a/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png and b/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png differ diff --git a/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png b/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png index a666ee7..6f68f16 100644 Binary files a/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png and b/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png differ diff --git a/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png b/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png index a9285ae..ebb7766 100644 Binary files a/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png and b/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png differ diff --git a/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png b/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png index 35fbf73..54fb0e2 100644 Binary files a/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png and b/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png differ diff --git a/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png b/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png index 1c58314..96edeb7 100644 Binary files a/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png and b/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png differ diff --git a/docs/dimensioning/img/esp32_c3/multipart_001.png b/docs/dimensioning/img/esp32_c3/multipart_001.png index 08e0f1a..1e249e5 100644 Binary files a/docs/dimensioning/img/esp32_c3/multipart_001.png and b/docs/dimensioning/img/esp32_c3/multipart_001.png differ diff --git a/docs/dimensioning/img/esp32_c3/multipart_002.png b/docs/dimensioning/img/esp32_c3/multipart_002.png index fd54765..0456481 100644 Binary files a/docs/dimensioning/img/esp32_c3/multipart_002.png and b/docs/dimensioning/img/esp32_c3/multipart_002.png differ diff --git a/docs/dimensioning/img/esp32_c3/multipart_003.png b/docs/dimensioning/img/esp32_c3/multipart_003.png index 046c0a3..2b4fcf9 100644 Binary files a/docs/dimensioning/img/esp32_c3/multipart_003.png and b/docs/dimensioning/img/esp32_c3/multipart_003.png differ diff --git a/docs/dimensioning/img/esp32_c3/tls_001.png b/docs/dimensioning/img/esp32_c3/tls_001.png index 9d933c0..bb465f3 100644 Binary files a/docs/dimensioning/img/esp32_c3/tls_001.png and b/docs/dimensioning/img/esp32_c3/tls_001.png differ diff --git a/docs/dimensioning/img/esp32_c3/tls_002.png b/docs/dimensioning/img/esp32_c3/tls_002.png index a69e17d..0755d09 100644 Binary files a/docs/dimensioning/img/esp32_c3/tls_002.png and b/docs/dimensioning/img/esp32_c3/tls_002.png differ diff --git a/docs/dimensioning/img/esp32_c3/tls_003.png b/docs/dimensioning/img/esp32_c3/tls_003.png index 5c6b128..603ace0 100644 Binary files a/docs/dimensioning/img/esp32_c3/tls_003.png and b/docs/dimensioning/img/esp32_c3/tls_003.png differ diff --git a/docs/dimensioning/img/esp32_s3/base.png b/docs/dimensioning/img/esp32_s3/base.png index b814774..858f288 100644 Binary files a/docs/dimensioning/img/esp32_s3/base.png and b/docs/dimensioning/img/esp32_s3/base.png differ diff --git a/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png b/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png index 81887d2..24d589f 100644 Binary files a/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png and b/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png differ diff --git a/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png b/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png index 664932a..9312ab8 100644 Binary files a/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png and b/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png differ diff --git a/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png b/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png index c81a08a..38c74e1 100644 Binary files a/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png and b/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png differ diff --git a/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png b/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png index fc48588..06d7e08 100644 Binary files a/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png and b/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png differ diff --git a/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png b/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png index 45dd764..563b6bd 100644 Binary files a/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png and b/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png differ diff --git a/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png b/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png index a1f9643..a9d5f31 100644 Binary files a/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png and b/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png differ diff --git a/docs/dimensioning/img/esp32_s3/multipart_001.png b/docs/dimensioning/img/esp32_s3/multipart_001.png index b3b9817..25b13a5 100644 Binary files a/docs/dimensioning/img/esp32_s3/multipart_001.png and b/docs/dimensioning/img/esp32_s3/multipart_001.png differ diff --git a/docs/dimensioning/img/esp32_s3/multipart_002.png b/docs/dimensioning/img/esp32_s3/multipart_002.png index 6e04077..fddf31e 100644 Binary files a/docs/dimensioning/img/esp32_s3/multipart_002.png and b/docs/dimensioning/img/esp32_s3/multipart_002.png differ diff --git a/docs/dimensioning/img/esp32_s3/multipart_003.png b/docs/dimensioning/img/esp32_s3/multipart_003.png index a799ddc..3ffe4fe 100644 Binary files a/docs/dimensioning/img/esp32_s3/multipart_003.png and b/docs/dimensioning/img/esp32_s3/multipart_003.png differ diff --git a/docs/dimensioning/img/esp32_s3/tls_001.png b/docs/dimensioning/img/esp32_s3/tls_001.png index af9ea15..190a2a3 100644 Binary files a/docs/dimensioning/img/esp32_s3/tls_001.png and b/docs/dimensioning/img/esp32_s3/tls_001.png differ diff --git a/docs/dimensioning/img/esp32_s3/tls_002.png b/docs/dimensioning/img/esp32_s3/tls_002.png index 91aefdc..6b0fc45 100644 Binary files a/docs/dimensioning/img/esp32_s3/tls_002.png and b/docs/dimensioning/img/esp32_s3/tls_002.png differ diff --git a/docs/dimensioning/img/esp32_s3/tls_003.png b/docs/dimensioning/img/esp32_s3/tls_003.png index fb13661..cbfc67d 100644 Binary files a/docs/dimensioning/img/esp32_s3/tls_003.png and b/docs/dimensioning/img/esp32_s3/tls_003.png differ diff --git a/example/demo_app/app.py b/example/demo_app/app.py index 28ba282..937b89f 100644 --- a/example/demo_app/app.py +++ b/example/demo_app/app.py @@ -20,7 +20,7 @@ def app(http_ctx, payload): if value_format == "%": free = round(100 * free / (free + mem_alloc()), 2) - return "text/plain", (f"Free memory [{value_format}]: {free}\n") + return "text/plain", f"Free memory [{value_format}]: {free}\n" async def main(): diff --git a/src/pyrobusta/bindings/http_connection.py b/src/pyrobusta/bindings/http_connection.py index 533dfc3..0d9b79f 100644 --- a/src/pyrobusta/bindings/http_connection.py +++ b/src/pyrobusta/bindings/http_connection.py @@ -7,7 +7,7 @@ from ..stream.buffer import BufferFullError from ..transport.connection import BaseConnection -from ..protocol.http import HttpEngine, ServerBusyError, HeaderParsingError +from ..protocol.http import HttpEngine from ..utils import logging @@ -19,7 +19,6 @@ class HttpConnection(BaseConnection): MTU_SIZE = 1460 STATE_MACHINE_SLEEP_MS = 2 - RESP_HANDLER_SLEEP_MS = 2 RECV_TIMEOUT_SECONDS = 10 __slots__ = ("_engine", "_prev_state", "_recv_buf", "_send_buf") @@ -36,9 +35,12 @@ async def run(self): Handle socket connection with HTTP state machine parser. """ self._prev_state = None - while self._engine.state is not None: + while not self._engine.is_terminated(): await self._run_state_machine() await sleep_ms(self.STATE_MACHINE_SLEEP_MS) + if self._engine.is_terminated() and self._engine.do_keep_alive(): + self._engine.reset() + self._prev_state = None async def _flush_response(self): data = self._send_buf.peek() @@ -49,67 +51,48 @@ async def _flush_response(self): async def _read_to_buf(self): buf_free = self._recv_buf.capacity - self._recv_buf.size() if not buf_free: - self._engine.on_buffer_full(self._send_buf) - await self._flush_response() - return 0 - try: - request = await self.read( - read_bytes=buf_free, - decoding=None, - timeout_seconds=self.RECV_TIMEOUT_SECONDS, - ) - except asyncio.TimeoutError: - self._engine.on_timeout(self._send_buf) - await self._flush_response() - return 0 - except Exception as e: # pylint: disable=W0718 - self._engine.on_failure( - self._send_buf, b"Read error: " + str(e).encode("ascii") - ) - await self._flush_response() - return 0 + raise BufferFullError() + request = await self.read( + read_bytes=buf_free, + decoding=None, + timeout_seconds=self.RECV_TIMEOUT_SECONDS, + ) self._recv_buf.write(request) logging.debug(__name__ + f"._read_to_buf: [{request}]") return len(request) async def _run_state_machine(self): - if self._prev_state == self._engine.state or self._prev_state is None: - num_read = await self._read_to_buf() - if not num_read: - # Reject incomplete request - self._engine.on_client_error( - self._send_buf, self._engine.BAD_REQUEST_ERROR - ) - await self._flush_response() - return - try: - resp_handler = None - while self._engine.state is not None: - self._prev_state = self._engine.state - resp_handler = self._engine.state(self._recv_buf, self._send_buf) - if not self._send_buf.size(): - break - await self._flush_response() - await sleep_ms(self.STATE_MACHINE_SLEEP_MS) - except BufferFullError: - self._engine.on_failure(self._send_buf, b"Buffer full") - await self._flush_response() - return - except ServerBusyError: - self._engine.on_unavailable(self._send_buf) - await self._flush_response() - return - except HeaderParsingError: - self._engine.on_client_error(self._send_buf, self._engine.HEADER_ERROR) - await self._flush_response() - return - except Exception as e: # pylint: disable=W0718 - logging.warning(__name__ + f"._run_state_machine: {e}") - self._engine.on_failure(self._send_buf, str(e).encode("ascii")) + # [1] read request + if self._prev_state == self._engine.state or ( + self._prev_state is None and not self._recv_buf.size() + ): + try: + num_read = await self._read_to_buf() + if not num_read: + self._engine.abort(400) + self._engine.set_response_body(b"Incomplete request") + except BufferFullError: + self._engine.abort(413) + except asyncio.TimeoutError: + self._engine.abort(408) + except Exception as e: # pylint: disable=W0718 + self._engine.abort(500) + self._engine.set_response_body(b"Read error: " + str(e).encode("ascii")) + + # [2] process request by state machine + for _ in self._engine.run(self._recv_buf): + if self._prev_state == self._engine.state: + # No state transition occurred, read more data + break + self._prev_state = self._engine.state + await sleep_ms(self.STATE_MACHINE_SLEEP_MS) + + # [3] write response + if self._engine.is_started() and self._engine.is_terminated(): + self._engine.write_response_head(self._send_buf) await self._flush_response() - return - if self._engine.state is None and resp_handler is not None: - await self._response_handler(resp_handler) + if self._engine.resp_handler is not None: + await self._response_handler(self._engine.resp_handler) async def _response_handler(self, resp_handler): if "closure" == type(resp_handler).__name__: @@ -117,22 +100,20 @@ async def _response_handler(self, resp_handler): await self._flush_response() if is_finished: break - await sleep_ms(self.RESP_HANDLER_SLEEP_MS) + await sleep_ms(self.STATE_MACHINE_SLEEP_MS) elif type(resp_handler).__name__ in ("FileIO", "BytesIO"): - with resp_handler as rh: + try: while True: view = self._send_buf.writable_view() - num_read = rh.readinto(view) + num_read = resp_handler.readinto(view) if not num_read: break self._send_buf.commit(num_read) await self._flush_response() - await sleep_ms(self.RESP_HANDLER_SLEEP_MS) + await sleep_ms(self.STATE_MACHINE_SLEEP_MS) + finally: + resp_handler.close() else: - self._engine.on_failure( - self._send_buf, - f"Invalid response handler {type(resp_handler).__name__}".encode( - "ascii" - ), + raise RuntimeError( + f"Invalid response handler {type(resp_handler).__name__}" ) - await self._flush_response() diff --git a/src/pyrobusta/connectivity/wifi.py b/src/pyrobusta/connectivity/wifi.py index bc27030..f907add 100644 --- a/src/pyrobusta/connectivity/wifi.py +++ b/src/pyrobusta/connectivity/wifi.py @@ -1,5 +1,5 @@ """ -Helpers for setting up Wi-Fi in station mode +Helpers for setting up Wi-Fi in station mode. """ from time import sleep @@ -12,7 +12,7 @@ def initialize(): """ - Initialize WLAN interface in station mode + Initialize WLAN station interface. """ ssid = get_config(CONF_WIFI_SSID) password = get_config(CONF_WIFI_PASSWORD) @@ -44,7 +44,7 @@ def initialize(): def get_address(): """ - Get the address of the WLAN interface + Get the IP address of the WLAN station. """ sta_if = WLAN(STA_IF) if sta_if.isconnected(): diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index fbd6a45..93d55ab 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -1,5 +1,5 @@ """ -Module is responsible for webserver state machine, +This module is responsible HTTP protocol parsing with partial guarantees on RFC compliance. """ @@ -12,32 +12,49 @@ CONF_HTTP_MULTIPART, CONF_HTTP_SERVE_FILES, ) +from ..utils import logging +from ..stream.buffer import BufferFullError -class HeaderParsingError(ValueError): - """Exception for errors occurring while parsing HTTP/MIME headers""" +class InvalidHeaders(ValueError): + """Exception for errors occurring while parsing HTTP/MIME headers.""" pass -class ServerBusyError(RuntimeError): - """Exception for applications to indicate busy state""" +class InvalidContentLength(ValueError): + """Exception for content-length related erros.""" + + pass + + +class MalformedRequest(ValueError): + """Exception for malformed requests.""" pass class HttpEngine: """ - HTTP protocol parser state machine - - provides an adapter/routing layer - - supports multipart request and response handling - - resolves static resources by returning a stream objects (FileIO) + HTTP protocol parser state machine and middleware. + - each instance represents a connection, allowing a request to be parsed + through a state machine + - provides an adapter/routing layer for applications + through registered endpoints (see also: register(), route()) + - supports percent encoded URLs and query parameters (x-www-form-urlencoded) + - allows applications to set response attributes (headers, status code) + + Feature flags (configured in pyrobusta.env) + - http_serve_files: serve files stored on the device + - http_multipart: support for multipart requests/responses """ __slots__ = ( "state", "status_code", "resp_headers", + "resp_handler", + "aborted", "version", "headers", "method", @@ -102,9 +119,6 @@ class HttpEngine: b"image/gif", ) - ASCII = "ascii" - CONTENT_LENGTH = "content-length" - DELETE = b"DELETE" GET = b"GET" HEAD = b"HEAD" @@ -114,18 +128,13 @@ class HttpEngine: METHODS = (DELETE, GET, HEAD, OPTIONS, POST, PUT) SUPPORTED_VERSIONS = (b"HTTP/1.1", b"HTTP/1.0") - MULTIPART_BOUNDARY = b"pyrobusta-boundary" - - CONTENT_LENGTH_ERROR = b"content length mismatch" - HEADER_ERROR = b"Invalid headers" - MULTIPART_BOUNDARY_ERROR = b"Invalid multipart boundary" - BAD_REQUEST_ERROR = b"Bad request" - def __init__(self): # [State machine] - self.state = self._parse_request_line_st + self.state = self._start_parser self.status_code = None self.resp_headers = [] + self.resp_handler = None + self.aborted = False # [Recived request] self.version = None @@ -139,22 +148,38 @@ def __init__(self): # [Multipart state] self.mp_boundary = None + def reset(self): + """ + Reset internal state to reuse a state machine object. + """ + self.state = self._start_parser + self.status_code = None + self.resp_headers.clear() + self.resp_handler = None + self.aborted = False + self.version = None + self.headers.clear() + self.method = None + self.url = None + self.query = None + self.content_len_cnt = 0 + self.recv_chunk_size = 0 + self.mp_boundary = None + # ========================================= # Methods/decorators for routing # ========================================= @classmethod - def register( - cls, endpoint: str, callback: object | str, method: str = "GET" - ) -> None: + def register(cls, endpoint: str, callback: callable, method: str = "GET") -> None: """ - Register an endpoint with a callback function - :param endpoint: name of the endpoint - :param callback: callback function + Register an endpoint with a callback function or file. + :param endpoint: URL path to be routed e.g. "/app/resource" + :param callback: callback function or file path :param method: HTTP method name """ - endpoint = endpoint.encode(cls.ASCII) - method = method.encode(cls.ASCII) + endpoint = endpoint.encode("ascii") + method = method.encode("ascii") endpoint_exists = cls._get_callback(endpoint, method) is not None if method not in cls.METHODS: @@ -167,6 +192,8 @@ def register( def route(endpoint: str, method: str): """ Decorator for registering endpoint callback functions. + :param endpoint: URL path to be routed e.g. "/app/resource" + :param method: HTTP method name """ def decorator(func): @@ -181,7 +208,9 @@ def decorator(func): @staticmethod def percent_decode(s: str): - """Decode percent-encoded input""" + """ + Decode percent-encoded input. + """ out = [] i = 0 while i < len(s): @@ -196,8 +225,8 @@ def percent_decode(s: str): @staticmethod def get_url_encoded_query_param(query: str, key: str, default: str = None): """ - Parse query and return the value belonging to a key - according to x-www-form-urlencoded + Parse a query and return the value belonging to a key + according to the x-www-form-urlencoded format. :param query: query part :param key: key to parse from the query :param default: default value to return when key is not present @@ -220,7 +249,7 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None): @classmethod def is_norm_path_served(cls, path: str): """ - Returns true if a normalized path is configured to be served + Returns true if a normalized path is configured to be served. """ served_paths = get_config(CONF_HTTP_SERVED_PATHS) parts = path.split("/") @@ -238,20 +267,20 @@ def _lookup(tuple_, key): return tuple_[idx + 1] @classmethod - def _get_callback(cls, endpoint, method): + def _get_callback(cls, endpoint, method: bytes): for e in cls.ENDPOINTS: if endpoint == e[0] and method == e[2]: return e[1] @classmethod - def _has_endpoint(cls, endpoint): + def _has_endpoint(cls, endpoint: bytes): for e in cls.ENDPOINTS: if endpoint == e[0]: return True return False @classmethod - def _supported_methods(cls, endpoint): + def _supported_methods(cls, endpoint: bytes): supported_methods = [] for method in cls.METHODS: if cls._get_callback(endpoint, method) is not None: @@ -261,20 +290,19 @@ def _supported_methods(cls, endpoint): @classmethod def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: """ - Basic parser to extract HTTP/MIME headers - :param raw_headers: headers + Basic parser to extract HTTP/MIME headers. """ header_lines = bytes(raw_headers).split(b"\r\n") headers = {} for line in header_lines: # pylint: disable=W0511 if any(c > 127 for c in line): - raise HeaderParsingError("Non-ASCII character") + raise InvalidHeaders("Non-ASCII character") if b":" not in line: - raise HeaderParsingError() + raise InvalidHeaders() name, value = line.split(b":", 1) if not name: - raise HeaderParsingError("Empty header name") + raise InvalidHeaders("Empty header name") for c in name: if ( 48 <= c <= 57 # 0-9 @@ -283,20 +311,23 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: or c in (45, 95) # -_ ): continue - raise HeaderParsingError("Invalid header name") - name = name.strip().lower().decode(cls.ASCII) + raise InvalidHeaders("Invalid header name") + name = name.strip().lower().decode("ascii") if any((c < 32 and c != 9) or c == 127 for c in value): - raise HeaderParsingError("Invalid header value") - if name == cls.CONTENT_LENGTH: + raise InvalidHeaders("Invalid header value") + if name == "content-length": value = int(value.strip()) else: - value = value.strip().decode(cls.ASCII) + value = value.strip().decode("ascii") headers[name] = value return headers @staticmethod def _get_mp_boundary(headers: dict) -> str: - """Determine from the headers if a request is multipart, and return the boundary value""" + """ + Determine from the headers if a request is multipart, + and return the boundary value. + """ content_type = headers.get("content-type") if not content_type or not content_type.lower().startswith( "multipart/form-data" @@ -315,26 +346,28 @@ def _get_mp_boundary(headers: dict) -> str: if value.startswith('"'): if len(value) < 2 or not value.endswith('"'): - raise HeaderParsingError() + raise InvalidHeaders() value = value[1:-1] elif value.endswith('"'): - raise HeaderParsingError() + raise InvalidHeaders() if not value: - raise HeaderParsingError() + raise InvalidHeaders() return value - raise HeaderParsingError() + raise InvalidHeaders() @classmethod def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]: - """Parse part headers and body and return them as a tuple""" + """ + Parse part headers and body and return them as a tuple. + """ blank_idx = -1 for i in range(len(part) - 3): if part[i : i + 4] == b"\r\n\r\n": blank_idx = i break if blank_idx == -1: - raise HeaderParsingError() + raise InvalidHeaders() headers = cls._parse_headers(part[:blank_idx]) body = part[blank_idx + 4 :] return headers, body @@ -343,7 +376,12 @@ def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]: # Helpers for state machine termination # ========================================= - def _set_response_header(self, key, value): + def set_response_header(self, key: bytes, value: bytes): + """ + Set a response header by key and value. + :param key: HTTP header key + :param value: HTTP header value + """ if ( key in self.resp_headers and (index := self.resp_headers.index(key) % 2) == 0 @@ -353,31 +391,15 @@ def _set_response_header(self, key, value): self.resp_headers.append(key) self.resp_headers.append(value) - def terminate(self, status_code: int, content_type: bytes = b"text/plain"): + def write_response_head(self, tx): """ - Terminate state machine with status code and response content-type - :param status_code: HTTP status code - :param content_type: content-type of the response + Write response status and header to an output buffer. + :param tx: response buffer """ - self.state = None - self.status_code = status_code - if content_type: - self._set_response_header(b"content-type", content_type) - self._set_response_header(b"connection", b"close") - - def _write_response_head(self, tx, content_length: int = 0): - """ - Write response status & header to the output, - with optional content-length value - """ - # Discard already accumulated content (e.g. 500 response on unexpected errors) - tx.consume() + tx.consume() # Discard already accumulated content, required on abrupt errors tx.write(self.version) tx.write(b" ") tx.write(self._lookup(self.RESP_HEADERS, self.status_code)) - if content_length is not None: - tx.write(b"\r\n") - tx.write(b"content-length: %s" % str(content_length).encode(self.ASCII)) for i in range(0, len(self.resp_headers), 2): key = self.resp_headers[i] value = self.resp_headers[i + 1] @@ -387,143 +409,202 @@ def _write_response_head(self, tx, content_length: int = 0): tx.write(value) tx.write(b"\r\n\r\n") - def _generate_response(self, tx, body: bytes | str | dict | tuple | list): + def set_response_body( + self, + body: bytes | str | dict | tuple | list, + content_type: bytes = b"text/plain", + ): """ - Write the complete response to the output, including status - and headers. Return a BytesIO object if the content length - exceeds the remaining buffer capacity, to delegate the writing - of the response body to the transport layer. + Serialize and wrap the response body with a BytesIO + object, stored by the resp_handler member. resp_handler + can be used for writing the body by the transport layer. + This method also updates the content-type and content-length + headers. + :param body: body to be sent in the response + :param content_type: content-type of the body """ - if not body: - self._write_response_head(tx, 0) - body_encoded = b"" - elif isinstance(body, (bytes, bytearray, memoryview)): - self._write_response_head(tx, len(body)) + if body is None: + return + if isinstance(body, (bytes, bytearray, memoryview)): body_encoded = body elif isinstance(body, str): body_encoded = body.encode() - self._write_response_head(tx, len(body_encoded)) elif isinstance(body, (dict, tuple, list)): body_encoded = dumps(body).encode() - self._write_response_head(tx, len(body_encoded)) else: - self.on_failure(tx, b"Unhandled body type") - return + raise ValueError("Unhandled body type") + self.set_response_header( + b"content-length", str(len(body_encoded)).encode("ascii") + ) + self.set_response_header(b"content-type", content_type) if self.method != self.HEAD: - if len(body_encoded) > tx.capacity - tx.size(): - return BytesIO(body_encoded) - tx.write(body_encoded) - - def on_client_error(self, tx, info: bytes): - """Terminate state machine and write 400 response""" - self.terminate(400) - response = info - self._write_response_head(tx, len(response)) - tx.write(response) - - def on_forbidden(self, tx): - """Terminate state machine and write 403 response""" - self.terminate(403) - self._write_response_head(tx) - - def on_missing_resource(self, tx): - """Terminate state machine and write 404 response""" - self.terminate(404) - self._write_response_head(tx) - - def on_method_not_allowed(self, tx): - """Terminate state machine and write 405 response""" - self.terminate(405) - self._write_response_head(tx) - - def on_timeout(self, tx): - """Terminate state machine and write 408 response""" - self.terminate(408) - self._write_response_head(tx) - - def on_buffer_full(self, tx): - """Terminate state machine and write 413 response""" - self.terminate(413) - self._write_response_head(tx) - - def on_failure(self, tx, info: bytes): - """Terminate state machine and write 500 response""" - self.terminate(500) - self._write_response_head(tx, len(info)) - tx.write(info) - - def on_unavailable(self, tx): - """Terminate state machine and write 503 response""" - self.terminate(503) - self._write_response_head(tx) + self.resp_handler = BytesIO(body_encoded) + + def do_keep_alive(self): + """ + Determine if the connection should be kept alive + depending on the HTTP version and headers sent in the request. + """ + if self.aborted: + return False + + connection_tokens = [ + token.strip().lower() + for token in self.headers.get("connection", "").split(",") + ] + return (self.version == b"HTTP/1.0" and "keep-alive" in connection_tokens) or ( + self.version == b"HTTP/1.1" and "close" not in connection_tokens + ) + + def terminate(self, status_code: int, request_complete: bool = False): + """ + Regular state machine termination with a specific status code. + :param status_code: HTTP status code + :param request_complete: true if the complete request is processed + """ + self.state = None + self.status_code = status_code + + if self.version == b"HTTP/1.0" and self.do_keep_alive() and request_complete: + self.set_response_header(b"connection", b"keep-alive") + elif ( + self.version == b"HTTP/1.1" + and not self.do_keep_alive() + and not request_complete + ): + self.set_response_header(b"connection", b"close") + + def abort(self, status_code: int): + """ + Abort state machine due to runtime errors. + Reset any header or response body set earlier. + :param status_code: HTTP status code + """ + self.aborted = True + self.resp_headers = [] + if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"): + self.resp_handler.close() + self.resp_handler = None + self.terminate(status_code, False) - def on_unsupported_version(self, tx): - """Terminate state machine and write 505 response""" - self.terminate(505) - self._write_response_head(tx) + def is_started(self): + """ + Returns true if the state machine has received any input. + """ + return self.state != self._start_parser # pylint: disable=W0143 + + def is_terminated(self): + """ + Returns true if the state machine is terminated. + """ + return self.state is None and self.status_code + + def run(self, rx): + """ + Run the state machine with request buffers provided. + Unlike individual states, this method does not raise an exception. + This method yields on every state transition allowing the calling side + to flush the response buffer. + """ + if self.is_terminated(): + return + try: + while not self.is_terminated(): + self.state(rx) + yield + except BufferFullError: + self.abort(500) + self.set_response_body(b"Buffer full") + except InvalidHeaders: + self.abort(400) + self.set_response_body(b"Invalid headers") + except InvalidContentLength: + self.abort(400) + self.set_response_body(b"Content length mismatch") + except MalformedRequest: + self.abort(400) + self.set_response_body(b"Malformed request") + except Exception as e: # pylint: disable=W0718 + logging.warning(__name__ + f"._run_state_machine: {e}") + self.abort(500) + self.set_response_body(str(e).encode("ascii")) + + # ======================================== + # Helpers for routing, state machine logic + # ======================================== + + def is_chunked(self): + """ + Determines if the request has a payload with chunked transfer-encoding. + """ + return self.headers.get("transfer-encoding") == "chunked" + + def has_payload(self): + """ + Determines if the request has a body. + """ + return ( + "content-length" in self.headers and self.headers["content-length"] > 0 + ) or self.is_chunked() # ================================================================================ # Parser states - # - all states must handle rx and tx buffer arguments for reading and writing data + # - all states must handle rx buffer argument for reading request data # - mandatory methods/attributes of rx: find(), peek(), consume(), size() - # - mandatory methods/attributes of tx: capacity, consume(), write(), size() - # - rx/tx reference implementation: SlidingBuffer (pyrobusta.stream.buffer) + # - reference implementation: SlidingBuffer (pyrobusta.stream.buffer) # ================================================================================ - def _parse_request_line_st(self, rx, tx): - """State for parsing the request line""" + def _start_parser(self, rx): + """ + Initial state. + """ + if rx.size(): + self.state = self._parse_request_line_st + + def _parse_request_line_st(self, rx): + """ + Parse the request line. + """ status_line_sep = rx.find(b"\r\n") if status_line_sep == -1: return status_parts = bytes(rx.peek(status_line_sep)).split() if len(status_parts) != 3: - self.on_client_error(tx, self.BAD_REQUEST_ERROR) - return + raise MalformedRequest() self.method = status_parts[0] url_parts = status_parts[1].split(b"?", 1) self.url = url_parts[0] self.query = ( "" if len(url_parts) == 1 - else self.percent_decode(url_parts[1].decode(self.ASCII)) + else self.percent_decode(url_parts[1].decode("ascii")) ) self.version = status_parts[2] if self.method not in self.METHODS: - self.on_method_not_allowed(tx) + self.terminate(405) return if self.version not in self.SUPPORTED_VERSIONS: - self.on_unsupported_version(tx) + self.terminate(505) return rx.consume(status_line_sep + 2) self.state = self._parse_headers_st - def _parse_headers_st(self, rx, tx): - """State for parsing headers""" + def _parse_headers_st(self, rx): + """ + Parse HTTP headers. + """ if (blank_idx := rx.find(b"\r\n\r\n")) == -1: return - try: - self.headers = self._parse_headers(rx.peek(blank_idx)) - if self.version == b"HTTP/1.1" and "host" not in self.headers: - raise HeaderParsingError() - except HeaderParsingError: - self.on_client_error(tx, self.HEADER_ERROR) - return + self.headers = self._parse_headers(rx.peek(blank_idx)) + if self.version == b"HTTP/1.1" and "host" not in self.headers: + raise InvalidHeaders() rx.consume(blank_idx + 4) self.state = self._route_request_st - def _is_chunked(self): - return self.headers.get("transfer-encoding") == "chunked" - - def _has_payload(self): - return ( - self.CONTENT_LENGTH in self.headers - and self.headers[self.CONTENT_LENGTH] > 0 - ) or self._is_chunked() - - def _route_request_st(self, _, tx): + def _route_request_st(self, _): """ - State for routing requests - - supported ways: static resources, endpoint callback functions + Route requests based on registered endpoints. + If no endpoint is registered, fall back to file serving. """ if self._has_endpoint(self.url) and ( self._get_callback(self.url, self.method) is not None @@ -535,114 +616,132 @@ def _route_request_st(self, _, tx): ): if self.method == self.OPTIONS: supported_methods = self._supported_methods(self.url) - self._set_response_header(b"allow", b", ".join(supported_methods)) - self.terminate(204, None) - self._write_response_head(tx, None) + self.set_response_header(b"allow", b", ".join(supported_methods)) + self.terminate(204, True) return - if self._has_payload(): + if self.has_payload(): if self.method == self.HEAD: - self.on_client_error(tx, self.BAD_REQUEST_ERROR) - return + raise MalformedRequest() if mp_boundary := self._get_mp_boundary(self.headers): - self.mp_boundary = mp_boundary.encode(self.ASCII) + # Request body is multipart + self.mp_boundary = mp_boundary.encode("ascii") self.state = self._start_multipart_parser_st - elif self._is_chunked(): - if self.CONTENT_LENGTH in self.headers: - self.on_client_error(tx, self.BAD_REQUEST_ERROR) - return - self.state = self._recv_chunked_size_st + elif self.is_chunked(): + # Request body is chunked + if "content-length" in self.headers: + raise MalformedRequest() + self.state = self._recv_chunk_size_st else: self.state = self._recv_payload_st else: self.state = self._app_endpoint_st return + # Request does not have a registered endpoint if ( self._has_endpoint(self.url) and self._get_callback(self.method, self.url) is None ): supported_methods = self._supported_methods(self.url) - self._set_response_header(b"allow", b", ".join(supported_methods)) - self.on_method_not_allowed(tx) + self.set_response_header(b"allow", b", ".join(supported_methods)) + self.terminate(405) return + # Fallback: serve file if self.method in (self.GET, self.HEAD): - self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, self.url) + self.state = lambda _rx: self._send_file_st(_rx, self.url) return - self.on_missing_resource(tx) + self.terminate(404) - def _recv_chunked_size_st(self, rx, _): + def _recv_chunk_size_st(self, rx): + """ + State for determining the chunk size (transfer-encoding: chunked). + """ if (blank_idx := rx.find(b"\r\n")) == -1: return self.recv_chunk_size = int(bytes(rx.peek(blank_idx)), 16) + if self.recv_chunk_size < 0: + raise InvalidContentLength() rx.consume(blank_idx + 2) self.state = self._recv_chunk_st - def _recv_chunk_st(self, rx, tx): + def _recv_chunk_st(self, rx): + """ + State for receiving a complete chunk (transfer-encoding: chunked). + """ if self.recv_chunk_size + 2 > rx.size(): return if self.recv_chunk_size + 2 <= rx.size(): - if rx.peek()[self.recv_chunk_size : self.recv_chunk_size + 2] != b"\r\n": - self.on_client_error(tx, self.CONTENT_LENGTH_ERROR) - return + if rx.peek(self.recv_chunk_size + 2)[-2:] != b"\r\n": + raise InvalidContentLength() self.state = self._app_endpoint_st - def _recv_payload_st(self, rx, tx): - if self.headers[self.CONTENT_LENGTH] > rx.size(): - return - if self.headers[self.CONTENT_LENGTH] < rx.size(): - self.on_client_error(tx, self.CONTENT_LENGTH_ERROR) + def _recv_payload_st(self, rx): + """ + State for receiving the request body. + """ + if self.headers["content-length"] > rx.size(): return self.state = self._app_endpoint_st - def _app_endpoint_st(self, rx, tx): - """Process a request by registered callback functions""" + def _app_endpoint_st(self, rx): + """ + Process a request by registered callback functions. + """ method = self.GET if self.method == self.HEAD else self.method callback = self._get_callback(self.url, method) - if self._has_payload(): - if self._is_chunked(): + if self.has_payload(): + if self.is_chunked(): if self.recv_chunk_size: callback(self, bytes(rx.peek(self.recv_chunk_size))) rx.consume(self.recv_chunk_size + 2) - self.state = self._recv_chunked_size_st + self.state = self._recv_chunk_size_st return dtype, data = callback(self, bytes(rx.peek(self.recv_chunk_size))) rx.consume(self.recv_chunk_size + 2) else: - dtype, data = callback(self, bytes(rx.peek())) - dtype = dtype.encode(self.ASCII) + dtype, data = callback( + self, bytes(rx.peek(self.headers["content-length"])) + ) + dtype = dtype.encode("ascii") else: if not callable(callback): - # Handle as a static resource - self.state = lambda _rx, _tx: self._send_file_st( - _rx, _tx, callback.encode(HttpEngine.ASCII) + # Handle as a file path + self.state = lambda _rx: self._send_file_st( + _rx, callback.encode("ascii") ) return dtype, data = callback(self, b"") - dtype = dtype.encode(self.ASCII) - self._set_response_header(b"content-type", dtype) + dtype = dtype.encode("ascii") + self.set_response_header(b"content-type", dtype) if dtype.startswith(b"multipart/"): - self.state = lambda _rx, _tx: self._generate_multipart_response( - _rx, _tx, data, dtype - ) + self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) return - self.terminate(200, dtype) - return self._generate_response(tx, data) + if not self.is_terminated(): + self.terminate(200, True) + self.set_response_body(data, content_type=dtype) - def _send_file_st(self, _, tx, web_resource: bytes): # pylint: disable=W0613 - """State for returning a static resource - disabled""" - self.on_unavailable(tx) + def _send_file_st(self, _, path: bytes): # pylint: disable=W0613 + """ + State for returning including a file in the response body (disabled). + :param path: path to the resource + """ + self.terminate(503, True) - def _start_multipart_parser_st(self, rx, tx): # pylint: disable=W0613 - """Initial state for processing multipart requests""" - self.on_unavailable(tx) + def _start_multipart_parser_st(self, rx): # pylint: disable=W0613 + """ + Initial state for processing multipart requests (disabled). + """ + self.terminate(503) def _generate_multipart_response( - self, rx, tx, callback, dtype + self, rx, callback, dtype ): # pylint: disable=W0613 - """Generate multipart response depening on the exact content type""" - self.on_unavailable(tx) + """ + Generate multipart response depening on the exact content type (disabled). + """ + self.terminate(503, True) def enable_optional_features(): diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index 1e1acc4..f128d37 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -10,43 +10,48 @@ from pyrobusta.utils.helpers import normalize_path, add_method -def _send_file_st(self, _, tx, web_resource: bytes): - """State for returning a static resource""" +def _send_file_st(self, _, file_path: bytes): + """ + State for returning a file. By default, /www is prepended to the path. + Alternatively, ready any file from the root when the path starts with /files + if it is configured in http_served_paths. + :param file_path: path to the file (unnormalized) + """ if self.url == b"/files": - web_resource = "/" + file_path = "/" elif self.url.startswith(b"/files/"): - web_resource = web_resource[7:] + file_path = file_path[7:] elif self.url == b"/": - web_resource = b"/www/index.html" + file_path = b"/www/index.html" else: - web_resource = b"/www" + web_resource + file_path = b"/www" + file_path - extension = web_resource.rsplit(b".", 1)[-1] - norm_path = normalize_path(web_resource.decode(self.ASCII)) + extension = file_path.rsplit(b".", 1)[-1] + norm_path = normalize_path(file_path.decode("ascii")) is_path_served = self.is_norm_path_served(norm_path) if not is_path_served: try: stat(norm_path) - self.on_forbidden(tx) + self.terminate(403, True) return except OSError: - self.on_missing_resource(tx) + self.terminate(404, True) return try: content_type = self._lookup(self.CONTENT_TYPES, extension) except ValueError: content_type = self._lookup(self.CONTENT_TYPES, b"raw") try: - self._set_response_header( - b"content-length", str(stat(norm_path)[6]).encode(http.HttpEngine.ASCII) + self.set_response_header( + b"content-length", str(stat(norm_path)[6]).encode("ascii") ) - self.terminate(200, content_type) - self._write_response_head(tx, None) + self.set_response_header(b"content-type", content_type) + self.terminate(200, True) if self.method != self.HEAD: - return open(norm_path, "rb") + self.resp_handler = open(norm_path, "rb") # pylint: disable=R1732 return except OSError: - self.on_missing_resource(tx) + self.terminate(404, True) def apply_patches(): diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index e58074a..fbc0862 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -8,24 +8,27 @@ from pyrobusta.utils.helpers import add_method -def _generate_multipart_response(self, _, tx, callback, dtype): - """Generate multipart response depening on the exact content type""" +def _generate_multipart_response(self, _, callback: callable, dtype: bytes): + """ + Generate multipart response depening on the exact content type. + The callback function is called without arguments, and it must return bytes-like objects. + :param callback: function for part generation, each call generates a separate part + :param dtype: exact multipart content-type (multipart/*) + """ if type(callback).__name__ not in ("function", "closure"): - self.on_failure(tx, b"Invalid response handler") - return - self.terminate(200, dtype) + raise ValueError("Invalid response handler") + self.terminate(200, True) boundary = self.MULTIPART_BOUNDARY - self._set_response_header(b"content-type", dtype + b"; boundary=" + boundary) - self._write_response_head(tx, None) + self.set_response_header(b"content-type", dtype + b"; boundary=" + boundary) if self.method != self.HEAD: - return self._multipart_wrapper_factory(callback, boundary) + self.resp_handler = self._multipart_wrapper_factory(callback, boundary) -def _multipart_wrapper_factory(callback, boundary: bytes): +def _multipart_wrapper_factory(callback: callable, boundary: bytes): """ - Factory method for creating closures that write multipart responses - :param callback: function without arguments, must return bytes-like objects - :param content_type: content type of body parts + Factory method for creating closures that write multipart responses. + The callback function is called without arguments, and it must return bytes-like objects. + :param callback: function for part generation, each call generates a separate part :param boundary: boundary value :return closure: closure to invoke for response generation """ @@ -33,7 +36,7 @@ def _multipart_wrapper_factory(callback, boundary: bytes): def _multipart_wrapper(tx): """ - Write multipart data generated from a callback's return value + Write multipart data generated from a callback's return value. - if insufficient buffer space is available, the generator yields control so the caller can flush or drain the buffer :return bool: true if the stream is completed @@ -62,25 +65,27 @@ def _multipart_wrapper(tx): return _multipart_wrapper -def _start_multipart_parser_st(self, rx, tx): - """Initial state for processing multipart requests""" - if not http.HttpEngine.CONTENT_LENGTH in self.headers: - self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR) - return +def _start_multipart_parser_st(self, rx): + """ + Initial state for processing multipart requests. + """ + if not "content-length" in self.headers: + raise http.InvalidContentLength() if (start_delimiter := rx.find(b"\r\n")) == -1: return self.mp_delimiter = b"--" + self.mp_boundary + b"\r\n" self.mp_last_delimiter = b"--" + self.mp_boundary + b"--" if rx.peek(start_delimiter + 2) != self.mp_delimiter: - self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR) - return + raise http.MalformedRequest() rx.consume(start_delimiter + 2) self.content_len_cnt += start_delimiter + 2 self.state = self._parse_boundary_st -def _parse_boundary_st(self, rx, _): - """State for parsing multipart boundary delimiter""" +def _parse_boundary_st(self, rx): + """ + State for parsing multipart boundary delimiter. + """ if ( rx.find(b"\r\n" + self.mp_delimiter) == -1 and rx.find(b"\r\n" + self.mp_last_delimiter) == -1 @@ -89,51 +94,50 @@ def _parse_boundary_st(self, rx, _): self.state = self._parse_complete_part_st -def _parse_complete_part_st(self, rx, tx): +def _parse_complete_part_st(self, rx): """ - State for processing complete parts in a multipart request + State for processing complete parts in a multipart request. - registered callback is required to process parts """ next_delimiter = rx.find(b"\r\n--" + self.mp_boundary) part = rx.peek(next_delimiter) rx.consume(next_delimiter + 2) # Consume leading CRLF self.content_len_cnt += next_delimiter + 2 - is_final = rx.peek(len(self.mp_last_delimiter)) == self.mp_last_delimiter + is_final = ( + rx.size() >= len(self.mp_last_delimiter) + and rx.peek(len(self.mp_last_delimiter)) == self.mp_last_delimiter + ) + # Validate part and content-length - if self.headers[http.HttpEngine.CONTENT_LENGTH] < self.content_len_cnt: - self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR) - return - try: - part_headers, part_body = http.HttpEngine._parse_body_part(part) - except http.HeaderParsingError: - self.on_client_error(tx, http.HttpEngine.HEADER_ERROR) - return + if self.headers["content-length"] < self.content_len_cnt: + raise http.InvalidContentLength() + part_headers, part_body = http.HttpEngine._parse_body_part(part) callback = http.HttpEngine._get_callback(self.url, self.method) + # Process complete part if not is_final: callback(self, (part_headers, part_body)) if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter: - self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR) - return + raise http.MalformedRequest() rx.consume(len(self.mp_delimiter)) self.content_len_cnt += len(self.mp_delimiter) self.mp_is_first = False self.state = self._parse_boundary_st return + # Process last part rx.consume(len(self.mp_last_delimiter)) self.content_len_cnt += len(self.mp_last_delimiter) if ( - self.headers[http.HttpEngine.CONTENT_LENGTH] != self.content_len_cnt - and self.content_len_cnt + rx.size() - != self.headers[http.HttpEngine.CONTENT_LENGTH] + self.headers["content-length"] != self.content_len_cnt + and self.content_len_cnt + rx.size() < self.headers["content-length"] ): - self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR) - return + raise http.InvalidContentLength() self.mp_is_last = True dtype, data = callback(self, (part_headers, part_body)) - self.terminate(200, dtype.encode(http.HttpEngine.ASCII)) - return self._generate_response(tx, data) + self.set_response_header(b"content-type", dtype.encode("ascii")) + self.terminate(200, True) + self.set_response_body(data) def apply_patches(): @@ -150,6 +154,7 @@ def new_init(self, *args, **kwargs): self.mp_last_delimiter = None http.HttpEngine.__init__ = new_init + http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary" add_method(http.HttpEngine, _generate_multipart_response) add_method(http.HttpEngine, _multipart_wrapper_factory, "static") diff --git a/src/pyrobusta/server/http_server.py b/src/pyrobusta/server/http_server.py index 6e7d4d5..29b7cc3 100644 --- a/src/pyrobusta/server/http_server.py +++ b/src/pyrobusta/server/http_server.py @@ -65,7 +65,8 @@ class HttpServer: @classmethod def _init_pools(cls, max_clients): """ - Initialize pool of buffers for sending/receiving based on different profiles + Initialize pool of buffers for sending/receiving based on different profiles. + :param max_clients: maximum number of HTTP clients """ mem_available = mem_free() + mem_alloc() con_limit = max_clients @@ -95,7 +96,9 @@ def _init_pools(cls, max_clients): @classmethod async def _drop_client(cls, client): - """Remove client from active list""" + """ + Remove client from the list of active clients. + """ if client not in cls.ACTIVE_CLIENTS: return logging.debug(__name__ + f": {client.id} dropped") @@ -141,6 +144,9 @@ async def can_handle_new_client(self): return False async def _reserve_buffers(self): + """ + Reserve and return request and response buffers. + """ if self.SEND_POOL is None or self.RECV_POOL is None: raise RuntimeError("Pools are uninitialized") @@ -160,6 +166,8 @@ async def _accept_socket(self, reader, writer): """ Handle incoming socket connection for HTTP. - creates HttpConnection object + :param reader: asyncio StreamReader + :param reader: asyncio StreamWriter """ if not await self.can_handle_new_client(): logging.debug(__name__ + ": cannot accept new client") @@ -224,7 +232,7 @@ async def start_socket_server(self): async def terminate(self): """ - Terminate HTTP server and drop clients + Terminate HTTP server and drop clients. """ logging.info(__name__ + ": terminated") while self.ACTIVE_CLIENTS: diff --git a/src/pyrobusta/stream/buffer.py b/src/pyrobusta/stream/buffer.py index 812a1f0..b5e81f0 100644 --- a/src/pyrobusta/stream/buffer.py +++ b/src/pyrobusta/stream/buffer.py @@ -19,7 +19,7 @@ class MemoryPool: def __init__(self, block_size, block_count, wrapper=None): """ - Initialize memory pool + Initialize memory pool. :param block_size: size of each memory block in bytes :param block_count: number of reservable memory blocks :param wrapper: wrapper class (abstraction layer) to access the memory, e.g. SlidingBuffer @@ -66,7 +66,7 @@ class SlidingBuffer: - Incremental consumption by advancing 'start' - Incremental writes by advancing 'end' - Automatic in-place compaction when additional space is - required and unused bytes exist before 'start' + required and unused bytes exist before 'start' - Bounded memory usage; no dynamic reallocation """ @@ -80,25 +80,33 @@ def __init__(self, buffer: bytearray | memoryview): self.capacity = len(buffer) def size(self) -> int: - """Determine the window size""" + """ + Determine the window size. + """ return self._end - self._start def writable(self) -> int: - """Determine the writeable size of the buffer""" + """ + Determine the writeable size of the buffer. + """ return self.capacity - self._end def readable_view(self) -> memoryview: - """Return a memoryview to the readable region of the buffer (window)""" + """ + Return a memoryview to the readable region of the buffer (window). + """ return self._mv[self._start : self._end] def writable_view(self) -> memoryview: - """Return a memoryview to the writeable region of the buffer""" + """ + Return a memoryview to the writeable region of the buffer. + """ return self._mv[self._end : self.capacity] def _compact(self): """ Compact the buffer by shifting the active - window to the beginning of the bytearray + window to the beginning of the bytearray. """ if self._start == 0: return @@ -112,7 +120,7 @@ def _compact(self): def peek(self, n=None) -> memoryview: """ Return the first n bytes from the window, - return the entire window when n is undefined + return the entire window when n is undefined. """ if n is None: n = self.size() @@ -121,7 +129,9 @@ def peek(self, n=None) -> memoryview: return self._mv[self._start : self._start + n] def write(self, data: bytes): - """Write new data into the writable region and advance the 'end' index""" + """ + Write new data into the writable region and advance the 'end' index. + """ if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError("write() expects bytes or bytearray") needed = len(data) @@ -135,7 +145,9 @@ def write(self, data: bytes): self._end += needed def consume(self, n: int = None): - """Discard the first n bytes of the window by advancing the 'start' index""" + """ + Discard the first n bytes of the window by advancing the 'start' index. + """ if n is None: n = self.size() if n > self.size(): @@ -148,7 +160,7 @@ def consume(self, n: int = None): def prepare(self, n: int): """ Check if the writeable region is larger or equal to n, - otherwise attempt to compact the buffer + otherwise attempt to compact the buffer. """ if n > self.capacity: raise ValueError("Capacity exceeded") @@ -159,13 +171,17 @@ def prepare(self, n: int): raise ValueError("Capacity exceeded") def commit(self, n): - """Increase the window size by n bytes by incrementing the 'end' index""" + """ + Increase the window size by n bytes by incrementing the 'end' index. + """ if self._end + n > self.capacity: raise ValueError("Capacity exceeded") self._end += n def find(self, term: bytes) -> int: - """Find and return the index of a search term in the current window""" + """ + Find and return the index of a search term in the current window. + """ for i in range(self._start, self._end - len(term) + 1): if self._mv[i : i + len(term)] == term: return i - self._start diff --git a/src/pyrobusta/transport/connection.py b/src/pyrobusta/transport/connection.py index ce516e4..aa5d7bd 100644 --- a/src/pyrobusta/transport/connection.py +++ b/src/pyrobusta/transport/connection.py @@ -19,8 +19,8 @@ class BaseConnection: def __init__(self, reader, writer): """ Base class for connection handling. - :param reader: async reader stream object - :param writer: async writer stream object + :param reader: asyncio StreamReader + :param writer: asyncio StreamWriter """ client_info = writer.get_extra_info("peername") self.id = str(client_info[0]) + ":" + str(client_info[1]) @@ -43,7 +43,7 @@ async def read( :param read_bytes: number of bytes to read :param decoding: decoding to use (optional), bytes are returned by default :param timeout_seconds: an exception is raised if exceeded, 0 means waiting indefinitely - :return data: holds bytes or decoded string read from the socket + :return data: holds bytes or decoded string read by the stream reader """ if not self.connected: raise OSError(f"{self.id} already closed") diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py index 08f647a..b713309 100644 --- a/src/pyrobusta/utils/helpers.py +++ b/src/pyrobusta/utils/helpers.py @@ -6,7 +6,9 @@ def normalize_path(path: str): - """Normalize a path string to resolve file and directory paths""" + """ + Normalize a path string to resolve file and directory paths. + """ if not path: return "" parts = [] @@ -27,10 +29,11 @@ def normalize_path(path: str): return cwd -def add_method(cls, func, method_type="instance"): +def add_method(cls, func: callable, method_type="instance"): """ - Helper to patch/extend classes with - additional methods and states. + Helper to patch/extend classes with additional methods and states. + :param func: function to add + :param method_type: type of the method (instance, static, class) """ if method_type == "instance": setattr(cls, func.__name__, func) diff --git a/src/pyrobusta/utils/logging.py b/src/pyrobusta/utils/logging.py index 6c58f66..9f67f21 100644 --- a/src/pyrobusta/utils/logging.py +++ b/src/pyrobusta/utils/logging.py @@ -11,7 +11,7 @@ def current_log_level(): """ - Determine current log level from the config + Determine current log level from the config. """ current = get_config(CONF_LOG_LEVEL) if current == "debug": @@ -25,7 +25,7 @@ def current_log_level(): def warning(log): """ - Print warning messages + Print warning messages. """ if current_log_level() >= _LOG_LEVEL_WARNING: print(f"[WARN] {log}") @@ -33,7 +33,7 @@ def warning(log): def info(log): """ - Print info messages + Print info messages. """ if current_log_level() >= _LOG_LEVEL_INFO: print(f"[INFO] {log}") @@ -41,7 +41,7 @@ def info(log): def debug(log): """ - Print debug messages + Print debug messages. """ if current_log_level() >= _LOG_LEVEL_DEBUG: print(f"[DEBUG] {log}") diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py index 68d5577..7618c9c 100644 --- a/tests/functional/test_http.py +++ b/tests/functional/test_http.py @@ -6,11 +6,7 @@ from os import mkdir, remove, rmdir from pyrobusta.server import http_server -from pyrobusta.protocol.http import ( - HttpEngine, - enable_optional_features, - ServerBusyError, -) +from pyrobusta.protocol.http import HttpEngine, enable_optional_features from pyrobusta.utils.config import ( CONF_HTTP_SERVED_PATHS, CONF_TLS, @@ -85,8 +81,9 @@ def simple_callback(http_ctx, _): @HttpEngine.route("/test/busy", "POST") -def busy_callback(*_): - raise ServerBusyError() +def busy_callback(http_ctx, _): + http_ctx.terminate(503) + return "text/plain", "Unavailable" def create_chunked_app_endpoint(endpoint): @@ -117,7 +114,9 @@ async def test_simple_response(tls_enabled): # Test: text/plain plain_response = await send_request( b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\nAccept:text/plain\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept:text/plain\r\n" b"\r\n", tls_enabled, ) @@ -135,7 +134,9 @@ async def test_simple_response(tls_enabled): # Test: application/json json_response = await send_request( b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\nAccept:application/json\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept: application/json\r\n" b"\r\n", tls_enabled, ) @@ -160,7 +161,9 @@ async def test_server_busy(): server, server_task = await start_server() plain_response = await send_request( - b"POST /test/busy HTTP/1.1\r\n" b"Host: localhost\r\n\r\n" + b"POST /test/busy HTTP/1.1\r\n" + b"Connection:close\r\n" + b"Host: localhost\r\n\r\n" ) test_assert( f"response is rejected by busy service with 503", @@ -179,15 +182,14 @@ async def test_chunked_transfer_encoding(): server, server_task = await start_server() json_response = await send_request( - ( - b"POST /test/chunked HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Transfer-Encoding: chunked\r\n\r\n" - b"14\r\nchunking\r\ntest\r\ncase\r\n" - b"E\r\nchunking\r\ntest\r\n" - b"8\r\nchunking\r\n" - b"0\r\n\r\n" - ) + b"POST /test/chunked HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Transfer-Encoding: chunked\r\n\r\n" + b"14\r\nchunking\r\ntest\r\ncase\r\n" + b"E\r\nchunking\r\ntest\r\n" + b"8\r\nchunking\r\n" + b"0\r\n\r\n" ) response_body = json.loads(json_response.split(b"\r\n\r\n")[1]) test_assert( @@ -226,7 +228,9 @@ async def test_fs_access_control(): # Case #1: /www/index.html response = await send_request( - (b"GET /allowed/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") + b"GET /allowed/index.html HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Host: localhost\r\n\r\n" ) response_body = response.split(b"\r\n\r\n")[1] @@ -238,7 +242,9 @@ async def test_fs_access_control(): # Case #2: /index.html response = await send_request( - (b"GET /rejected/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") + b"GET /rejected/index.html HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Host: localhost\r\n\r\n" ) test_assert( @@ -255,6 +261,90 @@ async def test_fs_access_control(): await server.terminate() +@garbage_collect +async def test_keepalive(): + setup_config() + server, server_task = await start_server() + + # ---------------------------------- + # Case 1: all requests are processed + # ---------------------------------- + plain_responses = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + ) + + test_assert( + f"contains all responses (connection: keep-alive)", + plain_responses.count(b"HTTP/1.1 200 OK"), + 3, + ) + + # ------------------------------------------------------------------- + # Case 2: close connection after the second request (invalid framing) + # ------------------------------------------------------------------- + plain_responses = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + ) + + test_assert( + f"contains two responses (connection: keep-alive, invalid framing)", + plain_responses.count(b"HTTP/1.1"), + 2, + ) + + # ------------------------------------------------ + # Case 3: close connection after the first request + # ------------------------------------------------ + plain_response = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + ) + + test_assert( + f"contains single response (connection: close)", + plain_response.count(b"HTTP/1.1 200 OK"), + 1, + ) + + server_task.cancel() + await server.terminate() + + ################################################# # Test methods ################################################# @@ -294,6 +384,7 @@ def test_main(): asyncio.run(test_server_busy()) asyncio.run(test_chunked_transfer_encoding()) asyncio.run(test_fs_access_control()) + asyncio.run(test_keepalive()) test_main() diff --git a/tests/functional/test_http_multipart.py b/tests/functional/test_http_multipart.py index 1765243..e52f665 100644 --- a/tests/functional/test_http_multipart.py +++ b/tests/functional/test_http_multipart.py @@ -91,7 +91,7 @@ def multipart_callback(http_ctx, _): async def start_server(): """ - Start an HTTP server as a background task + Start an HTTP server as a background task. """ server = http_server.HttpServer() server_task = asyncio.create_task(server.start_socket_server()) @@ -107,9 +107,12 @@ async def test_multipart_response(tls_enabled): # Test: 1 part plain_response = await send_request( b"GET /test/multipart HTTP/1.1\r\n" - b"Host: localhost\r\nX-Part-Count: 1\r\n\r\n", + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"X-Part-Count: 1\r\n\r\n", tls_enabled, ) + test_assert( f"http{"s" if tls_enabled else ""} response contains 1 part", b"Response 1" in plain_response, @@ -119,7 +122,9 @@ async def test_multipart_response(tls_enabled): # Test: 10 parts plain_response = await send_request( b"GET /test/multipart HTTP/1.1\r\n" - b"Host: localhost\r\nX-Part-Count: 10\r\n\r\n", + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"X-Part-Count: 10\r\n\r\n", tls_enabled, ) test_assert( diff --git a/tests/system/http_dimensioning/test.py b/tests/system/http_dimensioning/test.py index b191975..aa48adb 100644 --- a/tests/system/http_dimensioning/test.py +++ b/tests/system/http_dimensioning/test.py @@ -189,7 +189,9 @@ def measure_footprint(config, device_ip): port = 4443 if config["tls"] == True else 8080 try: usage = requests.get( - f"{proto}://{device_ip}:{port}/mem/current", verify=False + f"{proto}://{device_ip}:{port}/mem/current", + verify=False, + headers={"Connection": "close"}, ).text print(f"Measured: {usage}") except: @@ -225,6 +227,7 @@ def worker(): f"{base_url}/index.html", verify=False, timeout=5, + headers={"Connection": "close"}, ) resp.raise_for_status() @@ -254,6 +257,7 @@ def worker(): f"{base_url}/mem/time-series", verify=False, timeout=5, + headers={"Connection": "close"}, ).json() print(f"Measured: {usage}") diff --git a/tests/unit/test_buffer.py b/tests/unit/test_buffer.py index f004855..75c7d1e 100644 --- a/tests/unit/test_buffer.py +++ b/tests/unit/test_buffer.py @@ -5,7 +5,7 @@ class BufferTestBase(unittest.TestCase): """ - Tests for stream.buffer module + Tests for stream.buffer module. """ buffer_type = bytearray diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index 31d393d..949bc2f 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -83,7 +83,7 @@ def test_status_parsing_valid(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.method, b"GET") self.assertEqual(self.engine.url, b"/index.html") @@ -96,7 +96,7 @@ def test_status_parsing_incomplete_line(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) if self.engine.state is None: break @@ -110,7 +110,7 @@ def test_status_parsing_unsupported_method(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) if self.engine.state is None: break @@ -125,7 +125,7 @@ def test_status_parsing_unsupported_version(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) if self.engine.state is None: break @@ -141,7 +141,7 @@ def test_header_parsing_valid(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertDictEqual( {"content-length": 10, "content-type": "application/json"}, @@ -153,14 +153,10 @@ def test_header_parsing_valid(self): def test_header_parsing_incomplete_header(self): request = b"GET /index.html HTTP/1.1\r\nContent-Type\r\n\r\n" - for i in range(len(request)): - self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) - if self.engine.state is None: - break - - self.assertEqual(self.engine.status_code, 400) - self.assertEqual(self.engine.state, None) + with self.assertRaises(self.http_module.InvalidHeaders): + for i in range(len(request)): + self.rx.write(request[i : i + 1]) + self.engine.state(self.rx) def test_header_parsing_error(self): for case in ( @@ -171,7 +167,7 @@ def test_header_parsing_error(self): b"space in header name: value", b"new-line-in-header:\nvalue", ): - with self.assertRaises(self.http_module.HeaderParsingError): + with self.assertRaises(self.http_module.InvalidHeaders): self.engine._parse_headers(case) def test_routing_unsupported_method(self): @@ -183,7 +179,7 @@ def test_routing_unsupported_method(self): test_callback = mock.Mock() self.engine.register("/api/test", test_callback, "POST") - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 405) self.assertEqual(self.engine.state, None) @@ -201,7 +197,7 @@ def test_routing_options_method(self): self.engine.register("/api/test", test_callback, "POST") self.engine.register("/api/test", test_callback, "PUT") - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 204) self.assertEqual(self.engine.state, None) @@ -220,15 +216,15 @@ def test_routing_get_method(self): self.engine.register("/api/test", test_callback, "GET") while self.engine.state is not None: - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) - self.assertNotEqual( - self.tx.find(b"content-length: " + str(len(test_response)).encode("ascii")), - -1, + self.assertEqual( + int(self.engine._lookup(self.engine.resp_headers, b"content-length")), + len(test_response), ) - self.assertNotEqual(self.tx.find(test_response), -1) + self.assertEqual(self.engine.resp_handler.read(), test_response) def test_routing_head_method(self): self.engine.state = self.engine._route_request_st @@ -242,22 +238,22 @@ def test_routing_head_method(self): self.engine.register("/api/test", test_callback, "GET") while self.engine.state is not None: - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) - self.assertNotEqual( - self.tx.find(b"content-length: " + str(len(test_response)).encode("ascii")), - -1, + self.assertEqual( + int(self.engine._lookup(self.engine.resp_headers, b"content-length")), + len(test_response), ) - self.assertEqual(self.tx.find(test_response), -1) + self.assertEqual(self.engine.resp_handler, None) def test_simple_query_parameter(self): request = b"GET /api/test?param HTTP/1.1\r\n" for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.query, "param") @@ -275,7 +271,7 @@ def pct_encode(b): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.query, f"safe_chars.{unsafe_chars}") @@ -284,7 +280,7 @@ def test_single_url_encoded_query_parameter(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual( self.engine.get_url_encoded_query_param(self.engine.query, "param"), "value" @@ -297,7 +293,7 @@ def test_multiple_url_encoded_query_parameter(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual( self.engine.get_url_encoded_query_param(self.engine.query, "param1"), @@ -317,7 +313,7 @@ def test_empty_or_missing_url_encoded_query_parameter(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual( self.engine.get_url_encoded_query_param(self.engine.query, "param1"), @@ -342,7 +338,7 @@ def test_overlapping_url_encoded_query_parameter(self): for i in range(len(request)): self.rx.write(request[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual( self.engine.get_url_encoded_query_param(self.engine.query, "data"), @@ -362,7 +358,7 @@ def test_chunked_transfer_encoding_valid(self): self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" self.engine.headers["transfer-encoding"] = "chunked" - self.engine.state = self.engine._recv_chunked_size_st + self.engine.state = self.engine._recv_chunk_size_st test_callback = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_callback, "GET") @@ -375,10 +371,10 @@ def test_chunked_transfer_encoding_valid(self): ): for i in range(len(chunk)): self.rx.write(chunk[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.state, self.engine._app_endpoint_st) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) size_delimiter = chunk.find(b"\r\n") test_callback.assert_called_with( self.engine, chunk[size_delimiter + 2 : -2] @@ -392,27 +388,23 @@ def test_chunked_transfer_encoding_invalid_chunk_size_smaller(self): self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" self.engine.headers["transfer-encoding"] = "chunked" - self.engine.state = self.engine._recv_chunked_size_st + self.engine.state = self.engine._recv_chunk_size_st test_callback = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_callback, "GET") - chunk = b"2\r\nchunking\r\n" - for i in range(len(chunk)): - self.rx.write(chunk[i : i + 1]) - self.engine.state(self.rx, self.tx) - if self.engine.state is None: - break - - self.assertEqual(self.engine.status_code, 400) - self.assertEqual(self.engine.state, None) + with self.assertRaises(self.http_module.InvalidContentLength): + chunk = b"2\r\nchunking\r\n" + for i in range(len(chunk)): + self.rx.write(chunk[i : i + 1]) + self.engine.state(self.rx) def test_chunked_transfer_encoding_chunk_incomplete(self): self.engine.url = b"/api/test" self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" self.engine.headers["transfer-encoding"] = "chunked" - self.engine.state = self.engine._recv_chunked_size_st + self.engine.state = self.engine._recv_chunk_size_st test_callback = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_callback, "GET") @@ -420,7 +412,7 @@ def test_chunked_transfer_encoding_chunk_incomplete(self): chunk = b"FF\r\nchunking\r\n" for i in range(len(chunk)): self.rx.write(chunk[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) if self.engine.state is None: break @@ -510,7 +502,7 @@ def test_multipart_parser(self): {"content-type": 'multipart/form-data;boundary=missing-quote"'}, ]: with self.subTest(headers=case): - with self.assertRaises(self.http_module.HeaderParsingError): + with self.assertRaises(self.http_module.InvalidHeaders): self.engine._get_mp_boundary(case) def test_multipart_receiver_valid(self): @@ -521,7 +513,7 @@ def test_multipart_receiver_valid(self): for i in range(len(body_part)): self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.state, self.engine._parse_boundary_st) self.assertEqual(self.rx.peek(), b"Content-Type:text/plain") @@ -533,15 +525,10 @@ def test_multipart_receiver_boundary_mismatch(self): self.engine.mp_boundary = b"test-boundary" body_part = b"--test-boundary-delimiter\r\nContent-Type:text/plain" - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx, self.tx) - if self.engine.state is None: - break - - self.assertEqual(self.engine.state, None) - self.assertEqual(self.engine.status_code, 400) - self.assertEqual(self.rx.peek(), b"--test-boundary-delimiter\r\n") + with self.assertRaises(self.http_module.MalformedRequest): + for i in range(len(body_part)): + self.rx.write(body_part[i : i + 1]) + self.engine.state(self.rx) def test_multipart_receiver_complete_part(self): self.engine.state = self.engine._parse_boundary_st @@ -566,13 +553,13 @@ def test_multipart_receiver_complete_part(self): for i in range(len(body_part)): self.assertEqual(self.engine.state, self.engine._parse_boundary_st) self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) self.assertEqual(self.rx.peek(), body_part) self.assertEqual(self.engine.mp_is_first, True) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.state, self.engine._parse_boundary_st) test_callback.assert_called_once_with( @@ -611,12 +598,12 @@ def test_multipart_receiver_last_part(self): for i in range(len(body_part)): self.assertEqual(self.engine.state, self.engine._parse_boundary_st) self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) self.assertEqual(self.rx.peek(), body_part) - self.engine.state(self.rx, self.tx) + self.engine.state(self.rx) self.assertEqual(self.engine.state, None) self.assertEqual(self.engine.status_code, 200) @@ -666,7 +653,7 @@ def test_file_serving_missing_file(self, *_): self.engine.version = b"HTTP/1.1" self.engine.state = self.engine._send_file_st - self.engine.state(self.rx, self.tx, self.engine.url) + self.engine.state(self.rx, self.engine.url) self.assertEqual(self.engine.status_code, 404) self.assertEqual(self.engine.state, None) @@ -680,10 +667,10 @@ def test_file_serving_root(self, *_): file_content = "index content" with patch("builtins.open", mock_open(read_data=file_content)) as m: - response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + self.engine.state(self.rx, self.engine.url) m.assert_called_once_with("/www/index.html", "rb") - self.assertEqual(response_generator.read(), file_content) + self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) @@ -696,10 +683,10 @@ def test_file_serving_files_endpoint(self, *_): file_content = "data" with patch("builtins.open", mock_open(read_data=file_content)) as m: - response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + self.engine.state(self.rx, self.engine.url) m.assert_called_once_with("/www/scripts.js", "rb") - self.assertEqual(response_generator.read(), file_content) + self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) @@ -712,10 +699,10 @@ def test_file_serving_known_content_type(self, *_): file_content = "data" with patch("builtins.open", mock_open(read_data=file_content)) as m: - response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + self.engine.state(self.rx, self.engine.url) m.assert_called_once_with("/www/scripts.js", "rb") - self.assertEqual(response_generator.read(), file_content) + self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual( self.engine._lookup(self.engine.resp_headers, b"content-type"), b"application/javascript", @@ -732,10 +719,10 @@ def test_file_serving_fallback_content_type(self, *_): file_content = "data" with patch("builtins.open", mock_open(read_data=file_content)) as m: - response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + self.engine.state(self.rx, self.engine.url) m.assert_called_once_with("/www/scripts.unknown", "rb") - self.assertEqual(response_generator.read(), file_content) + self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual( self.engine._lookup(self.engine.resp_headers, b"content-type"), b"application/octet-stream", @@ -752,10 +739,10 @@ def test_file_serving_unserved_content_rejected(self, *_): file_content = "data" with patch("builtins.open", mock_open(read_data=file_content)) as m: - response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + self.engine.state(self.rx, self.engine.url) m.assert_not_called() - self.assertEqual(response_generator, None) + self.assertEqual(self.engine.resp_handler, None) self.assertEqual(self.engine.status_code, 403) self.assertEqual(self.engine.state, None)