diff --git a/Makefile b/Makefile index c609406..4d1b045 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -PYROBUSTA_VERSION := 0.1.0 +PYROBUSTA_VERSION := 0.2.0 DEVICE ?= u0 SRC_DIR := src diff --git a/README.md b/README.md index 19338f9..1156afc 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,9 @@ A lightweight HTTP server library for MicroPython designed for constrained embed - Bounded-copy memory footprint - Finite-state-machine parser with linear sliding buffer - Robust byte-stream handling +- Query parameter parsing with percent encoding support - TLS support -## Current limitation -- Query parameter parsing is not yet implemented - # Prerequisites ## Setup virtual environment diff --git a/dist/pyrobusta/bindings/socket_http.mpy b/dist/pyrobusta/bindings/socket_http.mpy index 40f6659..95d8afa 100644 Binary files a/dist/pyrobusta/bindings/socket_http.mpy and b/dist/pyrobusta/bindings/socket_http.mpy differ diff --git a/dist/pyrobusta/protocol/http.mpy b/dist/pyrobusta/protocol/http.mpy index e7f78f5..49e8176 100644 Binary files a/dist/pyrobusta/protocol/http.mpy and b/dist/pyrobusta/protocol/http.mpy differ diff --git a/dist/pyrobusta/protocol/http_multipart.mpy b/dist/pyrobusta/protocol/http_multipart.mpy index b7c3609..f6fac86 100644 Binary files a/dist/pyrobusta/protocol/http_multipart.mpy and b/dist/pyrobusta/protocol/http_multipart.mpy differ diff --git a/dist/pyrobusta/utils/config.mpy b/dist/pyrobusta/utils/config.mpy index 3b74366..9f08f56 100644 Binary files a/dist/pyrobusta/utils/config.mpy and b/dist/pyrobusta/utils/config.mpy differ diff --git a/example/mem_usage/app.py b/example/mem_usage/app.py index 4b1b6d6..bccf0ae 100644 --- a/example/mem_usage/app.py +++ b/example/mem_usage/app.py @@ -7,11 +7,34 @@ @HttpEngine.route("/mem-usage", "GET") -def mem_usage(*_): +def mem_usage(http_ctx, _): collect() free = mem_free() used = mem_alloc() usage_percentage = 100 * used / (free + used) + + if http_ctx.query: + value_format = http_ctx.get_url_encoded_query_param( + http_ctx.query, "format", "bytes" + ) + if value_format not in ("%", "bytes"): + raise ValueError("invalid format") + + selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "") + if selector == "free": + if value_format == "%": + free = 100 * free / (used + free) + return "text/plain", f"Free [{value_format}]: {free}\n" + if selector == "used": + if value_format == "%": + used = 100 * used / (used + free) + return "text/plain", f"Used [{value_format}]: {used}\n" + if selector == "total": + return "text/plain", f"Total [bytes]: {used + free}\n" + + if selector: + raise ValueError("invalid key") + return "text/plain", ( f"Currently used: {usage_percentage:.2f}%\n" f"Free [bytes]: {free}\n" diff --git a/example/mip_repo/app.py b/example/mip_repo/app.py index e637e59..5955b9a 100644 --- a/example/mip_repo/app.py +++ b/example/mip_repo/app.py @@ -28,10 +28,10 @@ def append_package_files(dir, package_files, host_name, protocol): @HttpEngine.route("/pyrobusta/package.json", "GET") -def self_serve_mip_package(headers, _): +def self_serve_mip_package(http_ctx, _): package_files = {"version": config.PYROBUSTA_VERSION, "deps": [], "urls": []} tls_enabled = config.get_config("tls").lower() == "true" - server_addr = headers["host"] + server_addr = http_ctx.headers["host"] if ":" not in server_addr: port = ( http_server.HttpServer.LISTEN_PORT_HTTPS diff --git a/package.json b/package.json index 86caf4b..30a9089 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,5 @@ { - "version": "0.1.0", + "version": "0.2.0", "urls": [ [ "pyrobusta/transport/socket.mpy", diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index ddd9127..f72cb1a 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -38,14 +38,16 @@ class HttpEngine: "headers", "method", "url", + "query", "content_length_cnt", "mp_boundary", "mp_first_part", + "mp_last_part", "mp_delimiter", "mp_closing_delimiter", ) - ENDPOINTS = {} + ENDPOINTS = [] # (endpoint, callback, method) RESP_HEADERS = ( 200, b"200 OK", @@ -123,6 +125,7 @@ def __init__(self): self.headers = {} self.method = None self.url = None + self.query = None self.content_length_cnt = 0 # [Multipart state] @@ -144,14 +147,16 @@ def register( """ endpoint = endpoint.encode(cls.ASCII) method = method.encode(cls.ASCII) - if not endpoint in cls.ENDPOINTS: - cls.ENDPOINTS[endpoint] = {} + endpoint_exists = cls._get_callback(endpoint, method) is not None + if method not in cls.METHODS: raise ValueError(f"method must be one of {cls.METHODS}") - cls.ENDPOINTS[endpoint][method] = callback + if endpoint_exists: + raise ValueError("endpoint exists") + cls.ENDPOINTS.append((endpoint, callback, method)) @staticmethod - def route(endpoint, method): + def route(endpoint: str, method: str): """ Decorator for registering endpoint callback functions. """ @@ -166,15 +171,66 @@ def decorator(func): # Static helpers for parsing # ========================================= + @staticmethod + def percent_decode(s: str): + """Decode percent-encoded input""" + out = [] + i = 0 + while i < len(s): + if s[i] == "%" and i + 2 < len(s): + out.append(chr(int(s[i + 1 : i + 3], 16))) + i += 3 + else: + out.append(s[i]) + i += 1 + return "".join(out) + + @staticmethod + def get_url_encoded_query_param(query: str, key: str, default: str = None): + """ + Parse query and return the value belonging to a key + according to x-www-form-urlencoded + :param query: query part + :param key: key to parse from the query + :param default: default value to return when key is not present + """ + idx_start = query.find(key + "=") + if idx_start != -1: + idx_end = -1 + idx_end = query.find("&", idx_start) + if idx_start > -1: + if idx_end > -1: + return query[idx_start + len(key) + 1 : idx_end] + return query[idx_start + len(key) + 1 :] + if default is None: + raise KeyError() + return default + + @staticmethod + def _lookup(tuple_, key): + idx = tuple_.index(key) + return tuple_[idx + 1] + + @classmethod + def _get_callback(cls, endpoint, method): + for e in cls.ENDPOINTS: + if endpoint == e[0] and method == e[2]: + return e[1] + @classmethod - def _get_status(cls, status_code): - idx = cls.RESP_HEADERS.index(status_code) - return cls.RESP_HEADERS[idx + 1] + def _has_endpoint(cls, endpoint): + for e in cls.ENDPOINTS: + if endpoint == e[0]: + return True + return False @classmethod - def _get_content_type(cls, extension): - idx = cls.CONTENT_TYPES.index(extension) - return cls.CONTENT_TYPES[idx + 1] + def _supported_methods(cls, endpoint): + supported_methods = [] + for method in cls.METHODS: + if cls._get_callback(endpoint, method) is not None: + supported_methods.append(method) + return supported_methods @classmethod def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: @@ -263,7 +319,7 @@ def _write_response_head(self, tx, content_length: int = 0): tx.consume() tx.write(self.version) tx.write(b" ") - tx.write(self._get_status(self.status_code)) + tx.write(self._lookup(self.RESP_HEADERS, self.status_code)) if content_length is not None: tx.write(b"\r\n") tx.write(b"content-length: %s" % str(content_length).encode(self.ASCII)) @@ -364,7 +420,13 @@ def _parse_request_line_st(self, rx, tx): self.on_client_error(tx, self.BAD_REQUEST_ERROR) return self.method = status_parts[0] - self.url = status_parts[1] + url_parts = status_parts[1].split(b"?", 1) + self.url = url_parts[0] + self.query = ( + "" + if len(url_parts) == 1 + else self.percent_decode(url_parts[1].decode(self.ASCII)) + ) self.version = status_parts[2] if self.method not in self.METHODS: self.on_method_not_allowed(tx) @@ -398,13 +460,16 @@ def _route_request_st(self, _, tx): State for routing requests - supported ways: static resources, endpoint callback functions """ - if self.url in self.ENDPOINTS and ( - self.method in self.ENDPOINTS[self.url] + if self._has_endpoint(self.url) and ( + self._get_callback(self.url, self.method) is not None or self.method == self.OPTIONS - or (self.method == self.HEAD and self.GET in self.ENDPOINTS[self.url]) + or ( + self.method == self.HEAD + and self._get_callback(self.url, self.GET) is not None + ) ): if self.method == self.OPTIONS: - supported_methods = list(self.ENDPOINTS[self.url].keys()) + supported_methods = self._supported_methods(self.url) self._set_response_header(b"allow", b", ".join(supported_methods)) self.terminate(204, None) self._write_response_head(tx, None) @@ -421,12 +486,16 @@ def _route_request_st(self, _, tx): else: self.state = self._app_endpoint_st return - if self.url in self.ENDPOINTS and self.method not in self.ENDPOINTS[self.url]: - supported_methods = list(self.ENDPOINTS[self.url].keys()) + + if ( + self._has_endpoint(self.url) + and self._get_callback(self.method, self.url) is None + ): + supported_methods = self._supported_methods(self.url) self._set_response_header(b"allow", b", ".join(supported_methods)) self.on_method_not_allowed(tx) return - if self.method == self.GET: + if self.method in (self.GET, self.HEAD): resource = b"index.html" if not self.url else self.url self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, resource) return @@ -443,10 +512,10 @@ def _recv_payload(self, rx, tx): def _app_endpoint_st(self, rx, tx): """Process a request by registered callback functions""" method = self.GET if self.method == self.HEAD else self.method - callback = self.ENDPOINTS[self.url][method] + callback = self._get_callback(self.url, method) if self._has_payload(): self.state = None - dtype, data = callback(self.headers, bytes(rx.peek())) + dtype, data = callback(self, bytes(rx.peek())) dtype = dtype.encode(self.ASCII) else: if not callable(callback): @@ -455,7 +524,7 @@ def _app_endpoint_st(self, rx, tx): _rx, _tx, callback.encode(HttpEngine.ASCII) ) return - dtype, data = callback(self.headers, b"") + dtype, data = callback(self, b"") dtype = dtype.encode(self.ASCII) self._set_response_header(b"content-type", dtype) if dtype == b"image/jpeg": @@ -500,9 +569,9 @@ def _send_file_st(self, _, tx, web_resource: bytes): norm_path = b"/".join(parts) try: - content_type = self._get_content_type(extension) + content_type = self._lookup(self.CONTENT_TYPES, extension) except ValueError: - content_type = self._get_content_type(b"raw") + content_type = self._lookup(self.CONTENT_TYPES, b"raw") try: self._set_response_header( b"content-length", str(stat(norm_path)[6]).encode(HttpEngine.ASCII) diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index 75cde2f..a8e6e96 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -107,10 +107,10 @@ def _parse_complete_part_st(self, rx, tx): except http.HeaderParsingError: self.on_client_error(tx, http.HttpEngine.HEADER_ERROR) return - callback = http.HttpEngine.ENDPOINTS[self.url][self.method] + callback = http.HttpEngine._get_callback(self.url, self.method) # Process complete part if not is_final: - callback(part_headers, part_body, first=self.mp_first_part, last=False) + callback(self, (part_headers, part_body)) if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter: self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR) return @@ -129,7 +129,8 @@ def _parse_complete_part_st(self, rx, tx): ): self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR) return - dtype, data = callback(part_headers, part_body, first=self.mp_first_part, last=True) + self.mp_last_part = True + dtype, data = callback(self, (part_headers, part_body)) self.terminate(200, dtype.encode(http.HttpEngine.ASCII)) return self._generate_response(tx, data) @@ -145,6 +146,7 @@ def apply_patches(): def new_init(self, *args, **kwargs): orig_init(self, *args, **kwargs) self.mp_first_part = True + self.mp_last_part = False self.mp_delimiter = None self.mp_closing_delimiter = None diff --git a/src/pyrobusta/utils/config.py b/src/pyrobusta/utils/config.py index 331dd0b..ea8e017 100644 --- a/src/pyrobusta/utils/config.py +++ b/src/pyrobusta/utils/config.py @@ -4,7 +4,7 @@ Values can be encapsulated by single or double quotes. """ -PYROBUSTA_VERSION = "0.1.0" +PYROBUSTA_VERSION = "0.2.0" CONFIG_LOADED = False CONFIG_LOCATION = "pyrobusta.env" CONFIG_CACHE = [ @@ -36,7 +36,7 @@ def read_config(config=CONFIG_LOCATION): with open(config, encoding="utf-8") as conf: for line in conf.read().splitlines("\n"): key = line.split("=")[0].strip() - if key.startswith("#"): + if key.startswith("#") or not line.strip(): continue value = line.split("=")[1].strip().strip("'").strip('"') if key and value: diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py index 76b084a..611a0d2 100644 --- a/tests/functional/test_http.py +++ b/tests/functional/test_http.py @@ -70,22 +70,22 @@ def response_generator(): @HttpEngine.route("/test/simple", "GET") -def simple_callback(headers, body): - if headers["accept"] == "text/plain": +def simple_callback(http_ctx, _): + if http_ctx.headers["accept"] == "text/plain": return "text/plain", "Test response\n" - elif headers["accept"] == "application/json": + elif http_ctx.headers["accept"] == "application/json": return "application/json", '{"response": "Test response"}' raise ValueError("Unhandled content-type") @HttpEngine.route("/test/multipart", "GET") -def multipart_callback(headers, body): - part_count = int(headers["x-part-count"]) +def multipart_callback(http_ctx, _): + part_count = int(http_ctx.headers["x-part-count"]) return "multipart/form-data", ("text/plain", multipart_response(part_count)) @HttpEngine.route("/test/busy", "POST") -def busy_callback(headers, body): +def busy_callback(*_): raise ServerBusyError() @@ -211,19 +211,19 @@ def test_registration(): test_assert( "simple endpoint registration", simple_callback, - HttpEngine.ENDPOINTS[b"/test/simple"][b"GET"], + HttpEngine._get_callback(b"/test/simple", b"GET"), ) test_assert( "multipart endpoint registration", multipart_callback, - HttpEngine.ENDPOINTS[b"/test/multipart"][b"GET"], + HttpEngine._get_callback(b"/test/multipart", b"GET"), ) test_assert( "busy endpoint registration", busy_callback, - HttpEngine.ENDPOINTS[b"/test/busy"][b"POST"], + HttpEngine._get_callback(b"/test/busy", b"POST"), ) diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index b4774cd..b0c8655 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -6,14 +6,14 @@ from .utils import load_module -class TestWebStateMachine(unittest.TestCase): +class TestWebStateMachineBase(unittest.TestCase): """ - Tests for the core functionality of the state machine. + Base class for stat machine tests. """ @classmethod def setUpClass(cls): - cls.config = {"http_multipart": "False"} + cls.config = {} def setUp(self): # Create mock modules @@ -54,6 +54,16 @@ def side_effect(input_arg, *_, **__): self.mock_utils_config.get_config.side_effect = side_effect + +class TestWebStateMachine(TestWebStateMachineBase): + """ + Tests for the core functionality of the state machine. + """ + + @classmethod + def setUpClass(cls): + cls.config = {"http_multipart": "False"} + def test_status_parsing_valid(self): request = b"GET /index.html HTTP/1.1\r\nContent-Length:10" @@ -216,8 +226,96 @@ def test_routing_head_method(self): ) self.assertEqual(self.tx.find(test_response), -1) + def test_simple_query_parameter(self): + request = b"GET /api/test?param HTTP/1.1\r\n" + + for i in range(len(request)): + self.rx.write(request[i : i + 1]) + self.engine.state(self.rx, self.tx) + + self.assertEqual(self.engine.query, "param") + + def test_pct_encoded_query_parameter(self): + def pct_encode(b): + out = [] + for c in b: + out.append(f"%{ord(c):02X}") + return "".join(out) + + unsafe_chars = ":/?#[]@!$&'()*+,;=% " + request = b"GET /api/test?safe_chars.%s HTTP/1.1\r\n" % pct_encode( + unsafe_chars + ).encode("ascii") + + for i in range(len(request)): + self.rx.write(request[i : i + 1]) + self.engine.state(self.rx, self.tx) + + self.assertEqual(self.engine.query, f"safe_chars.{unsafe_chars}") + + def test_single_url_encoded_query_parameter(self): + request = b"GET /api/test?param=value HTTP/1.1\r\n" + + for i in range(len(request)): + self.rx.write(request[i : i + 1]) + self.engine.state(self.rx, self.tx) + + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "param"), "value" + ) + + def test_multiple_url_encoded_query_parameter(self): + request = ( + b"GET /api/test?param1=value1¶m2=value2¶m3=value3 HTTP/1.1\r\n" + ) + + for i in range(len(request)): + self.rx.write(request[i : i + 1]) + self.engine.state(self.rx, self.tx) + + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "param1"), + "value1", + ) + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "param2"), + "value2", + ) + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "param3"), + "value3", + ) + + def test_empty_or_missing_url_encoded_query_parameter(self): + request = b"GET /api/test?param1=¶m2= HTTP/1.1\r\n" -class TestMultipartStateMachine(TestWebStateMachine): + for i in range(len(request)): + self.rx.write(request[i : i + 1]) + self.engine.state(self.rx, self.tx) + + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "param1"), + "", + ) + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "param2"), + "", + ) + self.assertEqual( + self.engine.get_url_encoded_query_param( + self.engine.query, "param3", "default" + ), + "default", + ) + + with self.assertRaises(KeyError): + self.engine.get_url_encoded_query_param(self.engine.query, "param3") + + +class TestMultipartStateMachine(TestWebStateMachineBase): + """ + Tests for multipart handling + """ @classmethod def setUpClass(cls): @@ -307,19 +405,23 @@ def test_multipart_receiver_complete_part(self): self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) self.assertEqual(self.rx.peek(), body_part) + self.assertEqual(self.engine.mp_first_part, True) self.engine.state(self.rx, self.tx) self.assertEqual(self.engine.state, self.engine._parse_boundary_st) test_callback.assert_called_once_with( - { - "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', - "content-type": "text/plain", - }, - b"Upload content", - first=True, - last=False, + self.engine, + ( + { + "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', + "content-type": "text/plain", + }, + b"Upload content", + ), ) + self.assertEqual(self.engine.mp_first_part, False) + self.assertEqual(self.engine.mp_last_part, False) def test_multipart_receiver_last_part(self): self.engine.state = self.engine._parse_boundary_st @@ -354,14 +456,17 @@ def test_multipart_receiver_last_part(self): self.assertEqual(self.engine.state, None) self.assertEqual(self.engine.status_code, 200) test_callback.assert_called_once_with( - { - "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', - "content-type": "text/plain", - }, - b"Upload content", - first=True, - last=True, + self.engine, + ( + { + "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', + "content-type": "text/plain", + }, + b"Upload content", + ), ) + self.assertEqual(self.engine.mp_first_part, True) + self.assertEqual(self.engine.mp_last_part, True) if __name__ == "__main__":