diff --git a/Makefile b/Makefile index 4d1b045..4486799 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -PYROBUSTA_VERSION := 0.2.0 +PYROBUSTA_VERSION := 0.3.0 DEVICE ?= u0 SRC_DIR := src @@ -66,14 +66,16 @@ $(BUILD_DIR)/%.py: $(SRC_DIR)/%.py .PHONY: deploy deploy: @echo "Uploading build/$(PKG) to device $(DEVICE)" + @mpremote $(DEVICE) mkdir :/lib || true @find $(BUILD_DIR)/$(PKG) | while read source; do \ rel=$${source#$(BUILD_DIR)/}; \ + remote="/lib/$${rel}"; \ if [ -d "$$source" ]; then \ - mpremote $(DEVICE) mkdir "$$rel" || true; \ + mpremote $(DEVICE) mkdir "$$remote" || true; \ elif [ -f "$$source" ]; then \ - echo "Uploading $$rel"; \ - mpremote $(DEVICE) rm "$$rel" || true; \ - mpremote $(DEVICE) cp "$$source" ":$$rel"; \ + echo "Uploading $$remote"; \ + mpremote $(DEVICE) rm ":$$remote" || true; \ + mpremote $(DEVICE) cp "$$source" ":$$remote"; \ fi; \ sleep 1; \ done @@ -121,10 +123,10 @@ publish: stage-example: @echo "Preparing unix runtime in $(RUNTIME_DIR)" @rm -rf $(RUNTIME_DIR) - @mkdir -p $(RUNTIME_DIR) + @mkdir -p $(RUNTIME_DIR)/lib @echo "Copying built package" - @cp -r build/pyrobusta $(RUNTIME_DIR)/ + @cp -r build/pyrobusta $(RUNTIME_DIR)/lib @echo "Copying example files" @cp $(EXAMPLE_DIR)/app.py $(RUNTIME_DIR)/ @@ -142,7 +144,7 @@ stage-example: .PHONY: run-unix run-unix: stage-example @echo "Running example with unix micropython" - cd $(RUNTIME_DIR) && ../$(MICROPYTHON) app.py + cd $(RUNTIME_DIR) && MICROPYPATH=":.frozen:lib" ../$(MICROPYTHON) app.py # ----------------------------- # Deploy example app @@ -211,9 +213,9 @@ static-checkers: pylint black .PHONY: stage-test stage-test: @rm -rf $(TEST_RUNTIME) - @mkdir -p $(TEST_RUNTIME) + @mkdir -p $(TEST_RUNTIME)/lib - @cp -r build/pyrobusta $(TEST_RUNTIME)/ + @cp -r build/pyrobusta $(TEST_RUNTIME)/lib @cp tests/functional/*.py $(TEST_RUNTIME)/ # ----------------------------- @@ -225,7 +227,7 @@ test-unix: stage-test tls-cert @cd $(TEST_RUNTIME); \ for test in test_*.py; do \ echo "Running $$test"; \ - ../$(MICROPYTHON) $$(basename $$test) || exit 1; \ + MICROPYPATH=":.frozen:lib" ../$(MICROPYTHON) $$(basename $$test) || exit 1; \ done # ----------------------------- diff --git a/README.md b/README.md index 1156afc..0a14064 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ A lightweight HTTP server library for MicroPython designed for constrained embed - Routing decorators - Fixed-size, configurable request/response buffers - Multipart request and response handling +- Chunked transfer decoding for streamed request bodies - Bounded-copy memory footprint - Finite-state-machine parser with linear sliding buffer - Robust byte-stream handling diff --git a/dist/pyrobusta/bindings/socket_http.mpy b/dist/pyrobusta/bindings/socket_http.mpy index 95d8afa..b56ceb7 100644 Binary files a/dist/pyrobusta/bindings/socket_http.mpy and b/dist/pyrobusta/bindings/socket_http.mpy differ diff --git a/dist/pyrobusta/con/wifi.mpy b/dist/pyrobusta/con/wifi.mpy index 881acff..8f5f2bc 100644 Binary files a/dist/pyrobusta/con/wifi.mpy and b/dist/pyrobusta/con/wifi.mpy differ diff --git a/dist/pyrobusta/protocol/http.mpy b/dist/pyrobusta/protocol/http.mpy index 49e8176..d2a04ba 100644 Binary files a/dist/pyrobusta/protocol/http.mpy and b/dist/pyrobusta/protocol/http.mpy differ diff --git a/dist/pyrobusta/protocol/http_multipart.mpy b/dist/pyrobusta/protocol/http_multipart.mpy index f6fac86..03fdbed 100644 Binary files a/dist/pyrobusta/protocol/http_multipart.mpy and b/dist/pyrobusta/protocol/http_multipart.mpy differ diff --git a/dist/pyrobusta/server/http_server.mpy b/dist/pyrobusta/server/http_server.mpy index 4dc378b..467a75e 100644 Binary files a/dist/pyrobusta/server/http_server.mpy and b/dist/pyrobusta/server/http_server.mpy differ diff --git a/dist/pyrobusta/stream/buffer.mpy b/dist/pyrobusta/stream/buffer.mpy index c17d5ac..1117c23 100644 Binary files a/dist/pyrobusta/stream/buffer.mpy and b/dist/pyrobusta/stream/buffer.mpy differ diff --git a/dist/pyrobusta/transport/socket.mpy b/dist/pyrobusta/transport/socket.mpy index 18a1dfd..3389bfc 100644 Binary files a/dist/pyrobusta/transport/socket.mpy and b/dist/pyrobusta/transport/socket.mpy differ diff --git a/dist/pyrobusta/utils/config.mpy b/dist/pyrobusta/utils/config.mpy index 9f08f56..5e9b55a 100644 Binary files a/dist/pyrobusta/utils/config.mpy and b/dist/pyrobusta/utils/config.mpy differ diff --git a/dist/pyrobusta/utils/helpers.mpy b/dist/pyrobusta/utils/helpers.mpy new file mode 100644 index 0000000..19f006a Binary files /dev/null and b/dist/pyrobusta/utils/helpers.mpy differ diff --git a/docs/configuration.md b/docs/configuration.md index 5b595a7..8ab2b12 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -9,7 +9,7 @@ to upload it to the root directory of the target device. | wifi_password | Password of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | | http_multipart | Enable multipart HTTP requests/responses. | "False" | | http_mem_cap | Max memory cap (% × 0.01) of usable heap for HTTP request/response stream buffers. | 0.1 | -| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "pyrobusta lib" | +| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/lib/pyrobusta" | | socket_max_con | Max number of socket connections of any enabled application server. | 2 | | tls | Enable/disable TLS. When turned on, cert.der/key.der must be installed at the root. | "False" | | log_level | Can be one of: warning, info, debug. | "warning" | diff --git a/package.json b/package.json index 30a9089..ddfce0e 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,5 @@ { - "version": "0.2.0", + "version": "0.3.0", "urls": [ [ "pyrobusta/transport/socket.mpy", @@ -9,6 +9,10 @@ "pyrobusta/transport/__init__.py", "github:szeka9/PyRobusta/dist/pyrobusta/transport/__init__.py" ], + [ + "pyrobusta/utils/helpers.mpy", + "github:szeka9/PyRobusta/dist/pyrobusta/utils/helpers.mpy" + ], [ "pyrobusta/utils/__init__.py", "github:szeka9/PyRobusta/dist/pyrobusta/utils/__init__.py" diff --git a/src/pyrobusta/bindings/socket_http.py b/src/pyrobusta/bindings/socket_http.py index 955ba3b..799fea7 100644 --- a/src/pyrobusta/bindings/socket_http.py +++ b/src/pyrobusta/bindings/socket_http.py @@ -54,10 +54,7 @@ def init_pools(max_sockets): ) if is_low_memory: logging.warning( - ( - "[SocketHttp.init_pools] low-memory mode with reduced buffer size, " - "decrease max_clients to use larger buffers" - ) + __name__ + ".init_pools: low-memory mode with reduced buffer size" ) recv_size = ( SocketHttp.RECV_BUF_MIN_BYTES @@ -73,13 +70,13 @@ def init_pools(max_sockets): if usable < per_conn: raise MemoryError( ( - f"Insufficient memory for webserver: {mem_available // 1024} KB " + f"Insufficient memory: {mem_available // 1024} KB " f"at {SocketHttp.MEM_CAP*100}% cap, " f"at least {per_conn // 1024} KB required" ) ) con_limit = min(usable // per_conn, con_limit) - logging.info((f"[SocketHttp.init_pools] {con_limit} connection(s) allowed")) + logging.info((__name__ + f".init_pools: {con_limit} connection(s) allowed")) SocketHttp.RECV_POOL = MemoryPool(recv_size, con_limit, wrapper=SlidingBuffer) SocketHttp.SEND_POOL = MemoryPool(send_size, con_limit, wrapper=SlidingBuffer) @@ -113,7 +110,7 @@ async def run(self): await self._run_state_machine() await sleep_ms(SocketHttp.STATE_MACHINE_SLEEP_MS) except Exception as e: # pylint: disable=W0718 - logging.warning(f"[SocketHttp] error in run_web: {e}") + logging.warning(__name__ + f": error in run_web: {e}") finally: if self._send_buf: self._send_buf.consume() @@ -126,7 +123,7 @@ async def run(self): async def _reserve_buffers(self): if SocketHttp.SEND_POOL is None or SocketHttp.RECV_POOL is None: - raise RuntimeError("Buffer pools are uninitialized") + raise RuntimeError("Pools are ninitialized") while not self._recv_buf or not self._send_buf: if not self._recv_buf: @@ -158,7 +155,7 @@ async def _run_state_machine(self): await self._flush_response() return except Exception as e: # pylint: disable=W0718 - logging.warning(f"[SocketHttp] error in _run_state_machine: {e}") + logging.warning(__name__ + f"._run_state_machine: {e}") self._engine.on_failure(self._send_buf, str(e).encode("ascii")) await self._flush_response() return @@ -188,7 +185,7 @@ async def _read_to_buf(self): await self._flush_response() return 0 self._recv_buf.write(request) - logging.debug(f"[SocketHttp._read_to_buf] read new message chunk: {request}") + logging.debug(__name__ + f"._read_to_buf: [{request}]") return len(request) async def _response_handler(self, resp_handler): diff --git a/src/pyrobusta/con/wifi.py b/src/pyrobusta/con/wifi.py index 5a2f170..8755f44 100644 --- a/src/pyrobusta/con/wifi.py +++ b/src/pyrobusta/con/wifi.py @@ -14,7 +14,7 @@ def initialize(): ssid = get_config("wifi_ssid") password = get_config("wifi_password") if not ssid or not password: - logging.warning("[Wi-Fi] Missing SSID/password, skip Wi-Fi initialization") + logging.warning(__name__ + ": missing SSID/password, skip initialization") return sta_if = WLAN(STA_IF) @@ -22,9 +22,9 @@ def initialize(): nets = sta_if.scan() for net in nets: if net[0].decode() == get_config("wifi_ssid"): - logging.info(f"[Wi-Fi] Network {net[0]} found!") + logging.info(__name__ + f": network {net[0]} found!") sta_if.connect(net[0], get_config("wifi_password")) - logging.info("[Wi-Fi] WLAN connection succeeded!") + logging.info(__name__ + f": connected, available at {sta_if.ifconfig()[0]}") break diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index f72cb1a..0de2a54 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -8,6 +8,7 @@ from os import stat from ..utils.config import get_config +from ..utils.helpers import normalize_path class HeaderParsingError(ValueError): @@ -33,18 +34,19 @@ class HttpEngine: __slots__ = ( "state", "status_code", - "response_headers", + "resp_headers", "version", "headers", "method", "url", "query", - "content_length_cnt", + "content_len_cnt", + "recv_chunk_size", "mp_boundary", - "mp_first_part", - "mp_last_part", + "mp_is_first", + "mp_is_last", "mp_delimiter", - "mp_closing_delimiter", + "mp_last_delimiter", ) ENDPOINTS = [] # (endpoint, callback, method) @@ -55,6 +57,8 @@ class HttpEngine: b"204 No Content", 400, b"400 Bad Request", + 403, + b"403 Forbidden", 404, b"404 Not Found", 405, @@ -109,7 +113,7 @@ class HttpEngine: MULTIPART_BOUNDARY = b"pyrobusta-boundary" - CONTENT_LENGTH_ERROR = b"Content-Length mismatch" + CONTENT_LENGTH_ERROR = b"content length mismatch" HEADER_ERROR = b"Invalid headers" MULTIPART_BOUNDARY_ERROR = b"Invalid multipart boundary" BAD_REQUEST_ERROR = b"Bad request" @@ -118,7 +122,7 @@ def __init__(self): # [State machine] self.state = self._parse_request_line_st self.status_code = None - self.response_headers = [] + self.resp_headers = [] # [Recived request] self.version = None @@ -126,7 +130,8 @@ def __init__(self): self.method = None self.url = None self.query = None - self.content_length_cnt = 0 + self.content_len_cnt = 0 + self.recv_chunk_size = 0 # [Multipart state] self.mp_boundary = None @@ -206,6 +211,21 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None): raise KeyError() return default + @classmethod + def is_norm_path_served(cls, path: str): + """ + Returns true if a normalized path is configured to be served + """ + served_paths = set(get_config("http_served_paths").split()) + parts = path.split("/") + for i, _ in enumerate(parts): + current_path = "/".join(parts[: i + 1]) + if not current_path: + current_path = "/" + if current_path in served_paths: + return True + return False + @staticmethod def _lookup(tuple_, key): idx = tuple_.index(key) @@ -290,13 +310,13 @@ def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]: def _set_response_header(self, key, value): if ( - key in self.response_headers - and (index := self.response_headers.index(key) % 2) == 0 + key in self.resp_headers + and (index := self.resp_headers.index(key) % 2) == 0 ): - self.response_headers[index + 1] = value + self.resp_headers[index + 1] = value else: - self.response_headers.append(key) - self.response_headers.append(value) + self.resp_headers.append(key) + self.resp_headers.append(value) def terminate(self, status_code: int, content_type: bytes = b"text/plain"): """ @@ -323,9 +343,9 @@ def _write_response_head(self, tx, content_length: int = 0): if content_length is not None: tx.write(b"\r\n") tx.write(b"content-length: %s" % str(content_length).encode(self.ASCII)) - for i in range(0, len(self.response_headers), 2): - key = self.response_headers[i] - value = self.response_headers[i + 1] + for i in range(0, len(self.resp_headers), 2): + key = self.resp_headers[i] + value = self.resp_headers[i + 1] tx.write(b"\r\n") tx.write(key) tx.write(b": ") @@ -366,6 +386,11 @@ def on_client_error(self, tx, info: bytes): self._write_response_head(tx, len(response)) tx.write(response) + def on_forbidden(self, tx): + """Terminate state machine and write 403 response""" + self.terminate(403) + self._write_response_head(tx) + def on_missing_resource(self, tx): """Terminate state machine and write 404 response""" self.terminate(404) @@ -443,17 +468,22 @@ def _parse_headers_st(self, rx, tx): return try: self.headers = self._parse_headers(rx.peek(blank_idx)) + if self.version == b"HTTP/1.1" and "host" not in self.headers: + raise HeaderParsingError() except HeaderParsingError: self.on_client_error(tx, self.HEADER_ERROR) return rx.consume(blank_idx + 4) self.state = self._route_request_st + def _is_chunked(self): + return self.headers.get("transfer-encoding") == "chunked" + def _has_payload(self): return ( self.CONTENT_LENGTH in self.headers and self.headers[self.CONTENT_LENGTH] > 0 - ) + ) or self._is_chunked() def _route_request_st(self, _, tx): """ @@ -481,8 +511,10 @@ def _route_request_st(self, _, tx): if mp_boundary := self._is_multipart(self.headers): self.mp_boundary = mp_boundary.encode(self.ASCII) self.state = self._start_multipart_parser_st + elif self._is_chunked(): + self.state = self._recv_chunked_size_st else: - self.state = self._recv_payload + self.state = self._recv_payload_st else: self.state = self._app_endpoint_st return @@ -501,7 +533,23 @@ def _route_request_st(self, _, tx): return self.on_missing_resource(tx) - def _recv_payload(self, rx, tx): + def _recv_chunked_size_st(self, rx, _): + if (blank_idx := rx.find(b"\r\n")) == -1: + return + self.recv_chunk_size = int(bytes(rx.peek(blank_idx)), 16) + rx.consume(blank_idx + 2) + self.state = self._recv_chunk_st + + def _recv_chunk_st(self, rx, tx): + if self.recv_chunk_size + 2 > rx.size(): + return + if self.recv_chunk_size + 2 <= rx.size(): + if rx.peek()[self.recv_chunk_size : self.recv_chunk_size + 2] != b"\r\n": + self.on_client_error(tx, self.CONTENT_LENGTH_ERROR) + return + self.state = self._app_endpoint_st + + def _recv_payload_st(self, rx, tx): if self.headers[self.CONTENT_LENGTH] > rx.size(): return if self.headers[self.CONTENT_LENGTH] < rx.size(): @@ -514,8 +562,16 @@ def _app_endpoint_st(self, rx, tx): method = self.GET if self.method == self.HEAD else self.method callback = self._get_callback(self.url, method) if self._has_payload(): - self.state = None - dtype, data = callback(self, bytes(rx.peek())) + if self._is_chunked(): + if self.recv_chunk_size: + callback(self, bytes(rx.peek(self.recv_chunk_size))) + rx.consume(self.recv_chunk_size + 2) + self.state = self._recv_chunked_size_st + return + dtype, data = callback(self, bytes(rx.peek(self.recv_chunk_size))) + rx.consume(self.recv_chunk_size + 2) + else: + dtype, data = callback(self, bytes(rx.peek())) dtype = dtype.encode(self.ASCII) else: if not callable(callback): @@ -527,9 +583,7 @@ def _app_endpoint_st(self, rx, tx): dtype, data = callback(self, b"") dtype = dtype.encode(self.ASCII) self._set_response_header(b"content-type", dtype) - if dtype == b"image/jpeg": - self.terminate(200, dtype) - return self._generate_response(tx, data) + if dtype in (b"multipart/x-mixed-replace", b"multipart/form-data"): part_content_type = data[0] callback = data[1] @@ -552,22 +606,17 @@ def _app_endpoint_st(self, rx, tx): def _send_file_st(self, _, tx, web_resource: bytes): """State for returning a static resource""" - # Normalize path - parts = [] - for p in web_resource.split(b"/"): - if p in (b".", b""): - continue - if p == b"..": - if parts: - parts.pop() - else: - parts.append(p) - if parts[0].decode(self.ASCII) not in get_config("http_served_paths").split(): - self.on_missing_resource(tx) - return extension = web_resource.rsplit(b".", 1)[-1] - norm_path = b"/".join(parts) - + norm_path = normalize_path(web_resource.decode(self.ASCII)) + is_path_served = self.is_norm_path_served(norm_path) + if not is_path_served: + try: + stat(norm_path) + self.on_forbidden(tx) + return + except OSError: + self.on_missing_resource(tx) + return try: content_type = self._lookup(self.CONTENT_TYPES, extension) except ValueError: diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index a8e6e96..dc95393 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -14,9 +14,9 @@ def add_method(cls, func, method_type="instance"): """ if method_type == "instance": setattr(cls, func.__name__, func) - elif method_type == "staticmethod": + elif method_type == "static": setattr(cls, func.__name__, staticmethod(func)) - elif method_type == "classmethod": + elif method_type == "class": setattr(cls, func.__name__, classmethod(func)) else: raise ValueError("Invalid type") @@ -69,12 +69,12 @@ def _start_multipart_parser_st(self, rx, tx): if (start_delimiter := rx.find(b"\r\n")) == -1: return self.mp_delimiter = b"--" + self.mp_boundary + b"\r\n" - self.mp_closing_delimiter = b"--" + self.mp_boundary + b"--" + self.mp_last_delimiter = b"--" + self.mp_boundary + b"--" if rx.peek(start_delimiter + 2) != self.mp_delimiter: self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR) return rx.consume(start_delimiter + 2) - self.content_length_cnt += start_delimiter + 2 + self.content_len_cnt += start_delimiter + 2 self.state = self._parse_boundary_st @@ -82,7 +82,7 @@ def _parse_boundary_st(self, rx, _): """State for parsing multipart boundary delimiter""" if ( rx.find(b"\r\n" + self.mp_delimiter) == -1 - and rx.find(b"\r\n" + self.mp_closing_delimiter) == -1 + and rx.find(b"\r\n" + self.mp_last_delimiter) == -1 ): return self.state = self._parse_complete_part_st @@ -96,10 +96,10 @@ def _parse_complete_part_st(self, rx, tx): next_delimiter = rx.find(b"\r\n--" + self.mp_boundary) part = rx.peek(next_delimiter) rx.consume(next_delimiter + 2) # Consume leading CRLF - self.content_length_cnt += next_delimiter + 2 - is_final = rx.peek(len(self.mp_closing_delimiter)) == self.mp_closing_delimiter + self.content_len_cnt += next_delimiter + 2 + is_final = rx.peek(len(self.mp_last_delimiter)) == self.mp_last_delimiter # Validate part and content-length - if self.headers[http.HttpEngine.CONTENT_LENGTH] < self.content_length_cnt: + if self.headers[http.HttpEngine.CONTENT_LENGTH] < self.content_len_cnt: self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR) return try: @@ -115,21 +115,21 @@ def _parse_complete_part_st(self, rx, tx): self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR) return rx.consume(len(self.mp_delimiter)) - self.content_length_cnt += len(self.mp_delimiter) - self.mp_first_part = False + self.content_len_cnt += len(self.mp_delimiter) + self.mp_is_first = False self.state = self._parse_boundary_st return # Process last part - rx.consume(len(self.mp_closing_delimiter)) - self.content_length_cnt += len(self.mp_closing_delimiter) + rx.consume(len(self.mp_last_delimiter)) + self.content_len_cnt += len(self.mp_last_delimiter) if ( - self.headers[http.HttpEngine.CONTENT_LENGTH] != self.content_length_cnt - and self.content_length_cnt + rx.size() + self.headers[http.HttpEngine.CONTENT_LENGTH] != self.content_len_cnt + and self.content_len_cnt + rx.size() != self.headers[http.HttpEngine.CONTENT_LENGTH] ): self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR) return - self.mp_last_part = True + self.mp_is_last = True dtype, data = callback(self, (part_headers, part_body)) self.terminate(200, dtype.encode(http.HttpEngine.ASCII)) return self._generate_response(tx, data) @@ -145,14 +145,14 @@ def apply_patches(): def new_init(self, *args, **kwargs): orig_init(self, *args, **kwargs) - self.mp_first_part = True - self.mp_last_part = False + self.mp_is_first = True + self.mp_is_last = False self.mp_delimiter = None - self.mp_closing_delimiter = None + self.mp_last_delimiter = None cls.__init__ = new_init - add_method(http.HttpEngine, _multipart_wrapper_factory, "staticmethod") + add_method(http.HttpEngine, _multipart_wrapper_factory, "static") add_method(http.HttpEngine, _start_multipart_parser_st) add_method(http.HttpEngine, _parse_boundary_st) add_method(http.HttpEngine, _parse_complete_part_st) diff --git a/src/pyrobusta/server/http_server.py b/src/pyrobusta/server/http_server.py index 811b3cd..e0cecff 100644 --- a/src/pyrobusta/server/http_server.py +++ b/src/pyrobusta/server/http_server.py @@ -2,9 +2,8 @@ Socket server application """ -from asyncio import sleep_ms, start_server, run # pylint: disable=E1101 import gc -import ssl +from asyncio import sleep_ms, start_server, run # pylint: disable=E1101 from time import ticks_ms, ticks_diff from ..protocol import http @@ -38,7 +37,7 @@ async def drop_client(cls, socket): """Remove socket from active list""" if socket not in cls.ACTIVE_SOCKETS: return - logging.debug(f"[HttpServer] {socket.id} dropped") + logging.debug(__name__ + f": {socket.id} dropped") await socket.close() cls.ACTIVE_SOCKETS.remove(socket) del socket @@ -72,8 +71,8 @@ async def can_handle_new_socket(self): if not socket.connected or socket_inactive > self._timeout: logging.debug( ( - f"[HttpSever] evicted {socket.id} " - f"timeout:{self._timeout - socket_inactive}s" + __name__ + f": evicted {socket.id} " + f"timeout: {self._timeout - socket_inactive}s" ) ) await self.drop_client(socket) @@ -87,13 +86,13 @@ async def accept_http(self, reader, writer): - creates SocketHttp object """ if not await self.can_handle_new_socket(): - logging.debug("[HttpSever] cannot accept new client") + logging.debug(__name__ + ": cannot accept new client") writer.close() await writer.wait_closed() return new_client = SocketHttp(reader, writer) - logging.debug(f"[HttpSever] new client: {new_client.id}") + logging.debug(__name__ + f": accept {new_client.id}") self.ACTIVE_SOCKETS.append(new_client) await new_client.run() @@ -108,6 +107,8 @@ async def run_server(self): SocketHttp.init_pools(self._max_sockets) ssl_ctx = None if get_config("tls").lower() == "true": + import ssl + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ssl_ctx.load_cert_chain(self.TLS_CERT_PATH, self.TLS_KEY_PATH) self._server = await start_server( @@ -117,15 +118,15 @@ async def run_server(self): backlog=self._max_sockets, ssl=ssl_ctx, ) - logging.info("[HttpSever] Started") + logging.info(__name__ + ": started") except MemoryError as e: - logging.warning(f"[HttpSever] Memory allocation failed: {e}") + logging.warning(__name__ + f": allocation failed - {e}") async def terminate(self): """ Terminate HTTP server and close sockets """ - logging.info("[HttpSever] Terminated") + logging.info(__name__ + ": terminated") while self.ACTIVE_SOCKETS: await self.drop_client(self.ACTIVE_SOCKETS[0]) if self._server: diff --git a/src/pyrobusta/stream/buffer.py b/src/pyrobusta/stream/buffer.py index 8ef43e9..812a1f0 100644 --- a/src/pyrobusta/stream/buffer.py +++ b/src/pyrobusta/stream/buffer.py @@ -151,17 +151,17 @@ def prepare(self, n: int): otherwise attempt to compact the buffer """ if n > self.capacity: - raise ValueError("Requested size exceeds capacity") + raise ValueError("Capacity exceeded") if n > self.writable(): self._compact() if n > self.writable(): - raise ValueError("Buffer full") + raise ValueError("Capacity exceeded") def commit(self, n): """Increase the window size by n bytes by incrementing the 'end' index""" if self._end + n > self.capacity: - raise ValueError("Buffer full") + raise ValueError("Capacity exceeded") self._end += n def find(self, term: bytes) -> int: diff --git a/src/pyrobusta/transport/socket.py b/src/pyrobusta/transport/socket.py index e4d7a3a..2d8dc63 100644 --- a/src/pyrobusta/transport/socket.py +++ b/src/pyrobusta/transport/socket.py @@ -36,7 +36,7 @@ async def read(self, read_bytes, decoding="utf8", timeout_seconds=0): - read_error is set to true upon timeout or other exception - data holds bytes or decoded string read from the socket """ - logging.debug(f"[SocketBase] read from {self.id}") + logging.debug(__name__ + f": read from {self.id}") self.last_event = ticks_ms() if timeout_seconds: request = await asyncio.wait_for( @@ -52,10 +52,10 @@ async def close(self): """ Async socket close method """ - logging.debug(f"[SocketBase] close connection: {self.id}") + logging.debug(__name__ + f": close connection: {self.id}") try: self.writer.close() await self.writer.wait_closed() except OSError as e: - logging.warning(f"[SocketBase] Error while closing {self.id}: {e}") + logging.warning(__name__ + f": error while closing {self.id}: {e}") self.connected = False diff --git a/src/pyrobusta/utils/config.py b/src/pyrobusta/utils/config.py index ea8e017..35dc2ea 100644 --- a/src/pyrobusta/utils/config.py +++ b/src/pyrobusta/utils/config.py @@ -4,7 +4,9 @@ Values can be encapsulated by single or double quotes. """ -PYROBUSTA_VERSION = "0.2.0" +from .helpers import normalize_path + +PYROBUSTA_VERSION = "0.3.0" CONFIG_LOADED = False CONFIG_LOCATION = "pyrobusta.env" CONFIG_CACHE = [ @@ -17,7 +19,7 @@ "http_mem_cap", 0.1, "http_served_paths", - "pyrobusta lib package.json", + "/lib/pyrobusta", "socket_max_con", 2, "tls", @@ -27,6 +29,15 @@ ] +def normalize(key, value): + """ + Normalize a configuration value depending on the key. + """ + if key == "http_served_paths": + return " ".join([normalize_path(p) for p in value.split()]) + return value + + def read_config(config=CONFIG_LOCATION): """ Read configuration from a file and update CONFIG_CACHE. @@ -34,12 +45,14 @@ def read_config(config=CONFIG_LOCATION): """ try: with open(config, encoding="utf-8") as conf: - for line in conf.read().splitlines("\n"): + for line in conf: + line = line.rstrip("\r\n") key = line.split("=")[0].strip() if key.startswith("#") or not line.strip(): continue value = line.split("=")[1].strip().strip("'").strip('"') if key and value: + value = normalize(key, value) if ( key in CONFIG_CACHE and (conf_idx := CONFIG_CACHE.index(key)) % 2 == 0 diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py new file mode 100644 index 0000000..cc0d9d3 --- /dev/null +++ b/src/pyrobusta/utils/helpers.py @@ -0,0 +1,27 @@ +""" +Helper methods +""" + +from os import getcwd + + +def normalize_path(path: str): + """Normalize a path string to resolve file and directory paths""" + if not path: + return "" + parts = [] + for p in path.split("/"): + if p in (".", ""): + continue + if p == "..": + if parts: + parts.pop() + else: + parts.append(p) + normalized = "/".join(parts) + cwd = getcwd() + if normalized: + if cwd.endswith("/"): + return cwd + normalized + return cwd + "/" + normalized + return cwd diff --git a/tests/.pylintrc b/tests/.pylintrc index 7bff1de..1824775 100644 --- a/tests/.pylintrc +++ b/tests/.pylintrc @@ -2,4 +2,5 @@ disable=W0212, C0114, C0115, - C0116 + C0116, + R0904 diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py index 611a0d2..231f0c4 100644 --- a/tests/functional/test_http.py +++ b/tests/functional/test_http.py @@ -1,5 +1,8 @@ import asyncio import ssl +import json + +from os import getcwd, mkdir from pyrobusta.server import http_server from pyrobusta.protocol import http_multipart @@ -9,6 +12,7 @@ ServerBusyError, ) from pyrobusta.utils import config +from pyrobusta.utils.helpers import normalize_path ################################################# # Test helpers @@ -89,13 +93,29 @@ def busy_callback(*_): raise ServerBusyError() -async def test_simple_response(tls_enabled): - setup_config(multipart=False, tls_enabled=tls_enabled) +def create_chunked_app_endpoint(endpoint): + recv_chunks = [] - # start server as background task + @HttpEngine.route(endpoint, "POST") + def chunked_callback(http_ctx, chunk): + if not chunk: # Received terminating chunk + return "application/json", recv_chunks + recv_chunks.append(chunk.decode("utf8")) + + +async def start_server(): + """ + Start an HTTP server as a background task + """ server = http_server.HttpServer() server_task = asyncio.create_task(server.run_server()) await asyncio.sleep_ms(100) + return server, server_task + + +async def test_simple_response(tls_enabled): + setup_config(multipart=False, tls_enabled=tls_enabled) + server, server_task = await start_server() # Test: text/plain plain_response = await send_request( @@ -139,11 +159,7 @@ async def test_simple_response(tls_enabled): async def test_multipart_response(tls_enabled): setup_config(multipart=True, tls_enabled=tls_enabled) - - # start server as background task - server = http_server.HttpServer() - server_task = asyncio.create_task(server.run_server()) - await asyncio.sleep_ms(100) + server, server_task = await start_server() # Test: 1 part plain_response = await send_request( @@ -175,12 +191,8 @@ async def test_multipart_response(tls_enabled): async def test_server_busy(): setup_config() + server, server_task = await start_server() - server = http_server.HttpServer() - server_task = asyncio.create_task(server.run_server()) - await asyncio.sleep_ms(100) - - # Test: 1 part plain_response = await send_request( b"POST /test/busy HTTP/1.1\r\n" b"Host: localhost\r\n\r\n" ) @@ -194,16 +206,90 @@ async def test_server_busy(): await server.terminate() +async def test_chunked_transfer_encoding(): + setup_config() + create_chunked_app_endpoint("/test/chunked") + server, server_task = await start_server() + + json_response = await send_request( + ( + b"POST /test/chunked HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n\r\n" + b"14\r\nchunking\r\ntest\r\ncase\r\n" + b"E\r\nchunking\r\ntest\r\n" + b"8\r\nchunking\r\n" + b"0\r\n\r\n" + ) + ) + response_body = json.loads(json_response.split(b"\r\n\r\n")[1]) + test_assert( + f"chunked transfer encoding - all chunks are received", + response_body, + ["chunking\r\ntest\r\ncase", "chunking\r\ntest", "chunking"], + ) + + server_task.cancel() + await server.terminate() + + +async def test_fs_access_control(): + setup_config(served_paths="/www") + server, server_task = await start_server() + + # Index page under /www -> accepted + workdir = normalize_path("/www") + index_html = normalize_path("/www/index.html") + mkdir(workdir) + with open(index_html, "w") as f: + f.write("PyRobusta Home") + + # Index page under / -> rejected + index_html = normalize_path("/index.html") + with open(index_html, "w") as f: + f.write("PyRobusta Home") + + # Case #1: /www/index.html + response = await send_request( + (b"GET /www/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") + ) + + response_body = response.split(b"\r\n\r\n")[1] + test_assert( + f"test FS access control - index page loaded", + response_body, + b"PyRobusta Home", + ) + + # Case #2: /index.html + response = await send_request( + (b"GET /index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") + ) + + test_assert( + f"test FS access control - index page rejected", + response.startswith(b"HTTP/1.1 403 Forbidden"), + True, + ) + + server_task.cancel() + await server.terminate() + + ################################################# # Test methods ################################################# -def setup_config(multipart=False, tls_enabled=False): +def setup_config(multipart=False, tls_enabled=False, served_paths=""): config_idx = config.CONFIG_CACHE.index("http_multipart") config.CONFIG_CACHE[config_idx + 1] = str(multipart) config_idx = config.CONFIG_CACHE.index("tls") config.CONFIG_CACHE[config_idx + 1] = str(tls_enabled) + config_idx = config.CONFIG_CACHE.index("http_served_paths") + config.CONFIG_CACHE[config_idx + 1] = config.normalize( + "http_served_paths", served_paths + ) enable_optional_features() @@ -246,6 +332,8 @@ def test_main(): asyncio.run(test_multipart_response(tls_enabled=True)) asyncio.run(test_server_busy()) + asyncio.run(test_chunked_transfer_encoding()) + asyncio.run(test_fs_access_control()) test_main() diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py new file mode 100644 index 0000000..6886fcc --- /dev/null +++ b/tests/unit/test_helpers.py @@ -0,0 +1,56 @@ +import unittest +from unittest.mock import patch +from os import getcwd + +from .utils import load_module + + +class TestHelpers(unittest.TestCase): + """ + Base class for stat machine tests. + """ + + @classmethod + def setUpClass(cls): + cls.config = {} + + def setUp(self): + self.helpers_module = load_module("pyrobusta/utils/helpers.py") + + def test_path_normalization_virtual_root(self): + """ + Test lexical path normalization in a Unix-port environment + with a virtual root. Simulates the situation where the process + working directory acts as a virtual filesystem root. + """ + cwd = getcwd() + for case in ( + ("", ""), + ("/path/to/resource", f"{cwd}/path/to/resource"), + ("/path/to/resource/", f"{cwd}/path/to/resource"), + ("///path///to///resource///", f"{cwd}/path/to/resource"), + ("/path/../to/resource", f"{cwd}/to/resource"), + ("/path/./to/resource", f"{cwd}/path/to/resource"), + ("/path/../../resource", f"{cwd}/resource"), + ("/path/../../resource/..", f"{cwd}"), + ): + self.assertEqual(self.helpers_module.normalize_path(case[0]), case[1]) + + @patch("pyrobusta.utils.helpers.getcwd", return_value="/") + def test_path_normalization_host_root(self, _): + """ + Test lexical path normalization assuming the working directory + is the device root ("/"). This simulates the target device environment + where all paths are rooted at "/". + """ + for case in ( + ("", ""), + ("/path/to/resource", "/path/to/resource"), + ("/path/to/resource/", "/path/to/resource"), + ("///path///to///resource///", "/path/to/resource"), + ("/path/../to/resource", "/to/resource"), + ("/path/./to/resource", "/path/to/resource"), + ("/path/../../resource", "/resource"), + ("/path/../../resource/..", "/"), + ): + self.assertEqual(self.helpers_module.normalize_path(case[0]), case[1]) diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index b0c8655..ee04f02 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -35,6 +35,7 @@ def setUp(self): self.set_mock_config(key, value) # Load your web and buffer modules + self.helpers_module = load_module("pyrobusta/utils/helpers.py") buffer_module = load_module("pyrobusta/stream/buffer.py") web_module = load_module("pyrobusta/protocol/http.py") web_module.enable_optional_features() @@ -62,7 +63,7 @@ class TestWebStateMachine(TestWebStateMachineBase): @classmethod def setUpClass(cls): - cls.config = {"http_multipart": "False"} + cls.config = {} def test_status_parsing_valid(self): request = b"GET /index.html HTTP/1.1\r\nContent-Length:10" @@ -161,8 +162,8 @@ def test_routing_unsupported_method(self): self.assertEqual(self.engine.status_code, 405) self.assertEqual(self.engine.state, None) - self.assertIn(b"allow", self.engine.response_headers) - self.assertIn(b"POST", self.engine.response_headers) + self.assertIn(b"allow", self.engine.resp_headers) + self.assertIn(b"POST", self.engine.resp_headers) def test_routing_options_method(self): self.engine.state = self.engine._route_request_st @@ -179,8 +180,8 @@ def test_routing_options_method(self): self.assertEqual(self.engine.status_code, 204) self.assertEqual(self.engine.state, None) - self.assertIn(b"allow", self.engine.response_headers) - self.assertIn(b"GET, POST, PUT", self.engine.response_headers) + self.assertIn(b"allow", self.engine.resp_headers) + self.assertIn(b"GET, POST, PUT", self.engine.resp_headers) def test_routing_get_method(self): self.engine.state = self.engine._route_request_st @@ -311,6 +312,100 @@ def test_empty_or_missing_url_encoded_query_parameter(self): with self.assertRaises(KeyError): self.engine.get_url_encoded_query_param(self.engine.query, "param3") + def test_chunked_transfer_encoding_valid(self): + self.engine.url = b"/api/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.headers["transfer-encoding"] = "chunked" + self.engine.state = self.engine._recv_chunked_size_st + + test_callback = mock.Mock(return_value=("text/plain", "OK")) + self.engine.register("/api/test", test_callback, "GET") + + for chunk in ( + b"14\r\nchunking\r\ntest\r\ncase\r\n", + b"E\r\nchunking\r\ntest\r\n", + b"8\r\nchunking\r\n", + b"0\r\n\r\n", + ): + for i in range(len(chunk)): + self.rx.write(chunk[i : i + 1]) + self.engine.state(self.rx, self.tx) + + self.assertEqual(self.engine.state, self.engine._app_endpoint_st) + self.engine.state(self.rx, self.tx) + size_delimiter = chunk.find(b"\r\n") + test_callback.assert_called_with( + self.engine, chunk[size_delimiter + 2 : -2] + ) + + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + def test_chunked_transfer_encoding_invalid_chunk_size_smaller(self): + self.engine.url = b"/api/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.headers["transfer-encoding"] = "chunked" + self.engine.state = self.engine._recv_chunked_size_st + + test_callback = mock.Mock(return_value=("text/plain", "OK")) + self.engine.register("/api/test", test_callback, "GET") + + chunk = b"2\r\nchunking\r\n" + for i in range(len(chunk)): + self.rx.write(chunk[i : i + 1]) + self.engine.state(self.rx, self.tx) + if self.engine.state is None: + break + + self.assertEqual(self.engine.status_code, 400) + self.assertEqual(self.engine.state, None) + + def test_chunked_transfer_encoding_chunk_incomplete(self): + self.engine.url = b"/api/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.headers["transfer-encoding"] = "chunked" + self.engine.state = self.engine._recv_chunked_size_st + + test_callback = mock.Mock(return_value=("text/plain", "OK")) + self.engine.register("/api/test", test_callback, "GET") + + chunk = b"FF\r\nchunking\r\n" + for i in range(len(chunk)): + self.rx.write(chunk[i : i + 1]) + self.engine.state(self.rx, self.tx) + if self.engine.state is None: + break + + self.assertEqual(self.engine.status_code, None) + self.assertEqual(self.engine.state, self.engine._recv_chunk_st) + + def test_path_serving_list(self): + self.set_mock_config("http_served_paths", "/path/to/dir1 /path/to/dir2") + self.assertEqual(self.engine.is_norm_path_served(""), False) + self.assertEqual(self.engine.is_norm_path_served("/"), False) + self.assertEqual(self.engine.is_norm_path_served("/path/to/dir1"), True) + self.assertEqual(self.engine.is_norm_path_served("/path/to/dir2"), True) + self.assertEqual(self.engine.is_norm_path_served("/path/to/dir12"), False) + self.assertEqual(self.engine.is_norm_path_served("/path/to/dir1/file"), True) + self.assertEqual(self.engine.is_norm_path_served("/path/to/dir2/file"), True) + self.assertEqual(self.engine.is_norm_path_served("/path/to/other"), False) + self.assertEqual(self.engine.is_norm_path_served("/path/to"), False) + + def test_path_serving_root(self): + self.set_mock_config("http_served_paths", "/") + self.assertEqual(self.engine.is_norm_path_served(""), True) + self.assertEqual(self.engine.is_norm_path_served("/"), True) + self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), True) + + def test_path_serving_none(self): + self.set_mock_config("http_served_paths", "") + self.assertEqual(self.engine.is_norm_path_served(""), False) + self.assertEqual(self.engine.is_norm_path_served("/"), False) + self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), False) + class TestMultipartStateMachine(TestWebStateMachineBase): """ @@ -389,7 +484,7 @@ def test_multipart_receiver_complete_part(self): self.engine.headers["content-length"] = 1000 self.engine.mp_boundary = b"test-boundary" self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_closing_delimiter = b"--test-boundary--" + self.engine.mp_last_delimiter = b"--test-boundary--" body_part = ( b'Content-Disposition:form-data;name="file-chunk";filename="upload.txt"\r\n' @@ -405,7 +500,7 @@ def test_multipart_receiver_complete_part(self): self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) self.assertEqual(self.rx.peek(), body_part) - self.assertEqual(self.engine.mp_first_part, True) + self.assertEqual(self.engine.mp_is_first, True) self.engine.state(self.rx, self.tx) @@ -420,8 +515,8 @@ def test_multipart_receiver_complete_part(self): b"Upload content", ), ) - self.assertEqual(self.engine.mp_first_part, False) - self.assertEqual(self.engine.mp_last_part, False) + self.assertEqual(self.engine.mp_is_first, False) + self.assertEqual(self.engine.mp_is_last, False) def test_multipart_receiver_last_part(self): self.engine.state = self.engine._parse_boundary_st @@ -431,7 +526,7 @@ def test_multipart_receiver_last_part(self): self.engine.headers["content-length"] = 131 self.engine.mp_boundary = b"test-boundary" self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_closing_delimiter = b"--test-boundary--" + self.engine.mp_last_delimiter = b"--test-boundary--" test_callback = mock.Mock(return_value=("text/plain", "OK")) self.engine.register("/api/test", test_callback) @@ -465,8 +560,8 @@ def test_multipart_receiver_last_part(self): b"Upload content", ), ) - self.assertEqual(self.engine.mp_first_part, True) - self.assertEqual(self.engine.mp_last_part, True) + self.assertEqual(self.engine.mp_is_first, True) + self.assertEqual(self.engine.mp_is_last, True) if __name__ == "__main__":