diff --git a/Makefile b/Makefile
index 5da1e69..016ada9 100644
--- a/Makefile
+++ b/Makefile
@@ -101,7 +101,7 @@ deploy-config:
# -----------------------------
-# Deploy index page # TODO use install_www from assets module
+# Deploy index page
# -----------------------------
.PHONY: deploy-www
deploy-www:
diff --git a/README.md b/README.md
index de54e56..17b1a3c 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,7 @@ A lightweight HTTP server library for MicroPython designed for constrained embed
- Finite-state-machine parser with linear sliding buffer
- Robust byte-stream handling
- Query parameter parsing with percent encoding support
+- Persistent connections (set by **connection: keep-alive**)
- TLS support
# Installation
diff --git a/assets/www/examples.html b/assets/www/examples.html
index e1a2ca6..95a1a4d 100644
--- a/assets/www/examples.html
+++ b/assets/www/examples.html
@@ -92,7 +92,7 @@
Simple Server Application
if value_format == "%":
free = round(100 * free / (free + mem_alloc()), 2)
- return "text/plain", (f"Free memory [{value_format}]: {free}\n")
+ return "text/plain", f"Free memory [{value_format}]: {free}\n"
async def run_server():
server = HttpServer()
diff --git a/docs/dimensioning/http_dimensioning.md b/docs/dimensioning/http_dimensioning.md
index e2a91bd..daf2d33 100644
--- a/docs/dimensioning/http_dimensioning.md
+++ b/docs/dimensioning/http_dimensioning.md
@@ -28,19 +28,19 @@ with no active network traffic.
| id | http_mem_cap | http_multipart | socket_max_con | tls | footprint_bytes |
| --- | --- | --- | --- | --- | --- |
-| base | 0.05 | False | 1 | False | 38448 |
-| low_mem_cap_001 | 0.0127 | False | 1 | False | 38448 |
-| low_mem_cap_002 | 0.0253 | False | 2 | False | 39664 |
-| low_mem_cap_003 | 0.0505 | False | 4 | False | 42096 |
-| high_mem_cap_001 | 0.0568 | False | 1 | False | 45616 |
-| high_mem_cap_002 | 0.114 | False | 2 | False | 54000 |
-| high_mem_cap_003 | 0.228 | False | 4 | False | 70768 |
-| multipart_001 | 0.0127 | True | 1 | False | 40704 |
-| multipart_002 | 0.0253 | True | 2 | False | 41920 |
-| multipart_003 | 0.0505 | True | 4 | False | 44352 |
-| tls_001 | 0.0127 | False | 1 | True | 41120 |
-| tls_002 | 0.0253 | False | 2 | True | 42336 |
-| tls_003 | 0.0505 | False | 4 | True | 44768 |
+| base | 0.05 | False | 1 | False | 38512 |
+| low_mem_cap_001 | 0.0127 | False | 1 | False | 38512 |
+| low_mem_cap_002 | 0.0253 | False | 2 | False | 39728 |
+| low_mem_cap_003 | 0.0505 | False | 4 | False | 42112 |
+| high_mem_cap_001 | 0.0568 | False | 1 | False | 45680 |
+| high_mem_cap_002 | 0.114 | False | 2 | False | 54064 |
+| high_mem_cap_003 | 0.228 | False | 4 | False | 70832 |
+| multipart_001 | 0.0127 | True | 1 | False | 40608 |
+| multipart_002 | 0.0253 | True | 2 | False | 41824 |
+| multipart_003 | 0.0505 | True | 4 | False | 44256 |
+| tls_001 | 0.0127 | False | 1 | True | 41184 |
+| tls_002 | 0.0253 | False | 2 | True | 42400 |
+| tls_003 | 0.0505 | False | 4 | True | 44832 |
## Heap usage under network traffic

@@ -60,9 +60,9 @@ with no active network traffic.
| high_mem_cap_001 | 0.00111 | False | 1 | False | 45520 |
| high_mem_cap_002 | 0.00222 | False | 2 | False | 53904 |
| high_mem_cap_003 | 0.00443 | False | 4 | False | 70672 |
-| multipart_001 | 0.000247 | True | 1 | False | 40656 |
-| multipart_002 | 0.000493 | True | 2 | False | 41872 |
-| multipart_003 | 0.000985 | True | 4 | False | 44304 |
+| multipart_001 | 0.000247 | True | 1 | False | 40528 |
+| multipart_002 | 0.000493 | True | 2 | False | 41744 |
+| multipart_003 | 0.000985 | True | 4 | False | 44176 |
| tls_001 | 0.000247 | False | 1 | True | 40736 |
| tls_002 | 0.000493 | False | 2 | True | 41952 |
| tls_003 | 0.000985 | False | 4 | True | 44384 |
diff --git a/docs/dimensioning/img/esp32_c3/base.png b/docs/dimensioning/img/esp32_c3/base.png
index 75fa126..1686010 100644
Binary files a/docs/dimensioning/img/esp32_c3/base.png and b/docs/dimensioning/img/esp32_c3/base.png differ
diff --git a/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png b/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png
index 9ae01ff..a4bf50e 100644
Binary files a/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png and b/docs/dimensioning/img/esp32_c3/high_mem_cap_001.png differ
diff --git a/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png b/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png
index e0af9fa..c57115b 100644
Binary files a/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png and b/docs/dimensioning/img/esp32_c3/high_mem_cap_002.png differ
diff --git a/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png b/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png
index a666ee7..6f68f16 100644
Binary files a/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png and b/docs/dimensioning/img/esp32_c3/high_mem_cap_003.png differ
diff --git a/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png b/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png
index a9285ae..ebb7766 100644
Binary files a/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png and b/docs/dimensioning/img/esp32_c3/low_mem_cap_001.png differ
diff --git a/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png b/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png
index 35fbf73..54fb0e2 100644
Binary files a/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png and b/docs/dimensioning/img/esp32_c3/low_mem_cap_002.png differ
diff --git a/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png b/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png
index 1c58314..96edeb7 100644
Binary files a/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png and b/docs/dimensioning/img/esp32_c3/low_mem_cap_003.png differ
diff --git a/docs/dimensioning/img/esp32_c3/multipart_001.png b/docs/dimensioning/img/esp32_c3/multipart_001.png
index 08e0f1a..1e249e5 100644
Binary files a/docs/dimensioning/img/esp32_c3/multipart_001.png and b/docs/dimensioning/img/esp32_c3/multipart_001.png differ
diff --git a/docs/dimensioning/img/esp32_c3/multipart_002.png b/docs/dimensioning/img/esp32_c3/multipart_002.png
index fd54765..0456481 100644
Binary files a/docs/dimensioning/img/esp32_c3/multipart_002.png and b/docs/dimensioning/img/esp32_c3/multipart_002.png differ
diff --git a/docs/dimensioning/img/esp32_c3/multipart_003.png b/docs/dimensioning/img/esp32_c3/multipart_003.png
index 046c0a3..2b4fcf9 100644
Binary files a/docs/dimensioning/img/esp32_c3/multipart_003.png and b/docs/dimensioning/img/esp32_c3/multipart_003.png differ
diff --git a/docs/dimensioning/img/esp32_c3/tls_001.png b/docs/dimensioning/img/esp32_c3/tls_001.png
index 9d933c0..bb465f3 100644
Binary files a/docs/dimensioning/img/esp32_c3/tls_001.png and b/docs/dimensioning/img/esp32_c3/tls_001.png differ
diff --git a/docs/dimensioning/img/esp32_c3/tls_002.png b/docs/dimensioning/img/esp32_c3/tls_002.png
index a69e17d..0755d09 100644
Binary files a/docs/dimensioning/img/esp32_c3/tls_002.png and b/docs/dimensioning/img/esp32_c3/tls_002.png differ
diff --git a/docs/dimensioning/img/esp32_c3/tls_003.png b/docs/dimensioning/img/esp32_c3/tls_003.png
index 5c6b128..603ace0 100644
Binary files a/docs/dimensioning/img/esp32_c3/tls_003.png and b/docs/dimensioning/img/esp32_c3/tls_003.png differ
diff --git a/docs/dimensioning/img/esp32_s3/base.png b/docs/dimensioning/img/esp32_s3/base.png
index b814774..858f288 100644
Binary files a/docs/dimensioning/img/esp32_s3/base.png and b/docs/dimensioning/img/esp32_s3/base.png differ
diff --git a/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png b/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png
index 81887d2..24d589f 100644
Binary files a/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png and b/docs/dimensioning/img/esp32_s3/high_mem_cap_001.png differ
diff --git a/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png b/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png
index 664932a..9312ab8 100644
Binary files a/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png and b/docs/dimensioning/img/esp32_s3/high_mem_cap_002.png differ
diff --git a/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png b/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png
index c81a08a..38c74e1 100644
Binary files a/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png and b/docs/dimensioning/img/esp32_s3/high_mem_cap_003.png differ
diff --git a/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png b/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png
index fc48588..06d7e08 100644
Binary files a/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png and b/docs/dimensioning/img/esp32_s3/low_mem_cap_001.png differ
diff --git a/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png b/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png
index 45dd764..563b6bd 100644
Binary files a/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png and b/docs/dimensioning/img/esp32_s3/low_mem_cap_002.png differ
diff --git a/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png b/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png
index a1f9643..a9d5f31 100644
Binary files a/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png and b/docs/dimensioning/img/esp32_s3/low_mem_cap_003.png differ
diff --git a/docs/dimensioning/img/esp32_s3/multipart_001.png b/docs/dimensioning/img/esp32_s3/multipart_001.png
index b3b9817..25b13a5 100644
Binary files a/docs/dimensioning/img/esp32_s3/multipart_001.png and b/docs/dimensioning/img/esp32_s3/multipart_001.png differ
diff --git a/docs/dimensioning/img/esp32_s3/multipart_002.png b/docs/dimensioning/img/esp32_s3/multipart_002.png
index 6e04077..fddf31e 100644
Binary files a/docs/dimensioning/img/esp32_s3/multipart_002.png and b/docs/dimensioning/img/esp32_s3/multipart_002.png differ
diff --git a/docs/dimensioning/img/esp32_s3/multipart_003.png b/docs/dimensioning/img/esp32_s3/multipart_003.png
index a799ddc..3ffe4fe 100644
Binary files a/docs/dimensioning/img/esp32_s3/multipart_003.png and b/docs/dimensioning/img/esp32_s3/multipart_003.png differ
diff --git a/docs/dimensioning/img/esp32_s3/tls_001.png b/docs/dimensioning/img/esp32_s3/tls_001.png
index af9ea15..190a2a3 100644
Binary files a/docs/dimensioning/img/esp32_s3/tls_001.png and b/docs/dimensioning/img/esp32_s3/tls_001.png differ
diff --git a/docs/dimensioning/img/esp32_s3/tls_002.png b/docs/dimensioning/img/esp32_s3/tls_002.png
index 91aefdc..6b0fc45 100644
Binary files a/docs/dimensioning/img/esp32_s3/tls_002.png and b/docs/dimensioning/img/esp32_s3/tls_002.png differ
diff --git a/docs/dimensioning/img/esp32_s3/tls_003.png b/docs/dimensioning/img/esp32_s3/tls_003.png
index fb13661..cbfc67d 100644
Binary files a/docs/dimensioning/img/esp32_s3/tls_003.png and b/docs/dimensioning/img/esp32_s3/tls_003.png differ
diff --git a/example/demo_app/app.py b/example/demo_app/app.py
index 28ba282..937b89f 100644
--- a/example/demo_app/app.py
+++ b/example/demo_app/app.py
@@ -20,7 +20,7 @@ def app(http_ctx, payload):
if value_format == "%":
free = round(100 * free / (free + mem_alloc()), 2)
- return "text/plain", (f"Free memory [{value_format}]: {free}\n")
+ return "text/plain", f"Free memory [{value_format}]: {free}\n"
async def main():
diff --git a/src/pyrobusta/bindings/http_connection.py b/src/pyrobusta/bindings/http_connection.py
index 533dfc3..0d9b79f 100644
--- a/src/pyrobusta/bindings/http_connection.py
+++ b/src/pyrobusta/bindings/http_connection.py
@@ -7,7 +7,7 @@
from ..stream.buffer import BufferFullError
from ..transport.connection import BaseConnection
-from ..protocol.http import HttpEngine, ServerBusyError, HeaderParsingError
+from ..protocol.http import HttpEngine
from ..utils import logging
@@ -19,7 +19,6 @@ class HttpConnection(BaseConnection):
MTU_SIZE = 1460
STATE_MACHINE_SLEEP_MS = 2
- RESP_HANDLER_SLEEP_MS = 2
RECV_TIMEOUT_SECONDS = 10
__slots__ = ("_engine", "_prev_state", "_recv_buf", "_send_buf")
@@ -36,9 +35,12 @@ async def run(self):
Handle socket connection with HTTP state machine parser.
"""
self._prev_state = None
- while self._engine.state is not None:
+ while not self._engine.is_terminated():
await self._run_state_machine()
await sleep_ms(self.STATE_MACHINE_SLEEP_MS)
+ if self._engine.is_terminated() and self._engine.do_keep_alive():
+ self._engine.reset()
+ self._prev_state = None
async def _flush_response(self):
data = self._send_buf.peek()
@@ -49,67 +51,48 @@ async def _flush_response(self):
async def _read_to_buf(self):
buf_free = self._recv_buf.capacity - self._recv_buf.size()
if not buf_free:
- self._engine.on_buffer_full(self._send_buf)
- await self._flush_response()
- return 0
- try:
- request = await self.read(
- read_bytes=buf_free,
- decoding=None,
- timeout_seconds=self.RECV_TIMEOUT_SECONDS,
- )
- except asyncio.TimeoutError:
- self._engine.on_timeout(self._send_buf)
- await self._flush_response()
- return 0
- except Exception as e: # pylint: disable=W0718
- self._engine.on_failure(
- self._send_buf, b"Read error: " + str(e).encode("ascii")
- )
- await self._flush_response()
- return 0
+ raise BufferFullError()
+ request = await self.read(
+ read_bytes=buf_free,
+ decoding=None,
+ timeout_seconds=self.RECV_TIMEOUT_SECONDS,
+ )
self._recv_buf.write(request)
logging.debug(__name__ + f"._read_to_buf: [{request}]")
return len(request)
async def _run_state_machine(self):
- if self._prev_state == self._engine.state or self._prev_state is None:
- num_read = await self._read_to_buf()
- if not num_read:
- # Reject incomplete request
- self._engine.on_client_error(
- self._send_buf, self._engine.BAD_REQUEST_ERROR
- )
- await self._flush_response()
- return
- try:
- resp_handler = None
- while self._engine.state is not None:
- self._prev_state = self._engine.state
- resp_handler = self._engine.state(self._recv_buf, self._send_buf)
- if not self._send_buf.size():
- break
- await self._flush_response()
- await sleep_ms(self.STATE_MACHINE_SLEEP_MS)
- except BufferFullError:
- self._engine.on_failure(self._send_buf, b"Buffer full")
- await self._flush_response()
- return
- except ServerBusyError:
- self._engine.on_unavailable(self._send_buf)
- await self._flush_response()
- return
- except HeaderParsingError:
- self._engine.on_client_error(self._send_buf, self._engine.HEADER_ERROR)
- await self._flush_response()
- return
- except Exception as e: # pylint: disable=W0718
- logging.warning(__name__ + f"._run_state_machine: {e}")
- self._engine.on_failure(self._send_buf, str(e).encode("ascii"))
+ # [1] read request
+ if self._prev_state == self._engine.state or (
+ self._prev_state is None and not self._recv_buf.size()
+ ):
+ try:
+ num_read = await self._read_to_buf()
+ if not num_read:
+ self._engine.abort(400)
+ self._engine.set_response_body(b"Incomplete request")
+ except BufferFullError:
+ self._engine.abort(413)
+ except asyncio.TimeoutError:
+ self._engine.abort(408)
+ except Exception as e: # pylint: disable=W0718
+ self._engine.abort(500)
+ self._engine.set_response_body(b"Read error: " + str(e).encode("ascii"))
+
+ # [2] process request by state machine
+ for _ in self._engine.run(self._recv_buf):
+ if self._prev_state == self._engine.state:
+ # No state transition occurred, read more data
+ break
+ self._prev_state = self._engine.state
+ await sleep_ms(self.STATE_MACHINE_SLEEP_MS)
+
+ # [3] write response
+ if self._engine.is_started() and self._engine.is_terminated():
+ self._engine.write_response_head(self._send_buf)
await self._flush_response()
- return
- if self._engine.state is None and resp_handler is not None:
- await self._response_handler(resp_handler)
+ if self._engine.resp_handler is not None:
+ await self._response_handler(self._engine.resp_handler)
async def _response_handler(self, resp_handler):
if "closure" == type(resp_handler).__name__:
@@ -117,22 +100,20 @@ async def _response_handler(self, resp_handler):
await self._flush_response()
if is_finished:
break
- await sleep_ms(self.RESP_HANDLER_SLEEP_MS)
+ await sleep_ms(self.STATE_MACHINE_SLEEP_MS)
elif type(resp_handler).__name__ in ("FileIO", "BytesIO"):
- with resp_handler as rh:
+ try:
while True:
view = self._send_buf.writable_view()
- num_read = rh.readinto(view)
+ num_read = resp_handler.readinto(view)
if not num_read:
break
self._send_buf.commit(num_read)
await self._flush_response()
- await sleep_ms(self.RESP_HANDLER_SLEEP_MS)
+ await sleep_ms(self.STATE_MACHINE_SLEEP_MS)
+ finally:
+ resp_handler.close()
else:
- self._engine.on_failure(
- self._send_buf,
- f"Invalid response handler {type(resp_handler).__name__}".encode(
- "ascii"
- ),
+ raise RuntimeError(
+ f"Invalid response handler {type(resp_handler).__name__}"
)
- await self._flush_response()
diff --git a/src/pyrobusta/connectivity/wifi.py b/src/pyrobusta/connectivity/wifi.py
index bc27030..f907add 100644
--- a/src/pyrobusta/connectivity/wifi.py
+++ b/src/pyrobusta/connectivity/wifi.py
@@ -1,5 +1,5 @@
"""
-Helpers for setting up Wi-Fi in station mode
+Helpers for setting up Wi-Fi in station mode.
"""
from time import sleep
@@ -12,7 +12,7 @@
def initialize():
"""
- Initialize WLAN interface in station mode
+ Initialize WLAN station interface.
"""
ssid = get_config(CONF_WIFI_SSID)
password = get_config(CONF_WIFI_PASSWORD)
@@ -44,7 +44,7 @@ def initialize():
def get_address():
"""
- Get the address of the WLAN interface
+ Get the IP address of the WLAN station.
"""
sta_if = WLAN(STA_IF)
if sta_if.isconnected():
diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py
index fbd6a45..93d55ab 100644
--- a/src/pyrobusta/protocol/http.py
+++ b/src/pyrobusta/protocol/http.py
@@ -1,5 +1,5 @@
"""
-Module is responsible for webserver state machine,
+This module is responsible HTTP protocol parsing
with partial guarantees on RFC compliance.
"""
@@ -12,32 +12,49 @@
CONF_HTTP_MULTIPART,
CONF_HTTP_SERVE_FILES,
)
+from ..utils import logging
+from ..stream.buffer import BufferFullError
-class HeaderParsingError(ValueError):
- """Exception for errors occurring while parsing HTTP/MIME headers"""
+class InvalidHeaders(ValueError):
+ """Exception for errors occurring while parsing HTTP/MIME headers."""
pass
-class ServerBusyError(RuntimeError):
- """Exception for applications to indicate busy state"""
+class InvalidContentLength(ValueError):
+ """Exception for content-length related erros."""
+
+ pass
+
+
+class MalformedRequest(ValueError):
+ """Exception for malformed requests."""
pass
class HttpEngine:
"""
- HTTP protocol parser state machine
- - provides an adapter/routing layer
- - supports multipart request and response handling
- - resolves static resources by returning a stream objects (FileIO)
+ HTTP protocol parser state machine and middleware.
+ - each instance represents a connection, allowing a request to be parsed
+ through a state machine
+ - provides an adapter/routing layer for applications
+ through registered endpoints (see also: register(), route())
+ - supports percent encoded URLs and query parameters (x-www-form-urlencoded)
+ - allows applications to set response attributes (headers, status code)
+
+ Feature flags (configured in pyrobusta.env)
+ - http_serve_files: serve files stored on the device
+ - http_multipart: support for multipart requests/responses
"""
__slots__ = (
"state",
"status_code",
"resp_headers",
+ "resp_handler",
+ "aborted",
"version",
"headers",
"method",
@@ -102,9 +119,6 @@ class HttpEngine:
b"image/gif",
)
- ASCII = "ascii"
- CONTENT_LENGTH = "content-length"
-
DELETE = b"DELETE"
GET = b"GET"
HEAD = b"HEAD"
@@ -114,18 +128,13 @@ class HttpEngine:
METHODS = (DELETE, GET, HEAD, OPTIONS, POST, PUT)
SUPPORTED_VERSIONS = (b"HTTP/1.1", b"HTTP/1.0")
- MULTIPART_BOUNDARY = b"pyrobusta-boundary"
-
- CONTENT_LENGTH_ERROR = b"content length mismatch"
- HEADER_ERROR = b"Invalid headers"
- MULTIPART_BOUNDARY_ERROR = b"Invalid multipart boundary"
- BAD_REQUEST_ERROR = b"Bad request"
-
def __init__(self):
# [State machine]
- self.state = self._parse_request_line_st
+ self.state = self._start_parser
self.status_code = None
self.resp_headers = []
+ self.resp_handler = None
+ self.aborted = False
# [Recived request]
self.version = None
@@ -139,22 +148,38 @@ def __init__(self):
# [Multipart state]
self.mp_boundary = None
+ def reset(self):
+ """
+ Reset internal state to reuse a state machine object.
+ """
+ self.state = self._start_parser
+ self.status_code = None
+ self.resp_headers.clear()
+ self.resp_handler = None
+ self.aborted = False
+ self.version = None
+ self.headers.clear()
+ self.method = None
+ self.url = None
+ self.query = None
+ self.content_len_cnt = 0
+ self.recv_chunk_size = 0
+ self.mp_boundary = None
+
# =========================================
# Methods/decorators for routing
# =========================================
@classmethod
- def register(
- cls, endpoint: str, callback: object | str, method: str = "GET"
- ) -> None:
+ def register(cls, endpoint: str, callback: callable, method: str = "GET") -> None:
"""
- Register an endpoint with a callback function
- :param endpoint: name of the endpoint
- :param callback: callback function
+ Register an endpoint with a callback function or file.
+ :param endpoint: URL path to be routed e.g. "/app/resource"
+ :param callback: callback function or file path
:param method: HTTP method name
"""
- endpoint = endpoint.encode(cls.ASCII)
- method = method.encode(cls.ASCII)
+ endpoint = endpoint.encode("ascii")
+ method = method.encode("ascii")
endpoint_exists = cls._get_callback(endpoint, method) is not None
if method not in cls.METHODS:
@@ -167,6 +192,8 @@ def register(
def route(endpoint: str, method: str):
"""
Decorator for registering endpoint callback functions.
+ :param endpoint: URL path to be routed e.g. "/app/resource"
+ :param method: HTTP method name
"""
def decorator(func):
@@ -181,7 +208,9 @@ def decorator(func):
@staticmethod
def percent_decode(s: str):
- """Decode percent-encoded input"""
+ """
+ Decode percent-encoded input.
+ """
out = []
i = 0
while i < len(s):
@@ -196,8 +225,8 @@ def percent_decode(s: str):
@staticmethod
def get_url_encoded_query_param(query: str, key: str, default: str = None):
"""
- Parse query and return the value belonging to a key
- according to x-www-form-urlencoded
+ Parse a query and return the value belonging to a key
+ according to the x-www-form-urlencoded format.
:param query: query part
:param key: key to parse from the query
:param default: default value to return when key is not present
@@ -220,7 +249,7 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None):
@classmethod
def is_norm_path_served(cls, path: str):
"""
- Returns true if a normalized path is configured to be served
+ Returns true if a normalized path is configured to be served.
"""
served_paths = get_config(CONF_HTTP_SERVED_PATHS)
parts = path.split("/")
@@ -238,20 +267,20 @@ def _lookup(tuple_, key):
return tuple_[idx + 1]
@classmethod
- def _get_callback(cls, endpoint, method):
+ def _get_callback(cls, endpoint, method: bytes):
for e in cls.ENDPOINTS:
if endpoint == e[0] and method == e[2]:
return e[1]
@classmethod
- def _has_endpoint(cls, endpoint):
+ def _has_endpoint(cls, endpoint: bytes):
for e in cls.ENDPOINTS:
if endpoint == e[0]:
return True
return False
@classmethod
- def _supported_methods(cls, endpoint):
+ def _supported_methods(cls, endpoint: bytes):
supported_methods = []
for method in cls.METHODS:
if cls._get_callback(endpoint, method) is not None:
@@ -261,20 +290,19 @@ def _supported_methods(cls, endpoint):
@classmethod
def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
"""
- Basic parser to extract HTTP/MIME headers
- :param raw_headers: headers
+ Basic parser to extract HTTP/MIME headers.
"""
header_lines = bytes(raw_headers).split(b"\r\n")
headers = {}
for line in header_lines:
# pylint: disable=W0511
if any(c > 127 for c in line):
- raise HeaderParsingError("Non-ASCII character")
+ raise InvalidHeaders("Non-ASCII character")
if b":" not in line:
- raise HeaderParsingError()
+ raise InvalidHeaders()
name, value = line.split(b":", 1)
if not name:
- raise HeaderParsingError("Empty header name")
+ raise InvalidHeaders("Empty header name")
for c in name:
if (
48 <= c <= 57 # 0-9
@@ -283,20 +311,23 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
or c in (45, 95) # -_
):
continue
- raise HeaderParsingError("Invalid header name")
- name = name.strip().lower().decode(cls.ASCII)
+ raise InvalidHeaders("Invalid header name")
+ name = name.strip().lower().decode("ascii")
if any((c < 32 and c != 9) or c == 127 for c in value):
- raise HeaderParsingError("Invalid header value")
- if name == cls.CONTENT_LENGTH:
+ raise InvalidHeaders("Invalid header value")
+ if name == "content-length":
value = int(value.strip())
else:
- value = value.strip().decode(cls.ASCII)
+ value = value.strip().decode("ascii")
headers[name] = value
return headers
@staticmethod
def _get_mp_boundary(headers: dict) -> str:
- """Determine from the headers if a request is multipart, and return the boundary value"""
+ """
+ Determine from the headers if a request is multipart,
+ and return the boundary value.
+ """
content_type = headers.get("content-type")
if not content_type or not content_type.lower().startswith(
"multipart/form-data"
@@ -315,26 +346,28 @@ def _get_mp_boundary(headers: dict) -> str:
if value.startswith('"'):
if len(value) < 2 or not value.endswith('"'):
- raise HeaderParsingError()
+ raise InvalidHeaders()
value = value[1:-1]
elif value.endswith('"'):
- raise HeaderParsingError()
+ raise InvalidHeaders()
if not value:
- raise HeaderParsingError()
+ raise InvalidHeaders()
return value
- raise HeaderParsingError()
+ raise InvalidHeaders()
@classmethod
def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]:
- """Parse part headers and body and return them as a tuple"""
+ """
+ Parse part headers and body and return them as a tuple.
+ """
blank_idx = -1
for i in range(len(part) - 3):
if part[i : i + 4] == b"\r\n\r\n":
blank_idx = i
break
if blank_idx == -1:
- raise HeaderParsingError()
+ raise InvalidHeaders()
headers = cls._parse_headers(part[:blank_idx])
body = part[blank_idx + 4 :]
return headers, body
@@ -343,7 +376,12 @@ def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]:
# Helpers for state machine termination
# =========================================
- def _set_response_header(self, key, value):
+ def set_response_header(self, key: bytes, value: bytes):
+ """
+ Set a response header by key and value.
+ :param key: HTTP header key
+ :param value: HTTP header value
+ """
if (
key in self.resp_headers
and (index := self.resp_headers.index(key) % 2) == 0
@@ -353,31 +391,15 @@ def _set_response_header(self, key, value):
self.resp_headers.append(key)
self.resp_headers.append(value)
- def terminate(self, status_code: int, content_type: bytes = b"text/plain"):
+ def write_response_head(self, tx):
"""
- Terminate state machine with status code and response content-type
- :param status_code: HTTP status code
- :param content_type: content-type of the response
+ Write response status and header to an output buffer.
+ :param tx: response buffer
"""
- self.state = None
- self.status_code = status_code
- if content_type:
- self._set_response_header(b"content-type", content_type)
- self._set_response_header(b"connection", b"close")
-
- def _write_response_head(self, tx, content_length: int = 0):
- """
- Write response status & header to the output,
- with optional content-length value
- """
- # Discard already accumulated content (e.g. 500 response on unexpected errors)
- tx.consume()
+ tx.consume() # Discard already accumulated content, required on abrupt errors
tx.write(self.version)
tx.write(b" ")
tx.write(self._lookup(self.RESP_HEADERS, self.status_code))
- if content_length is not None:
- tx.write(b"\r\n")
- tx.write(b"content-length: %s" % str(content_length).encode(self.ASCII))
for i in range(0, len(self.resp_headers), 2):
key = self.resp_headers[i]
value = self.resp_headers[i + 1]
@@ -387,143 +409,202 @@ def _write_response_head(self, tx, content_length: int = 0):
tx.write(value)
tx.write(b"\r\n\r\n")
- def _generate_response(self, tx, body: bytes | str | dict | tuple | list):
+ def set_response_body(
+ self,
+ body: bytes | str | dict | tuple | list,
+ content_type: bytes = b"text/plain",
+ ):
"""
- Write the complete response to the output, including status
- and headers. Return a BytesIO object if the content length
- exceeds the remaining buffer capacity, to delegate the writing
- of the response body to the transport layer.
+ Serialize and wrap the response body with a BytesIO
+ object, stored by the resp_handler member. resp_handler
+ can be used for writing the body by the transport layer.
+ This method also updates the content-type and content-length
+ headers.
+ :param body: body to be sent in the response
+ :param content_type: content-type of the body
"""
- if not body:
- self._write_response_head(tx, 0)
- body_encoded = b""
- elif isinstance(body, (bytes, bytearray, memoryview)):
- self._write_response_head(tx, len(body))
+ if body is None:
+ return
+ if isinstance(body, (bytes, bytearray, memoryview)):
body_encoded = body
elif isinstance(body, str):
body_encoded = body.encode()
- self._write_response_head(tx, len(body_encoded))
elif isinstance(body, (dict, tuple, list)):
body_encoded = dumps(body).encode()
- self._write_response_head(tx, len(body_encoded))
else:
- self.on_failure(tx, b"Unhandled body type")
- return
+ raise ValueError("Unhandled body type")
+ self.set_response_header(
+ b"content-length", str(len(body_encoded)).encode("ascii")
+ )
+ self.set_response_header(b"content-type", content_type)
if self.method != self.HEAD:
- if len(body_encoded) > tx.capacity - tx.size():
- return BytesIO(body_encoded)
- tx.write(body_encoded)
-
- def on_client_error(self, tx, info: bytes):
- """Terminate state machine and write 400 response"""
- self.terminate(400)
- response = info
- self._write_response_head(tx, len(response))
- tx.write(response)
-
- def on_forbidden(self, tx):
- """Terminate state machine and write 403 response"""
- self.terminate(403)
- self._write_response_head(tx)
-
- def on_missing_resource(self, tx):
- """Terminate state machine and write 404 response"""
- self.terminate(404)
- self._write_response_head(tx)
-
- def on_method_not_allowed(self, tx):
- """Terminate state machine and write 405 response"""
- self.terminate(405)
- self._write_response_head(tx)
-
- def on_timeout(self, tx):
- """Terminate state machine and write 408 response"""
- self.terminate(408)
- self._write_response_head(tx)
-
- def on_buffer_full(self, tx):
- """Terminate state machine and write 413 response"""
- self.terminate(413)
- self._write_response_head(tx)
-
- def on_failure(self, tx, info: bytes):
- """Terminate state machine and write 500 response"""
- self.terminate(500)
- self._write_response_head(tx, len(info))
- tx.write(info)
-
- def on_unavailable(self, tx):
- """Terminate state machine and write 503 response"""
- self.terminate(503)
- self._write_response_head(tx)
+ self.resp_handler = BytesIO(body_encoded)
+
+ def do_keep_alive(self):
+ """
+ Determine if the connection should be kept alive
+ depending on the HTTP version and headers sent in the request.
+ """
+ if self.aborted:
+ return False
+
+ connection_tokens = [
+ token.strip().lower()
+ for token in self.headers.get("connection", "").split(",")
+ ]
+ return (self.version == b"HTTP/1.0" and "keep-alive" in connection_tokens) or (
+ self.version == b"HTTP/1.1" and "close" not in connection_tokens
+ )
+
+ def terminate(self, status_code: int, request_complete: bool = False):
+ """
+ Regular state machine termination with a specific status code.
+ :param status_code: HTTP status code
+ :param request_complete: true if the complete request is processed
+ """
+ self.state = None
+ self.status_code = status_code
+
+ if self.version == b"HTTP/1.0" and self.do_keep_alive() and request_complete:
+ self.set_response_header(b"connection", b"keep-alive")
+ elif (
+ self.version == b"HTTP/1.1"
+ and not self.do_keep_alive()
+ and not request_complete
+ ):
+ self.set_response_header(b"connection", b"close")
+
+ def abort(self, status_code: int):
+ """
+ Abort state machine due to runtime errors.
+ Reset any header or response body set earlier.
+ :param status_code: HTTP status code
+ """
+ self.aborted = True
+ self.resp_headers = []
+ if type(self.resp_handler).__name__ in ("FileIO", "BytesIO"):
+ self.resp_handler.close()
+ self.resp_handler = None
+ self.terminate(status_code, False)
- def on_unsupported_version(self, tx):
- """Terminate state machine and write 505 response"""
- self.terminate(505)
- self._write_response_head(tx)
+ def is_started(self):
+ """
+ Returns true if the state machine has received any input.
+ """
+ return self.state != self._start_parser # pylint: disable=W0143
+
+ def is_terminated(self):
+ """
+ Returns true if the state machine is terminated.
+ """
+ return self.state is None and self.status_code
+
+ def run(self, rx):
+ """
+ Run the state machine with request buffers provided.
+ Unlike individual states, this method does not raise an exception.
+ This method yields on every state transition allowing the calling side
+ to flush the response buffer.
+ """
+ if self.is_terminated():
+ return
+ try:
+ while not self.is_terminated():
+ self.state(rx)
+ yield
+ except BufferFullError:
+ self.abort(500)
+ self.set_response_body(b"Buffer full")
+ except InvalidHeaders:
+ self.abort(400)
+ self.set_response_body(b"Invalid headers")
+ except InvalidContentLength:
+ self.abort(400)
+ self.set_response_body(b"Content length mismatch")
+ except MalformedRequest:
+ self.abort(400)
+ self.set_response_body(b"Malformed request")
+ except Exception as e: # pylint: disable=W0718
+ logging.warning(__name__ + f"._run_state_machine: {e}")
+ self.abort(500)
+ self.set_response_body(str(e).encode("ascii"))
+
+ # ========================================
+ # Helpers for routing, state machine logic
+ # ========================================
+
+ def is_chunked(self):
+ """
+ Determines if the request has a payload with chunked transfer-encoding.
+ """
+ return self.headers.get("transfer-encoding") == "chunked"
+
+ def has_payload(self):
+ """
+ Determines if the request has a body.
+ """
+ return (
+ "content-length" in self.headers and self.headers["content-length"] > 0
+ ) or self.is_chunked()
# ================================================================================
# Parser states
- # - all states must handle rx and tx buffer arguments for reading and writing data
+ # - all states must handle rx buffer argument for reading request data
# - mandatory methods/attributes of rx: find(), peek(), consume(), size()
- # - mandatory methods/attributes of tx: capacity, consume(), write(), size()
- # - rx/tx reference implementation: SlidingBuffer (pyrobusta.stream.buffer)
+ # - reference implementation: SlidingBuffer (pyrobusta.stream.buffer)
# ================================================================================
- def _parse_request_line_st(self, rx, tx):
- """State for parsing the request line"""
+ def _start_parser(self, rx):
+ """
+ Initial state.
+ """
+ if rx.size():
+ self.state = self._parse_request_line_st
+
+ def _parse_request_line_st(self, rx):
+ """
+ Parse the request line.
+ """
status_line_sep = rx.find(b"\r\n")
if status_line_sep == -1:
return
status_parts = bytes(rx.peek(status_line_sep)).split()
if len(status_parts) != 3:
- self.on_client_error(tx, self.BAD_REQUEST_ERROR)
- return
+ raise MalformedRequest()
self.method = status_parts[0]
url_parts = status_parts[1].split(b"?", 1)
self.url = url_parts[0]
self.query = (
""
if len(url_parts) == 1
- else self.percent_decode(url_parts[1].decode(self.ASCII))
+ else self.percent_decode(url_parts[1].decode("ascii"))
)
self.version = status_parts[2]
if self.method not in self.METHODS:
- self.on_method_not_allowed(tx)
+ self.terminate(405)
return
if self.version not in self.SUPPORTED_VERSIONS:
- self.on_unsupported_version(tx)
+ self.terminate(505)
return
rx.consume(status_line_sep + 2)
self.state = self._parse_headers_st
- def _parse_headers_st(self, rx, tx):
- """State for parsing headers"""
+ def _parse_headers_st(self, rx):
+ """
+ Parse HTTP headers.
+ """
if (blank_idx := rx.find(b"\r\n\r\n")) == -1:
return
- try:
- self.headers = self._parse_headers(rx.peek(blank_idx))
- if self.version == b"HTTP/1.1" and "host" not in self.headers:
- raise HeaderParsingError()
- except HeaderParsingError:
- self.on_client_error(tx, self.HEADER_ERROR)
- return
+ self.headers = self._parse_headers(rx.peek(blank_idx))
+ if self.version == b"HTTP/1.1" and "host" not in self.headers:
+ raise InvalidHeaders()
rx.consume(blank_idx + 4)
self.state = self._route_request_st
- def _is_chunked(self):
- return self.headers.get("transfer-encoding") == "chunked"
-
- def _has_payload(self):
- return (
- self.CONTENT_LENGTH in self.headers
- and self.headers[self.CONTENT_LENGTH] > 0
- ) or self._is_chunked()
-
- def _route_request_st(self, _, tx):
+ def _route_request_st(self, _):
"""
- State for routing requests
- - supported ways: static resources, endpoint callback functions
+ Route requests based on registered endpoints.
+ If no endpoint is registered, fall back to file serving.
"""
if self._has_endpoint(self.url) and (
self._get_callback(self.url, self.method) is not None
@@ -535,114 +616,132 @@ def _route_request_st(self, _, tx):
):
if self.method == self.OPTIONS:
supported_methods = self._supported_methods(self.url)
- self._set_response_header(b"allow", b", ".join(supported_methods))
- self.terminate(204, None)
- self._write_response_head(tx, None)
+ self.set_response_header(b"allow", b", ".join(supported_methods))
+ self.terminate(204, True)
return
- if self._has_payload():
+ if self.has_payload():
if self.method == self.HEAD:
- self.on_client_error(tx, self.BAD_REQUEST_ERROR)
- return
+ raise MalformedRequest()
if mp_boundary := self._get_mp_boundary(self.headers):
- self.mp_boundary = mp_boundary.encode(self.ASCII)
+ # Request body is multipart
+ self.mp_boundary = mp_boundary.encode("ascii")
self.state = self._start_multipart_parser_st
- elif self._is_chunked():
- if self.CONTENT_LENGTH in self.headers:
- self.on_client_error(tx, self.BAD_REQUEST_ERROR)
- return
- self.state = self._recv_chunked_size_st
+ elif self.is_chunked():
+ # Request body is chunked
+ if "content-length" in self.headers:
+ raise MalformedRequest()
+ self.state = self._recv_chunk_size_st
else:
self.state = self._recv_payload_st
else:
self.state = self._app_endpoint_st
return
+ # Request does not have a registered endpoint
if (
self._has_endpoint(self.url)
and self._get_callback(self.method, self.url) is None
):
supported_methods = self._supported_methods(self.url)
- self._set_response_header(b"allow", b", ".join(supported_methods))
- self.on_method_not_allowed(tx)
+ self.set_response_header(b"allow", b", ".join(supported_methods))
+ self.terminate(405)
return
+ # Fallback: serve file
if self.method in (self.GET, self.HEAD):
- self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, self.url)
+ self.state = lambda _rx: self._send_file_st(_rx, self.url)
return
- self.on_missing_resource(tx)
+ self.terminate(404)
- def _recv_chunked_size_st(self, rx, _):
+ def _recv_chunk_size_st(self, rx):
+ """
+ State for determining the chunk size (transfer-encoding: chunked).
+ """
if (blank_idx := rx.find(b"\r\n")) == -1:
return
self.recv_chunk_size = int(bytes(rx.peek(blank_idx)), 16)
+ if self.recv_chunk_size < 0:
+ raise InvalidContentLength()
rx.consume(blank_idx + 2)
self.state = self._recv_chunk_st
- def _recv_chunk_st(self, rx, tx):
+ def _recv_chunk_st(self, rx):
+ """
+ State for receiving a complete chunk (transfer-encoding: chunked).
+ """
if self.recv_chunk_size + 2 > rx.size():
return
if self.recv_chunk_size + 2 <= rx.size():
- if rx.peek()[self.recv_chunk_size : self.recv_chunk_size + 2] != b"\r\n":
- self.on_client_error(tx, self.CONTENT_LENGTH_ERROR)
- return
+ if rx.peek(self.recv_chunk_size + 2)[-2:] != b"\r\n":
+ raise InvalidContentLength()
self.state = self._app_endpoint_st
- def _recv_payload_st(self, rx, tx):
- if self.headers[self.CONTENT_LENGTH] > rx.size():
- return
- if self.headers[self.CONTENT_LENGTH] < rx.size():
- self.on_client_error(tx, self.CONTENT_LENGTH_ERROR)
+ def _recv_payload_st(self, rx):
+ """
+ State for receiving the request body.
+ """
+ if self.headers["content-length"] > rx.size():
return
self.state = self._app_endpoint_st
- def _app_endpoint_st(self, rx, tx):
- """Process a request by registered callback functions"""
+ def _app_endpoint_st(self, rx):
+ """
+ Process a request by registered callback functions.
+ """
method = self.GET if self.method == self.HEAD else self.method
callback = self._get_callback(self.url, method)
- if self._has_payload():
- if self._is_chunked():
+ if self.has_payload():
+ if self.is_chunked():
if self.recv_chunk_size:
callback(self, bytes(rx.peek(self.recv_chunk_size)))
rx.consume(self.recv_chunk_size + 2)
- self.state = self._recv_chunked_size_st
+ self.state = self._recv_chunk_size_st
return
dtype, data = callback(self, bytes(rx.peek(self.recv_chunk_size)))
rx.consume(self.recv_chunk_size + 2)
else:
- dtype, data = callback(self, bytes(rx.peek()))
- dtype = dtype.encode(self.ASCII)
+ dtype, data = callback(
+ self, bytes(rx.peek(self.headers["content-length"]))
+ )
+ dtype = dtype.encode("ascii")
else:
if not callable(callback):
- # Handle as a static resource
- self.state = lambda _rx, _tx: self._send_file_st(
- _rx, _tx, callback.encode(HttpEngine.ASCII)
+ # Handle as a file path
+ self.state = lambda _rx: self._send_file_st(
+ _rx, callback.encode("ascii")
)
return
dtype, data = callback(self, b"")
- dtype = dtype.encode(self.ASCII)
- self._set_response_header(b"content-type", dtype)
+ dtype = dtype.encode("ascii")
+ self.set_response_header(b"content-type", dtype)
if dtype.startswith(b"multipart/"):
- self.state = lambda _rx, _tx: self._generate_multipart_response(
- _rx, _tx, data, dtype
- )
+ self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype)
return
- self.terminate(200, dtype)
- return self._generate_response(tx, data)
+ if not self.is_terminated():
+ self.terminate(200, True)
+ self.set_response_body(data, content_type=dtype)
- def _send_file_st(self, _, tx, web_resource: bytes): # pylint: disable=W0613
- """State for returning a static resource - disabled"""
- self.on_unavailable(tx)
+ def _send_file_st(self, _, path: bytes): # pylint: disable=W0613
+ """
+ State for returning including a file in the response body (disabled).
+ :param path: path to the resource
+ """
+ self.terminate(503, True)
- def _start_multipart_parser_st(self, rx, tx): # pylint: disable=W0613
- """Initial state for processing multipart requests"""
- self.on_unavailable(tx)
+ def _start_multipart_parser_st(self, rx): # pylint: disable=W0613
+ """
+ Initial state for processing multipart requests (disabled).
+ """
+ self.terminate(503)
def _generate_multipart_response(
- self, rx, tx, callback, dtype
+ self, rx, callback, dtype
): # pylint: disable=W0613
- """Generate multipart response depening on the exact content type"""
- self.on_unavailable(tx)
+ """
+ Generate multipart response depening on the exact content type (disabled).
+ """
+ self.terminate(503, True)
def enable_optional_features():
diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py
index 1e1acc4..f128d37 100644
--- a/src/pyrobusta/protocol/http_file_server.py
+++ b/src/pyrobusta/protocol/http_file_server.py
@@ -10,43 +10,48 @@
from pyrobusta.utils.helpers import normalize_path, add_method
-def _send_file_st(self, _, tx, web_resource: bytes):
- """State for returning a static resource"""
+def _send_file_st(self, _, file_path: bytes):
+ """
+ State for returning a file. By default, /www is prepended to the path.
+ Alternatively, ready any file from the root when the path starts with /files
+ if it is configured in http_served_paths.
+ :param file_path: path to the file (unnormalized)
+ """
if self.url == b"/files":
- web_resource = "/"
+ file_path = "/"
elif self.url.startswith(b"/files/"):
- web_resource = web_resource[7:]
+ file_path = file_path[7:]
elif self.url == b"/":
- web_resource = b"/www/index.html"
+ file_path = b"/www/index.html"
else:
- web_resource = b"/www" + web_resource
+ file_path = b"/www" + file_path
- extension = web_resource.rsplit(b".", 1)[-1]
- norm_path = normalize_path(web_resource.decode(self.ASCII))
+ extension = file_path.rsplit(b".", 1)[-1]
+ norm_path = normalize_path(file_path.decode("ascii"))
is_path_served = self.is_norm_path_served(norm_path)
if not is_path_served:
try:
stat(norm_path)
- self.on_forbidden(tx)
+ self.terminate(403, True)
return
except OSError:
- self.on_missing_resource(tx)
+ self.terminate(404, True)
return
try:
content_type = self._lookup(self.CONTENT_TYPES, extension)
except ValueError:
content_type = self._lookup(self.CONTENT_TYPES, b"raw")
try:
- self._set_response_header(
- b"content-length", str(stat(norm_path)[6]).encode(http.HttpEngine.ASCII)
+ self.set_response_header(
+ b"content-length", str(stat(norm_path)[6]).encode("ascii")
)
- self.terminate(200, content_type)
- self._write_response_head(tx, None)
+ self.set_response_header(b"content-type", content_type)
+ self.terminate(200, True)
if self.method != self.HEAD:
- return open(norm_path, "rb")
+ self.resp_handler = open(norm_path, "rb") # pylint: disable=R1732
return
except OSError:
- self.on_missing_resource(tx)
+ self.terminate(404, True)
def apply_patches():
diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py
index e58074a..fbc0862 100644
--- a/src/pyrobusta/protocol/http_multipart.py
+++ b/src/pyrobusta/protocol/http_multipart.py
@@ -8,24 +8,27 @@
from pyrobusta.utils.helpers import add_method
-def _generate_multipart_response(self, _, tx, callback, dtype):
- """Generate multipart response depening on the exact content type"""
+def _generate_multipart_response(self, _, callback: callable, dtype: bytes):
+ """
+ Generate multipart response depening on the exact content type.
+ The callback function is called without arguments, and it must return bytes-like objects.
+ :param callback: function for part generation, each call generates a separate part
+ :param dtype: exact multipart content-type (multipart/*)
+ """
if type(callback).__name__ not in ("function", "closure"):
- self.on_failure(tx, b"Invalid response handler")
- return
- self.terminate(200, dtype)
+ raise ValueError("Invalid response handler")
+ self.terminate(200, True)
boundary = self.MULTIPART_BOUNDARY
- self._set_response_header(b"content-type", dtype + b"; boundary=" + boundary)
- self._write_response_head(tx, None)
+ self.set_response_header(b"content-type", dtype + b"; boundary=" + boundary)
if self.method != self.HEAD:
- return self._multipart_wrapper_factory(callback, boundary)
+ self.resp_handler = self._multipart_wrapper_factory(callback, boundary)
-def _multipart_wrapper_factory(callback, boundary: bytes):
+def _multipart_wrapper_factory(callback: callable, boundary: bytes):
"""
- Factory method for creating closures that write multipart responses
- :param callback: function without arguments, must return bytes-like objects
- :param content_type: content type of body parts
+ Factory method for creating closures that write multipart responses.
+ The callback function is called without arguments, and it must return bytes-like objects.
+ :param callback: function for part generation, each call generates a separate part
:param boundary: boundary value
:return closure: closure to invoke for response generation
"""
@@ -33,7 +36,7 @@ def _multipart_wrapper_factory(callback, boundary: bytes):
def _multipart_wrapper(tx):
"""
- Write multipart data generated from a callback's return value
+ Write multipart data generated from a callback's return value.
- if insufficient buffer space is available, the generator yields control so
the caller can flush or drain the buffer
:return bool: true if the stream is completed
@@ -62,25 +65,27 @@ def _multipart_wrapper(tx):
return _multipart_wrapper
-def _start_multipart_parser_st(self, rx, tx):
- """Initial state for processing multipart requests"""
- if not http.HttpEngine.CONTENT_LENGTH in self.headers:
- self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR)
- return
+def _start_multipart_parser_st(self, rx):
+ """
+ Initial state for processing multipart requests.
+ """
+ if not "content-length" in self.headers:
+ raise http.InvalidContentLength()
if (start_delimiter := rx.find(b"\r\n")) == -1:
return
self.mp_delimiter = b"--" + self.mp_boundary + b"\r\n"
self.mp_last_delimiter = b"--" + self.mp_boundary + b"--"
if rx.peek(start_delimiter + 2) != self.mp_delimiter:
- self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR)
- return
+ raise http.MalformedRequest()
rx.consume(start_delimiter + 2)
self.content_len_cnt += start_delimiter + 2
self.state = self._parse_boundary_st
-def _parse_boundary_st(self, rx, _):
- """State for parsing multipart boundary delimiter"""
+def _parse_boundary_st(self, rx):
+ """
+ State for parsing multipart boundary delimiter.
+ """
if (
rx.find(b"\r\n" + self.mp_delimiter) == -1
and rx.find(b"\r\n" + self.mp_last_delimiter) == -1
@@ -89,51 +94,50 @@ def _parse_boundary_st(self, rx, _):
self.state = self._parse_complete_part_st
-def _parse_complete_part_st(self, rx, tx):
+def _parse_complete_part_st(self, rx):
"""
- State for processing complete parts in a multipart request
+ State for processing complete parts in a multipart request.
- registered callback is required to process parts
"""
next_delimiter = rx.find(b"\r\n--" + self.mp_boundary)
part = rx.peek(next_delimiter)
rx.consume(next_delimiter + 2) # Consume leading CRLF
self.content_len_cnt += next_delimiter + 2
- is_final = rx.peek(len(self.mp_last_delimiter)) == self.mp_last_delimiter
+ is_final = (
+ rx.size() >= len(self.mp_last_delimiter)
+ and rx.peek(len(self.mp_last_delimiter)) == self.mp_last_delimiter
+ )
+
# Validate part and content-length
- if self.headers[http.HttpEngine.CONTENT_LENGTH] < self.content_len_cnt:
- self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR)
- return
- try:
- part_headers, part_body = http.HttpEngine._parse_body_part(part)
- except http.HeaderParsingError:
- self.on_client_error(tx, http.HttpEngine.HEADER_ERROR)
- return
+ if self.headers["content-length"] < self.content_len_cnt:
+ raise http.InvalidContentLength()
+ part_headers, part_body = http.HttpEngine._parse_body_part(part)
callback = http.HttpEngine._get_callback(self.url, self.method)
+
# Process complete part
if not is_final:
callback(self, (part_headers, part_body))
if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter:
- self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR)
- return
+ raise http.MalformedRequest()
rx.consume(len(self.mp_delimiter))
self.content_len_cnt += len(self.mp_delimiter)
self.mp_is_first = False
self.state = self._parse_boundary_st
return
+
# Process last part
rx.consume(len(self.mp_last_delimiter))
self.content_len_cnt += len(self.mp_last_delimiter)
if (
- self.headers[http.HttpEngine.CONTENT_LENGTH] != self.content_len_cnt
- and self.content_len_cnt + rx.size()
- != self.headers[http.HttpEngine.CONTENT_LENGTH]
+ self.headers["content-length"] != self.content_len_cnt
+ and self.content_len_cnt + rx.size() < self.headers["content-length"]
):
- self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR)
- return
+ raise http.InvalidContentLength()
self.mp_is_last = True
dtype, data = callback(self, (part_headers, part_body))
- self.terminate(200, dtype.encode(http.HttpEngine.ASCII))
- return self._generate_response(tx, data)
+ self.set_response_header(b"content-type", dtype.encode("ascii"))
+ self.terminate(200, True)
+ self.set_response_body(data)
def apply_patches():
@@ -150,6 +154,7 @@ def new_init(self, *args, **kwargs):
self.mp_last_delimiter = None
http.HttpEngine.__init__ = new_init
+ http.HttpEngine.MULTIPART_BOUNDARY = b"pyrobusta-boundary"
add_method(http.HttpEngine, _generate_multipart_response)
add_method(http.HttpEngine, _multipart_wrapper_factory, "static")
diff --git a/src/pyrobusta/server/http_server.py b/src/pyrobusta/server/http_server.py
index 6e7d4d5..29b7cc3 100644
--- a/src/pyrobusta/server/http_server.py
+++ b/src/pyrobusta/server/http_server.py
@@ -65,7 +65,8 @@ class HttpServer:
@classmethod
def _init_pools(cls, max_clients):
"""
- Initialize pool of buffers for sending/receiving based on different profiles
+ Initialize pool of buffers for sending/receiving based on different profiles.
+ :param max_clients: maximum number of HTTP clients
"""
mem_available = mem_free() + mem_alloc()
con_limit = max_clients
@@ -95,7 +96,9 @@ def _init_pools(cls, max_clients):
@classmethod
async def _drop_client(cls, client):
- """Remove client from active list"""
+ """
+ Remove client from the list of active clients.
+ """
if client not in cls.ACTIVE_CLIENTS:
return
logging.debug(__name__ + f": {client.id} dropped")
@@ -141,6 +144,9 @@ async def can_handle_new_client(self):
return False
async def _reserve_buffers(self):
+ """
+ Reserve and return request and response buffers.
+ """
if self.SEND_POOL is None or self.RECV_POOL is None:
raise RuntimeError("Pools are uninitialized")
@@ -160,6 +166,8 @@ async def _accept_socket(self, reader, writer):
"""
Handle incoming socket connection for HTTP.
- creates HttpConnection object
+ :param reader: asyncio StreamReader
+ :param reader: asyncio StreamWriter
"""
if not await self.can_handle_new_client():
logging.debug(__name__ + ": cannot accept new client")
@@ -224,7 +232,7 @@ async def start_socket_server(self):
async def terminate(self):
"""
- Terminate HTTP server and drop clients
+ Terminate HTTP server and drop clients.
"""
logging.info(__name__ + ": terminated")
while self.ACTIVE_CLIENTS:
diff --git a/src/pyrobusta/stream/buffer.py b/src/pyrobusta/stream/buffer.py
index 812a1f0..b5e81f0 100644
--- a/src/pyrobusta/stream/buffer.py
+++ b/src/pyrobusta/stream/buffer.py
@@ -19,7 +19,7 @@ class MemoryPool:
def __init__(self, block_size, block_count, wrapper=None):
"""
- Initialize memory pool
+ Initialize memory pool.
:param block_size: size of each memory block in bytes
:param block_count: number of reservable memory blocks
:param wrapper: wrapper class (abstraction layer) to access the memory, e.g. SlidingBuffer
@@ -66,7 +66,7 @@ class SlidingBuffer:
- Incremental consumption by advancing 'start'
- Incremental writes by advancing 'end'
- Automatic in-place compaction when additional space is
- required and unused bytes exist before 'start'
+ required and unused bytes exist before 'start'
- Bounded memory usage; no dynamic reallocation
"""
@@ -80,25 +80,33 @@ def __init__(self, buffer: bytearray | memoryview):
self.capacity = len(buffer)
def size(self) -> int:
- """Determine the window size"""
+ """
+ Determine the window size.
+ """
return self._end - self._start
def writable(self) -> int:
- """Determine the writeable size of the buffer"""
+ """
+ Determine the writeable size of the buffer.
+ """
return self.capacity - self._end
def readable_view(self) -> memoryview:
- """Return a memoryview to the readable region of the buffer (window)"""
+ """
+ Return a memoryview to the readable region of the buffer (window).
+ """
return self._mv[self._start : self._end]
def writable_view(self) -> memoryview:
- """Return a memoryview to the writeable region of the buffer"""
+ """
+ Return a memoryview to the writeable region of the buffer.
+ """
return self._mv[self._end : self.capacity]
def _compact(self):
"""
Compact the buffer by shifting the active
- window to the beginning of the bytearray
+ window to the beginning of the bytearray.
"""
if self._start == 0:
return
@@ -112,7 +120,7 @@ def _compact(self):
def peek(self, n=None) -> memoryview:
"""
Return the first n bytes from the window,
- return the entire window when n is undefined
+ return the entire window when n is undefined.
"""
if n is None:
n = self.size()
@@ -121,7 +129,9 @@ def peek(self, n=None) -> memoryview:
return self._mv[self._start : self._start + n]
def write(self, data: bytes):
- """Write new data into the writable region and advance the 'end' index"""
+ """
+ Write new data into the writable region and advance the 'end' index.
+ """
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("write() expects bytes or bytearray")
needed = len(data)
@@ -135,7 +145,9 @@ def write(self, data: bytes):
self._end += needed
def consume(self, n: int = None):
- """Discard the first n bytes of the window by advancing the 'start' index"""
+ """
+ Discard the first n bytes of the window by advancing the 'start' index.
+ """
if n is None:
n = self.size()
if n > self.size():
@@ -148,7 +160,7 @@ def consume(self, n: int = None):
def prepare(self, n: int):
"""
Check if the writeable region is larger or equal to n,
- otherwise attempt to compact the buffer
+ otherwise attempt to compact the buffer.
"""
if n > self.capacity:
raise ValueError("Capacity exceeded")
@@ -159,13 +171,17 @@ def prepare(self, n: int):
raise ValueError("Capacity exceeded")
def commit(self, n):
- """Increase the window size by n bytes by incrementing the 'end' index"""
+ """
+ Increase the window size by n bytes by incrementing the 'end' index.
+ """
if self._end + n > self.capacity:
raise ValueError("Capacity exceeded")
self._end += n
def find(self, term: bytes) -> int:
- """Find and return the index of a search term in the current window"""
+ """
+ Find and return the index of a search term in the current window.
+ """
for i in range(self._start, self._end - len(term) + 1):
if self._mv[i : i + len(term)] == term:
return i - self._start
diff --git a/src/pyrobusta/transport/connection.py b/src/pyrobusta/transport/connection.py
index ce516e4..aa5d7bd 100644
--- a/src/pyrobusta/transport/connection.py
+++ b/src/pyrobusta/transport/connection.py
@@ -19,8 +19,8 @@ class BaseConnection:
def __init__(self, reader, writer):
"""
Base class for connection handling.
- :param reader: async reader stream object
- :param writer: async writer stream object
+ :param reader: asyncio StreamReader
+ :param writer: asyncio StreamWriter
"""
client_info = writer.get_extra_info("peername")
self.id = str(client_info[0]) + ":" + str(client_info[1])
@@ -43,7 +43,7 @@ async def read(
:param read_bytes: number of bytes to read
:param decoding: decoding to use (optional), bytes are returned by default
:param timeout_seconds: an exception is raised if exceeded, 0 means waiting indefinitely
- :return data: holds bytes or decoded string read from the socket
+ :return data: holds bytes or decoded string read by the stream reader
"""
if not self.connected:
raise OSError(f"{self.id} already closed")
diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py
index 08f647a..b713309 100644
--- a/src/pyrobusta/utils/helpers.py
+++ b/src/pyrobusta/utils/helpers.py
@@ -6,7 +6,9 @@
def normalize_path(path: str):
- """Normalize a path string to resolve file and directory paths"""
+ """
+ Normalize a path string to resolve file and directory paths.
+ """
if not path:
return ""
parts = []
@@ -27,10 +29,11 @@ def normalize_path(path: str):
return cwd
-def add_method(cls, func, method_type="instance"):
+def add_method(cls, func: callable, method_type="instance"):
"""
- Helper to patch/extend classes with
- additional methods and states.
+ Helper to patch/extend classes with additional methods and states.
+ :param func: function to add
+ :param method_type: type of the method (instance, static, class)
"""
if method_type == "instance":
setattr(cls, func.__name__, func)
diff --git a/src/pyrobusta/utils/logging.py b/src/pyrobusta/utils/logging.py
index 6c58f66..9f67f21 100644
--- a/src/pyrobusta/utils/logging.py
+++ b/src/pyrobusta/utils/logging.py
@@ -11,7 +11,7 @@
def current_log_level():
"""
- Determine current log level from the config
+ Determine current log level from the config.
"""
current = get_config(CONF_LOG_LEVEL)
if current == "debug":
@@ -25,7 +25,7 @@ def current_log_level():
def warning(log):
"""
- Print warning messages
+ Print warning messages.
"""
if current_log_level() >= _LOG_LEVEL_WARNING:
print(f"[WARN] {log}")
@@ -33,7 +33,7 @@ def warning(log):
def info(log):
"""
- Print info messages
+ Print info messages.
"""
if current_log_level() >= _LOG_LEVEL_INFO:
print(f"[INFO] {log}")
@@ -41,7 +41,7 @@ def info(log):
def debug(log):
"""
- Print debug messages
+ Print debug messages.
"""
if current_log_level() >= _LOG_LEVEL_DEBUG:
print(f"[DEBUG] {log}")
diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py
index 68d5577..7618c9c 100644
--- a/tests/functional/test_http.py
+++ b/tests/functional/test_http.py
@@ -6,11 +6,7 @@
from os import mkdir, remove, rmdir
from pyrobusta.server import http_server
-from pyrobusta.protocol.http import (
- HttpEngine,
- enable_optional_features,
- ServerBusyError,
-)
+from pyrobusta.protocol.http import HttpEngine, enable_optional_features
from pyrobusta.utils.config import (
CONF_HTTP_SERVED_PATHS,
CONF_TLS,
@@ -85,8 +81,9 @@ def simple_callback(http_ctx, _):
@HttpEngine.route("/test/busy", "POST")
-def busy_callback(*_):
- raise ServerBusyError()
+def busy_callback(http_ctx, _):
+ http_ctx.terminate(503)
+ return "text/plain", "Unavailable"
def create_chunked_app_endpoint(endpoint):
@@ -117,7 +114,9 @@ async def test_simple_response(tls_enabled):
# Test: text/plain
plain_response = await send_request(
b"GET /test/simple HTTP/1.1\r\n"
- b"Host: localhost\r\nAccept:text/plain\r\n"
+ b"Host: localhost\r\n"
+ b"Connection: close\r\n"
+ b"Accept:text/plain\r\n"
b"\r\n",
tls_enabled,
)
@@ -135,7 +134,9 @@ async def test_simple_response(tls_enabled):
# Test: application/json
json_response = await send_request(
b"GET /test/simple HTTP/1.1\r\n"
- b"Host: localhost\r\nAccept:application/json\r\n"
+ b"Host: localhost\r\n"
+ b"Connection: close\r\n"
+ b"Accept: application/json\r\n"
b"\r\n",
tls_enabled,
)
@@ -160,7 +161,9 @@ async def test_server_busy():
server, server_task = await start_server()
plain_response = await send_request(
- b"POST /test/busy HTTP/1.1\r\n" b"Host: localhost\r\n\r\n"
+ b"POST /test/busy HTTP/1.1\r\n"
+ b"Connection:close\r\n"
+ b"Host: localhost\r\n\r\n"
)
test_assert(
f"response is rejected by busy service with 503",
@@ -179,15 +182,14 @@ async def test_chunked_transfer_encoding():
server, server_task = await start_server()
json_response = await send_request(
- (
- b"POST /test/chunked HTTP/1.1\r\n"
- b"Host: localhost\r\n"
- b"Transfer-Encoding: chunked\r\n\r\n"
- b"14\r\nchunking\r\ntest\r\ncase\r\n"
- b"E\r\nchunking\r\ntest\r\n"
- b"8\r\nchunking\r\n"
- b"0\r\n\r\n"
- )
+ b"POST /test/chunked HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Connection: close\r\n"
+ b"Transfer-Encoding: chunked\r\n\r\n"
+ b"14\r\nchunking\r\ntest\r\ncase\r\n"
+ b"E\r\nchunking\r\ntest\r\n"
+ b"8\r\nchunking\r\n"
+ b"0\r\n\r\n"
)
response_body = json.loads(json_response.split(b"\r\n\r\n")[1])
test_assert(
@@ -226,7 +228,9 @@ async def test_fs_access_control():
# Case #1: /www/index.html
response = await send_request(
- (b"GET /allowed/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n")
+ b"GET /allowed/index.html HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Host: localhost\r\n\r\n"
)
response_body = response.split(b"\r\n\r\n")[1]
@@ -238,7 +242,9 @@ async def test_fs_access_control():
# Case #2: /index.html
response = await send_request(
- (b"GET /rejected/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n")
+ b"GET /rejected/index.html HTTP/1.1\r\n"
+ b"Connection: close\r\n"
+ b"Host: localhost\r\n\r\n"
)
test_assert(
@@ -255,6 +261,90 @@ async def test_fs_access_control():
await server.terminate()
+@garbage_collect
+async def test_keepalive():
+ setup_config()
+ server, server_task = await start_server()
+
+ # ----------------------------------
+ # Case 1: all requests are processed
+ # ----------------------------------
+ plain_responses = await send_request(
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Connection: close\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ )
+
+ test_assert(
+ f"contains all responses (connection: keep-alive)",
+ plain_responses.count(b"HTTP/1.1 200 OK"),
+ 3,
+ )
+
+ # -------------------------------------------------------------------
+ # Case 2: close connection after the second request (invalid framing)
+ # -------------------------------------------------------------------
+ plain_responses = await send_request(
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b""
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ )
+
+ test_assert(
+ f"contains two responses (connection: keep-alive, invalid framing)",
+ plain_responses.count(b"HTTP/1.1"),
+ 2,
+ )
+
+ # ------------------------------------------------
+ # Case 3: close connection after the first request
+ # ------------------------------------------------
+ plain_response = await send_request(
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Connection: close\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ b"GET /test/simple HTTP/1.1\r\n"
+ b"Host: localhost\r\n"
+ b"Accept:text/plain\r\n"
+ b"\r\n"
+ )
+
+ test_assert(
+ f"contains single response (connection: close)",
+ plain_response.count(b"HTTP/1.1 200 OK"),
+ 1,
+ )
+
+ server_task.cancel()
+ await server.terminate()
+
+
#################################################
# Test methods
#################################################
@@ -294,6 +384,7 @@ def test_main():
asyncio.run(test_server_busy())
asyncio.run(test_chunked_transfer_encoding())
asyncio.run(test_fs_access_control())
+ asyncio.run(test_keepalive())
test_main()
diff --git a/tests/functional/test_http_multipart.py b/tests/functional/test_http_multipart.py
index 1765243..e52f665 100644
--- a/tests/functional/test_http_multipart.py
+++ b/tests/functional/test_http_multipart.py
@@ -91,7 +91,7 @@ def multipart_callback(http_ctx, _):
async def start_server():
"""
- Start an HTTP server as a background task
+ Start an HTTP server as a background task.
"""
server = http_server.HttpServer()
server_task = asyncio.create_task(server.start_socket_server())
@@ -107,9 +107,12 @@ async def test_multipart_response(tls_enabled):
# Test: 1 part
plain_response = await send_request(
b"GET /test/multipart HTTP/1.1\r\n"
- b"Host: localhost\r\nX-Part-Count: 1\r\n\r\n",
+ b"Host: localhost\r\n"
+ b"Connection: close\r\n"
+ b"X-Part-Count: 1\r\n\r\n",
tls_enabled,
)
+
test_assert(
f"http{"s" if tls_enabled else ""} response contains 1 part",
b"Response 1" in plain_response,
@@ -119,7 +122,9 @@ async def test_multipart_response(tls_enabled):
# Test: 10 parts
plain_response = await send_request(
b"GET /test/multipart HTTP/1.1\r\n"
- b"Host: localhost\r\nX-Part-Count: 10\r\n\r\n",
+ b"Host: localhost\r\n"
+ b"Connection: close\r\n"
+ b"X-Part-Count: 10\r\n\r\n",
tls_enabled,
)
test_assert(
diff --git a/tests/system/http_dimensioning/test.py b/tests/system/http_dimensioning/test.py
index b191975..aa48adb 100644
--- a/tests/system/http_dimensioning/test.py
+++ b/tests/system/http_dimensioning/test.py
@@ -189,7 +189,9 @@ def measure_footprint(config, device_ip):
port = 4443 if config["tls"] == True else 8080
try:
usage = requests.get(
- f"{proto}://{device_ip}:{port}/mem/current", verify=False
+ f"{proto}://{device_ip}:{port}/mem/current",
+ verify=False,
+ headers={"Connection": "close"},
).text
print(f"Measured: {usage}")
except:
@@ -225,6 +227,7 @@ def worker():
f"{base_url}/index.html",
verify=False,
timeout=5,
+ headers={"Connection": "close"},
)
resp.raise_for_status()
@@ -254,6 +257,7 @@ def worker():
f"{base_url}/mem/time-series",
verify=False,
timeout=5,
+ headers={"Connection": "close"},
).json()
print(f"Measured: {usage}")
diff --git a/tests/unit/test_buffer.py b/tests/unit/test_buffer.py
index f004855..75c7d1e 100644
--- a/tests/unit/test_buffer.py
+++ b/tests/unit/test_buffer.py
@@ -5,7 +5,7 @@
class BufferTestBase(unittest.TestCase):
"""
- Tests for stream.buffer module
+ Tests for stream.buffer module.
"""
buffer_type = bytearray
diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py
index 31d393d..949bc2f 100644
--- a/tests/unit/test_http.py
+++ b/tests/unit/test_http.py
@@ -83,7 +83,7 @@ def test_status_parsing_valid(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.method, b"GET")
self.assertEqual(self.engine.url, b"/index.html")
@@ -96,7 +96,7 @@ def test_status_parsing_incomplete_line(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
if self.engine.state is None:
break
@@ -110,7 +110,7 @@ def test_status_parsing_unsupported_method(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
if self.engine.state is None:
break
@@ -125,7 +125,7 @@ def test_status_parsing_unsupported_version(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
if self.engine.state is None:
break
@@ -141,7 +141,7 @@ def test_header_parsing_valid(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertDictEqual(
{"content-length": 10, "content-type": "application/json"},
@@ -153,14 +153,10 @@ def test_header_parsing_valid(self):
def test_header_parsing_incomplete_header(self):
request = b"GET /index.html HTTP/1.1\r\nContent-Type\r\n\r\n"
- for i in range(len(request)):
- self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
- if self.engine.state is None:
- break
-
- self.assertEqual(self.engine.status_code, 400)
- self.assertEqual(self.engine.state, None)
+ with self.assertRaises(self.http_module.InvalidHeaders):
+ for i in range(len(request)):
+ self.rx.write(request[i : i + 1])
+ self.engine.state(self.rx)
def test_header_parsing_error(self):
for case in (
@@ -171,7 +167,7 @@ def test_header_parsing_error(self):
b"space in header name: value",
b"new-line-in-header:\nvalue",
):
- with self.assertRaises(self.http_module.HeaderParsingError):
+ with self.assertRaises(self.http_module.InvalidHeaders):
self.engine._parse_headers(case)
def test_routing_unsupported_method(self):
@@ -183,7 +179,7 @@ def test_routing_unsupported_method(self):
test_callback = mock.Mock()
self.engine.register("/api/test", test_callback, "POST")
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 405)
self.assertEqual(self.engine.state, None)
@@ -201,7 +197,7 @@ def test_routing_options_method(self):
self.engine.register("/api/test", test_callback, "POST")
self.engine.register("/api/test", test_callback, "PUT")
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 204)
self.assertEqual(self.engine.state, None)
@@ -220,15 +216,15 @@ def test_routing_get_method(self):
self.engine.register("/api/test", test_callback, "GET")
while self.engine.state is not None:
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 200)
self.assertEqual(self.engine.state, None)
- self.assertNotEqual(
- self.tx.find(b"content-length: " + str(len(test_response)).encode("ascii")),
- -1,
+ self.assertEqual(
+ int(self.engine._lookup(self.engine.resp_headers, b"content-length")),
+ len(test_response),
)
- self.assertNotEqual(self.tx.find(test_response), -1)
+ self.assertEqual(self.engine.resp_handler.read(), test_response)
def test_routing_head_method(self):
self.engine.state = self.engine._route_request_st
@@ -242,22 +238,22 @@ def test_routing_head_method(self):
self.engine.register("/api/test", test_callback, "GET")
while self.engine.state is not None:
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.status_code, 200)
self.assertEqual(self.engine.state, None)
- self.assertNotEqual(
- self.tx.find(b"content-length: " + str(len(test_response)).encode("ascii")),
- -1,
+ self.assertEqual(
+ int(self.engine._lookup(self.engine.resp_headers, b"content-length")),
+ len(test_response),
)
- self.assertEqual(self.tx.find(test_response), -1)
+ self.assertEqual(self.engine.resp_handler, None)
def test_simple_query_parameter(self):
request = b"GET /api/test?param HTTP/1.1\r\n"
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.query, "param")
@@ -275,7 +271,7 @@ def pct_encode(b):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.query, f"safe_chars.{unsafe_chars}")
@@ -284,7 +280,7 @@ def test_single_url_encoded_query_parameter(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(
self.engine.get_url_encoded_query_param(self.engine.query, "param"), "value"
@@ -297,7 +293,7 @@ def test_multiple_url_encoded_query_parameter(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(
self.engine.get_url_encoded_query_param(self.engine.query, "param1"),
@@ -317,7 +313,7 @@ def test_empty_or_missing_url_encoded_query_parameter(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(
self.engine.get_url_encoded_query_param(self.engine.query, "param1"),
@@ -342,7 +338,7 @@ def test_overlapping_url_encoded_query_parameter(self):
for i in range(len(request)):
self.rx.write(request[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(
self.engine.get_url_encoded_query_param(self.engine.query, "data"),
@@ -362,7 +358,7 @@ def test_chunked_transfer_encoding_valid(self):
self.engine.method = b"GET"
self.engine.version = b"HTTP/1.1"
self.engine.headers["transfer-encoding"] = "chunked"
- self.engine.state = self.engine._recv_chunked_size_st
+ self.engine.state = self.engine._recv_chunk_size_st
test_callback = mock.Mock(return_value=("text/plain", "OK"))
self.engine.register("/api/test", test_callback, "GET")
@@ -375,10 +371,10 @@ def test_chunked_transfer_encoding_valid(self):
):
for i in range(len(chunk)):
self.rx.write(chunk[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.state, self.engine._app_endpoint_st)
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
size_delimiter = chunk.find(b"\r\n")
test_callback.assert_called_with(
self.engine, chunk[size_delimiter + 2 : -2]
@@ -392,27 +388,23 @@ def test_chunked_transfer_encoding_invalid_chunk_size_smaller(self):
self.engine.method = b"GET"
self.engine.version = b"HTTP/1.1"
self.engine.headers["transfer-encoding"] = "chunked"
- self.engine.state = self.engine._recv_chunked_size_st
+ self.engine.state = self.engine._recv_chunk_size_st
test_callback = mock.Mock(return_value=("text/plain", "OK"))
self.engine.register("/api/test", test_callback, "GET")
- chunk = b"2\r\nchunking\r\n"
- for i in range(len(chunk)):
- self.rx.write(chunk[i : i + 1])
- self.engine.state(self.rx, self.tx)
- if self.engine.state is None:
- break
-
- self.assertEqual(self.engine.status_code, 400)
- self.assertEqual(self.engine.state, None)
+ with self.assertRaises(self.http_module.InvalidContentLength):
+ chunk = b"2\r\nchunking\r\n"
+ for i in range(len(chunk)):
+ self.rx.write(chunk[i : i + 1])
+ self.engine.state(self.rx)
def test_chunked_transfer_encoding_chunk_incomplete(self):
self.engine.url = b"/api/test"
self.engine.method = b"GET"
self.engine.version = b"HTTP/1.1"
self.engine.headers["transfer-encoding"] = "chunked"
- self.engine.state = self.engine._recv_chunked_size_st
+ self.engine.state = self.engine._recv_chunk_size_st
test_callback = mock.Mock(return_value=("text/plain", "OK"))
self.engine.register("/api/test", test_callback, "GET")
@@ -420,7 +412,7 @@ def test_chunked_transfer_encoding_chunk_incomplete(self):
chunk = b"FF\r\nchunking\r\n"
for i in range(len(chunk)):
self.rx.write(chunk[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
if self.engine.state is None:
break
@@ -510,7 +502,7 @@ def test_multipart_parser(self):
{"content-type": 'multipart/form-data;boundary=missing-quote"'},
]:
with self.subTest(headers=case):
- with self.assertRaises(self.http_module.HeaderParsingError):
+ with self.assertRaises(self.http_module.InvalidHeaders):
self.engine._get_mp_boundary(case)
def test_multipart_receiver_valid(self):
@@ -521,7 +513,7 @@ def test_multipart_receiver_valid(self):
for i in range(len(body_part)):
self.rx.write(body_part[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.state, self.engine._parse_boundary_st)
self.assertEqual(self.rx.peek(), b"Content-Type:text/plain")
@@ -533,15 +525,10 @@ def test_multipart_receiver_boundary_mismatch(self):
self.engine.mp_boundary = b"test-boundary"
body_part = b"--test-boundary-delimiter\r\nContent-Type:text/plain"
- for i in range(len(body_part)):
- self.rx.write(body_part[i : i + 1])
- self.engine.state(self.rx, self.tx)
- if self.engine.state is None:
- break
-
- self.assertEqual(self.engine.state, None)
- self.assertEqual(self.engine.status_code, 400)
- self.assertEqual(self.rx.peek(), b"--test-boundary-delimiter\r\n")
+ with self.assertRaises(self.http_module.MalformedRequest):
+ for i in range(len(body_part)):
+ self.rx.write(body_part[i : i + 1])
+ self.engine.state(self.rx)
def test_multipart_receiver_complete_part(self):
self.engine.state = self.engine._parse_boundary_st
@@ -566,13 +553,13 @@ def test_multipart_receiver_complete_part(self):
for i in range(len(body_part)):
self.assertEqual(self.engine.state, self.engine._parse_boundary_st)
self.rx.write(body_part[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.state, self.engine._parse_complete_part_st)
self.assertEqual(self.rx.peek(), body_part)
self.assertEqual(self.engine.mp_is_first, True)
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.state, self.engine._parse_boundary_st)
test_callback.assert_called_once_with(
@@ -611,12 +598,12 @@ def test_multipart_receiver_last_part(self):
for i in range(len(body_part)):
self.assertEqual(self.engine.state, self.engine._parse_boundary_st)
self.rx.write(body_part[i : i + 1])
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.state, self.engine._parse_complete_part_st)
self.assertEqual(self.rx.peek(), body_part)
- self.engine.state(self.rx, self.tx)
+ self.engine.state(self.rx)
self.assertEqual(self.engine.state, None)
self.assertEqual(self.engine.status_code, 200)
@@ -666,7 +653,7 @@ def test_file_serving_missing_file(self, *_):
self.engine.version = b"HTTP/1.1"
self.engine.state = self.engine._send_file_st
- self.engine.state(self.rx, self.tx, self.engine.url)
+ self.engine.state(self.rx, self.engine.url)
self.assertEqual(self.engine.status_code, 404)
self.assertEqual(self.engine.state, None)
@@ -680,10 +667,10 @@ def test_file_serving_root(self, *_):
file_content = "index content"
with patch("builtins.open", mock_open(read_data=file_content)) as m:
- response_generator = self.engine.state(self.rx, self.tx, self.engine.url)
+ self.engine.state(self.rx, self.engine.url)
m.assert_called_once_with("/www/index.html", "rb")
- self.assertEqual(response_generator.read(), file_content)
+ self.assertEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(self.engine.status_code, 200)
self.assertEqual(self.engine.state, None)
@@ -696,10 +683,10 @@ def test_file_serving_files_endpoint(self, *_):
file_content = "data"
with patch("builtins.open", mock_open(read_data=file_content)) as m:
- response_generator = self.engine.state(self.rx, self.tx, self.engine.url)
+ self.engine.state(self.rx, self.engine.url)
m.assert_called_once_with("/www/scripts.js", "rb")
- self.assertEqual(response_generator.read(), file_content)
+ self.assertEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(self.engine.status_code, 200)
self.assertEqual(self.engine.state, None)
@@ -712,10 +699,10 @@ def test_file_serving_known_content_type(self, *_):
file_content = "data"
with patch("builtins.open", mock_open(read_data=file_content)) as m:
- response_generator = self.engine.state(self.rx, self.tx, self.engine.url)
+ self.engine.state(self.rx, self.engine.url)
m.assert_called_once_with("/www/scripts.js", "rb")
- self.assertEqual(response_generator.read(), file_content)
+ self.assertEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(
self.engine._lookup(self.engine.resp_headers, b"content-type"),
b"application/javascript",
@@ -732,10 +719,10 @@ def test_file_serving_fallback_content_type(self, *_):
file_content = "data"
with patch("builtins.open", mock_open(read_data=file_content)) as m:
- response_generator = self.engine.state(self.rx, self.tx, self.engine.url)
+ self.engine.state(self.rx, self.engine.url)
m.assert_called_once_with("/www/scripts.unknown", "rb")
- self.assertEqual(response_generator.read(), file_content)
+ self.assertEqual(self.engine.resp_handler.read(), file_content)
self.assertEqual(
self.engine._lookup(self.engine.resp_headers, b"content-type"),
b"application/octet-stream",
@@ -752,10 +739,10 @@ def test_file_serving_unserved_content_rejected(self, *_):
file_content = "data"
with patch("builtins.open", mock_open(read_data=file_content)) as m:
- response_generator = self.engine.state(self.rx, self.tx, self.engine.url)
+ self.engine.state(self.rx, self.engine.url)
m.assert_not_called()
- self.assertEqual(response_generator, None)
+ self.assertEqual(self.engine.resp_handler, None)
self.assertEqual(self.engine.status_code, 403)
self.assertEqual(self.engine.state, None)