diff --git a/README.md b/README.md index 17b1a3c..3e2a4d7 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,41 @@ # PyRobusta -A lightweight HTTP server library for MicroPython designed for constrained embedded systems. +PyRobusta is a memory-conscious HTTP/1.1 server library built for embedded devices where heap usage, connection reliability, and stream processing efficiency matter. PyRobusta offers robust keep-alive connection management and efficient byte-stream processing while maintaining a predictable memory footprint. -## HTTP features -- Routing decorators -- Fixed-size, configurable request/response buffers +## HTTP Features +- Routing decorators and wildcard-based URL matching - Multipart request and response handling -- Chunked transfer decoding for streamed request bodies -- Bounded-copy memory footprint -- Finite-state-machine parser with linear sliding buffer -- Robust byte-stream handling +- Support for chunked encoding and streaming payloads - Query parameter parsing with percent encoding support -- Persistent connections (set by **connection: keep-alive**) +- Built-in API for uploading, downloading, and deleting files stored on the server +- Persistent connection handling via the `Connection: keep-alive` header +- HTTP/1.0 and HTTP/1.1 support - TLS support +## Design Principles + +- Predictable memory usage through fixed-size stream buffers +- Incremental byte-stream processing with bounded memory overhead +- State-machine-driven request parsing for extensibility and protocol correctness +- Reliable connection handling with keep-alive, timeouts, and transport error recovery +- Designed specifically for MicroPython and memory-constrained embedded environments + +## Project Status + +PyRobusta is under active development. The public API is not yet considered +stable and may change between releases. + +Starting with v1.0.0, backwards compatibility will be maintained within each major version. Any backwards-incompatible changes introduced before then are clearly documented in the release notes. + # Installation Install PyRobusta on your MicroPython-enabled device using the mip package manager. -A minimum of 40 KB free heap is required. However, for better usability and stability,\ -devices with more SRAM are strongly recommended. The ESP32-C3 SuperMini is a good\ +A minimum of 40 KB free heap is required. However, for better usability and stability, +devices with more SRAM are strongly recommended. The ESP32-C3 SuperMini is a good entry-level option, providing a comfortable amount of free memory after installation. -If you haven’t already set up your environment, follow the [setup guide](./docs/setup.md) to install\ +If you haven’t already set up your environment, follow the [setup guide](./docs/setup.md) to install mpremote and connect your device to Wi-Fi. @@ -48,14 +61,46 @@ async def main(): asyncio.run(main()) ``` -# Access the Application +# Verify the Installation -Open a web browser and enter your device’s IP address in the address bar.\ -You should see the default homepage. Refer to the included documentation\ -for details on supported use cases and advanced features. +Open a web browser and enter your device’s IP address in the address bar. + +If the server is running correctly, the default homepage will be displayed. +Refer to the documentation for configuration options, routing, streaming +payloads, and advanced HTTP features. ![image info](./docs/img/home_page.png) +## Sample Application + +```python +import asyncio +from gc import mem_free, mem_alloc, collect + +import pyrobusta.server.http_server as http_server +from pyrobusta.protocol.http import HttpEngine + +@HttpEngine.route("/mem-usage", "GET") +def mem_usage(http_ctx, _): + collect() + free = mem_free() + used = mem_alloc() + usage_percentage = 100 * used / (free + used) + return "text/plain", ( + f"Currently used: {usage_percentage:.2f}%\n" + f"Free [bytes]: {free}\n" + f"Used [bytes]: {used}\n" + f"Total [bytes]: {used + free}\n" + ) + +async def main(): + server = http_server.HttpServer() + asyncio.create_task(server.start_socket_server()) + while True: + await asyncio.sleep(1) + +asyncio.run(main()) +``` # Configuration and Optimization @@ -65,5 +110,5 @@ To fine-tune heap usage and optimize performance, see: # Development -Check the provided development guide to create and deploy custom builds\ +Check the provided development guide to create and deploy custom builds to your device: [development guide](./docs/development.md) diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000..e2b6138 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,59 @@ +*** + +# File Management Endpoint (`/files`) + +This endpoint provides file management capabilities, allowing clients to upload, retrieve, and manage files through various HTTP methods. `http_files_api` must be set to `True` in pyrobusta.env to enable this API. + +## Summary + +| Method | Path | Description | +| :------- | :------------------- | :---------- | +| `GET` | `/files/{path}` | Lists or retrieves metadata about files. | +| `PUT` | `/files/{file path}` | Uploads or overwrites a file at the specified path. | +| `POST` | `/files` | Uploads multiple files in multipart/form-data. | +| `DELETE` | `/files/{file path}` | Delete a file at the specified path. | + +--- + +## Endpoint Details + +### 1. File Retrieval/Listing (`GET /files/{path}`) + +This endpoint allows general file system interaction, enabling operations such as listing directory contents and retrieving metadata as well as downloading files. + +* **Method:** `GET` +* **Path:** `/files/{path}` +* **Success Response:** 200 OK. + +### 2. File Upload / Overwrite (`PUT /files/{file path}`) + +This method is used to upload a file or overwrite an existing file at a specific path. +The upload path is restricted to /www/user_data. + +* **Method:** `PUT` +* **Path:** `/files/{file path}` +* **Body:** Raw file content (e.g., binary data). +* **Success Response:** 201 Created. +* **Notes:** `transfer-encoding: chunked` is supported. + +### 3. File Upload (`POST /files`) + +This method handles general file uploads, designed for uploading multiple files with per-file chunking supported. Only multipart/form-data is accepted as a content type. + +The upload path is restricted to /www/user_data, however, content-disposition headers only have to specify the file name, /www/user_data is prepended by default. + +`http_multipart` must be set to `True` in the configuration to use this endpoint. + +* **Method:** `POST` +* **Path:** `/files` +* **Body:** File content encapsulated in multipart/form-data. +* **Success Response:** 201 Created. + +### 4. File Delete (`DELETE /files/{file path}`) + +This method is used to delete a file at a specific path. +The path is restricted to /www/user_data. + +* **Method:** `PUT` +* **Path:** `/files/{file path}` +* **Success Response:** 204 No Content. diff --git a/docs/configuration.md b/docs/configuration.md index e7a9df6..428ad33 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -3,16 +3,16 @@ Configuration can be overridden in pyrobusta.env, in .env format. Create pyrobusta.env in the project root, and run ```make deploy-config``` to upload it to the root directory of the target device. -| Name | Description | Default | -|-------------------|-------------------------------------------------------------------------------------------------------|-------------------------------| -| wifi_ssid | Name of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | -| wifi_password | Password of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | -| http_port | Port number for HTTP. | 80 | -| https_port | Port number for HTTPS. | 443 | -| http_multipart | Enable multipart HTTP requests/responses. | False | -| http_mem_cap | Max memory cap (% × 0.01) of usable heap for HTTP request/response stream buffers. | 0.1 | -| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/www /lib/pyrobusta" | -| http_serve_files | Enable/disable file serving. | True | -| socket_max_con | Max number of socket connections of any enabled application server. | 2 | -| tls | Enable/disable TLS. When turned on, cert.der/key.der must be installed at the root. | False | -| log_level | Can be one of: warning, info, debug. | "info" | +| Name | Description | Default | +| :---------------- | :---------- | :------ | +| wifi_ssid | Name of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | +| wifi_password | Password of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | +| http_port | Port number for HTTP. | 80 | +| https_port | Port number for HTTPS. | 443 | +| http_multipart | Enable multipart HTTP requests/responses. | False | +| http_mem_cap | Max memory cap (% × 0.01) of usable heap for HTTP request/response stream buffers. | 0.1 | +| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/www /lib/pyrobusta" | +| http_files_api | Enables or disables the file management API endpoint (/files), allowing to upload, download, and list files. | False | +| socket_max_con | Max number of socket connections of any enabled application server. | 2 | +| tls | Enables or disables TLS. When turned on, cert.der/key.der must be installed at the root. | False | +| log_level | Can be one of: warning, info, debug. | "info" | diff --git a/docs/dimensioning/http_dimensioning.md b/docs/dimensioning/http_dimensioning.md index 83ebeb3..fe28566 100644 --- a/docs/dimensioning/http_dimensioning.md +++ b/docs/dimensioning/http_dimensioning.md @@ -10,7 +10,7 @@ of parameters relative to a defined baseline configuration. socket_max_con=1 http_mem_cap=0.05 http_multipart=False -http_serve_files=True +http_files_api=False tls=False http_port=8080 https_port=4443 diff --git a/example/boot.py b/example/boot.py new file mode 100644 index 0000000..e286346 --- /dev/null +++ b/example/boot.py @@ -0,0 +1,13 @@ +# This file is executed on every boot (including wake-boot from deepsleep) +import asyncio +import machine +from os import listdir + +from pyrobusta.connectivity import wifi + +connected = wifi.initialize() +if connected and not machine.reset_cause() == machine.SOFT_RESET: + if "app.py" in listdir(): + import app + + asyncio.run(app.main()) diff --git a/example/demo_app/app.py b/example/demo_app/app.py index 4e605fa..6817989 100644 --- a/example/demo_app/app.py +++ b/example/demo_app/app.py @@ -56,3 +56,7 @@ async def main(): asyncio.create_task(server.start_socket_server()) while True: await asyncio.sleep(1) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/example/demo_app/boot.py b/example/demo_app/boot.py deleted file mode 100644 index e286346..0000000 --- a/example/demo_app/boot.py +++ /dev/null @@ -1,13 +0,0 @@ -# This file is executed on every boot (including wake-boot from deepsleep) -import asyncio -import machine -from os import listdir - -from pyrobusta.connectivity import wifi - -connected = wifi.initialize() -if connected and not machine.reset_cause() == machine.SOFT_RESET: - if "app.py" in listdir(): - import app - - asyncio.run(app.main()) diff --git a/example/demo_app/boot.py b/example/demo_app/boot.py new file mode 120000 index 0000000..3345b8c --- /dev/null +++ b/example/demo_app/boot.py @@ -0,0 +1 @@ +../boot.py \ No newline at end of file diff --git a/example/mem_usage/app.py b/example/mem_usage/app.py index 108fb5f..bfde590 100644 --- a/example/mem_usage/app.py +++ b/example/mem_usage/app.py @@ -47,3 +47,7 @@ async def main(): asyncio.create_task(server.start_socket_server()) while True: await asyncio.sleep(1) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/example/mem_usage/boot.py b/example/mem_usage/boot.py deleted file mode 100644 index e286346..0000000 --- a/example/mem_usage/boot.py +++ /dev/null @@ -1,13 +0,0 @@ -# This file is executed on every boot (including wake-boot from deepsleep) -import asyncio -import machine -from os import listdir - -from pyrobusta.connectivity import wifi - -connected = wifi.initialize() -if connected and not machine.reset_cause() == machine.SOFT_RESET: - if "app.py" in listdir(): - import app - - asyncio.run(app.main()) diff --git a/example/mem_usage/boot.py b/example/mem_usage/boot.py new file mode 120000 index 0000000..3345b8c --- /dev/null +++ b/example/mem_usage/boot.py @@ -0,0 +1 @@ +../boot.py \ No newline at end of file diff --git a/example/mip_repo/app.py b/example/mip_repo/app.py index 421d9d5..282921d 100644 --- a/example/mip_repo/app.py +++ b/example/mip_repo/app.py @@ -46,3 +46,7 @@ async def main(): asyncio.create_task(server.start_socket_server()) while True: await asyncio.sleep(1) + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/example/mip_repo/boot.py b/example/mip_repo/boot.py deleted file mode 100644 index e286346..0000000 --- a/example/mip_repo/boot.py +++ /dev/null @@ -1,13 +0,0 @@ -# This file is executed on every boot (including wake-boot from deepsleep) -import asyncio -import machine -from os import listdir - -from pyrobusta.connectivity import wifi - -connected = wifi.initialize() -if connected and not machine.reset_cause() == machine.SOFT_RESET: - if "app.py" in listdir(): - import app - - asyncio.run(app.main()) diff --git a/example/mip_repo/boot.py b/example/mip_repo/boot.py new file mode 120000 index 0000000..3345b8c --- /dev/null +++ b/example/mip_repo/boot.py @@ -0,0 +1 @@ +../boot.py \ No newline at end of file diff --git a/src/pyrobusta/bindings/http_connection.py b/src/pyrobusta/bindings/http_connection.py index ece96aa..8577832 100644 --- a/src/pyrobusta/bindings/http_connection.py +++ b/src/pyrobusta/bindings/http_connection.py @@ -96,11 +96,21 @@ async def _run_state_machine(self): async def _response_handler(self, resp_handler): if "closure" == type(resp_handler).__name__: - for is_finished in resp_handler(self._send_buf): - await self._flush_response() - if is_finished: - break - await sleep_ms(self.STATE_MACHINE_SLEEP_MS) + if self._engine.get_response_header(b"transfer-encoding") == b"chunked": + for is_finished in resp_handler(self._send_buf): + await self.write(b"%x\r\n" % self._send_buf.size()) + await self._flush_response() + await self.write(b"\r\n") + if is_finished: + await self.write(b"0\r\n\r\n") + break + await sleep_ms(self.STATE_MACHINE_SLEEP_MS) + else: + for is_finished in resp_handler(self._send_buf): + await self._flush_response() + if is_finished: + break + await sleep_ms(self.STATE_MACHINE_SLEEP_MS) elif type(resp_handler).__name__ in ("FileIO", "BytesIO"): try: while True: diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 21e6486..b14fead 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -5,13 +5,15 @@ from json import dumps from io import BytesIO +from os import stat from ..utils.config import ( get_config, CONF_HTTP_MULTIPART, - CONF_HTTP_SERVE_FILES, + CONF_HTTP_FILES_API, + CONF_HTTP_SERVED_PATHS, ) -from ..utils import logging +from ..utils import logging, helpers from ..stream.buffer import BufferFullError @@ -44,11 +46,13 @@ class HttpEngine: - allows applications to set response attributes (headers, status code) Feature flags (configured in pyrobusta.env) - - http_serve_files: serve files stored on the device + - http_files_api: serve files at the /files endpoint, with support for uploads, + removal and directory listing - http_multipart: support for multipart requests/responses """ __slots__ = ( + "id", "state", "status_code", "resp_headers", @@ -73,6 +77,8 @@ class HttpEngine: RESP_HEADERS = ( 200, b"200 OK", + 201, + b"201 Created", 204, b"204 No Content", 400, @@ -95,6 +101,31 @@ class HttpEngine: b"505 Version Not Supported", ) + CONTENT_TYPES = ( + b"raw", + b"application/octet-stream", + b"html", + b"text/html", + b"css", + b"text/css", + b"js", + b"application/javascript", + b"json", + b"application/json", + b"ico", + b"image/x-icon", + b"jpeg", + b"image/jpeg", + b"jpg", + b"image/jpeg", + b"png", + b"image/png", + b"txt", + b"text/plain", + b"gif", + b"image/gif", + ) + DELETE = b"DELETE" GET = b"GET" HEAD = b"HEAD" @@ -103,9 +134,19 @@ class HttpEngine: PUT = b"PUT" METHODS = (DELETE, GET, HEAD, OPTIONS, POST, PUT) SUPPORTED_VERSIONS = (b"HTTP/1.1", b"HTTP/1.0") + SESSION_COUNTER = 0 + + @classmethod + def new_session_id(cls): + """ + Create a new unique ID for the HTTP session. + """ + cls.SESSION_COUNTER = (cls.SESSION_COUNTER + 1) & 0xFFFFFFFF + return cls.SESSION_COUNTER def __init__(self): # [State machine] + self.id = self.new_session_id() self.state = self._start_parser self.status_code = None self.resp_headers = [] @@ -129,6 +170,7 @@ def reset(self): """ Reset internal state to reuse a state machine object. """ + self.id = self.new_session_id() self.state = self._start_parser self.status_code = None self.resp_headers.clear() @@ -151,9 +193,9 @@ def reset(self): @classmethod def register(cls, endpoint: str, callback: callable, method: str = "GET") -> None: """ - Register an endpoint with a callback function or file. + Register an endpoint with a callback function. :param endpoint: URL path to be routed e.g. "/app/resource" - :param callback: callback function or file path + :param callback: callback function :param method: HTTP method name """ endpoint = endpoint.encode("ascii") @@ -166,6 +208,19 @@ def register(cls, endpoint: str, callback: callable, method: str = "GET") -> Non raise ValueError("endpoint exists") cls.ENDPOINTS.append((endpoint, callback, method)) + @classmethod + def deregister(cls, endpoint: str, method: str) -> None: + """ + Deregister an endpoint. + :param endpoint: URL path to be routed e.g. "/app/resource" + :param method: HTTP method name + """ + endpoint = endpoint.encode("ascii") + method = method.encode("ascii") + + if callback := cls._get_callback(endpoint, method): + cls.ENDPOINTS.remove((endpoint, callback, method)) + @staticmethod def route(endpoint: str, method: str): """ @@ -229,7 +284,9 @@ def _is_matching_url_path(path: bytes, pattern: bytes) -> bool: """ Match a URL path against a pattern that can contain wildcard segments e.g. /path/{wildcard}/resource where {wildcard} matches any non-empty - string in that segment. + string in that segment. /path/to/{wildcard:path} matches multiple path + segments, only allowed for trailing segments. + (e.g. "/{wildcard:path}/resource" is forbidden) """ if path == pattern: return True @@ -253,6 +310,8 @@ def _is_matching_url_path(path: bytes, pattern: bytes) -> bool: and len(path_seg) > 0 ): return False + if pat_seg.endswith(b":path}"): + return True i = ni + 1 j = nj + 1 return i >= n and j >= m @@ -383,13 +442,24 @@ def set_response_header(self, key: bytes, value: bytes): """ if ( key in self.resp_headers - and (index := self.resp_headers.index(key) % 2) == 0 + and (index := self.resp_headers.index(key)) % 2 == 0 ): self.resp_headers[index + 1] = value else: self.resp_headers.append(key) self.resp_headers.append(value) + def get_response_header(self, key: bytes): + """ + Get a response header by key. + :param key: HTTP header key + """ + if ( + key in self.resp_headers + and (index := self.resp_headers.index(key)) % 2 == 0 + ): + return self.resp_headers[index + 1] + def write_response_head(self, tx): """ Write response status and header to an output buffer. @@ -500,7 +570,7 @@ def is_terminated(self): def run(self, rx): """ - Run the state machine with request buffers provided. + Run the state machine, consuming the content of a request buffer (rx). Unlike individual states, this method does not raise an exception. This method yields on every state transition allowing the calling side to flush the response buffer. @@ -648,7 +718,7 @@ def _route_request_st(self, _): return # Fallback: serve file if self.method in (self.GET, self.HEAD): - self.state = lambda _rx: self._send_file_st(_rx, self.url) + self.state = self._fs_retrieve_st return self.terminate(404) @@ -696,39 +766,71 @@ def _app_endpoint_st(self, rx): rx.consume(self.recv_chunk_size + 2) self.state = self._recv_chunk_size_st return - dtype, data = callback(self, bytes(rx.peek(self.recv_chunk_size))) + callback_response = callback(self, b"") rx.consume(self.recv_chunk_size + 2) else: - dtype, data = callback( + callback_response = callback( self, bytes(rx.peek(self.headers["content-length"])) ) else: - if not callable(callback): - # Handle as a file path - self.state = lambda _rx: self._send_file_st( - _rx, callback.encode("ascii") - ) - return - dtype, data = callback(self, b"") + callback_response = callback(self, b"") + + if not self.is_terminated(): + self.terminate(200, True) + if callback_response is None: + return + + dtype, data = callback_response if dtype.startswith("multipart/"): self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) return - if not self.is_terminated(): - self.terminate(200, True) self.set_response_body(data, content_type=dtype) - def _send_file_st(self, _, path: bytes): # pylint: disable=W0613 + def _fs_retrieve_st(self, _): """ - State for returning including a file in the response body (disabled). - :param path: path to the resource + State for retrieving a file under /www. + /www is prepended to the path by default. """ - self.terminate(503, True) + if self.url == b"/": + target_path = "/www/index.html" + else: + target_path = "/www" + self.url.decode("ascii") + + norm_path = helpers.normalize_path(target_path) + is_path_served = helpers.is_norm_path_served( + norm_path, get_config(CONF_HTTP_SERVED_PATHS) + ) + + try: + if not is_path_served: + stat(norm_path) + self.terminate(403, True) + return + + try: + extension = target_path.rsplit(".", 1)[-1] + content_type = self._lookup( + self.CONTENT_TYPES, extension.encode("ascii") + ) + except ValueError: + content_type = self._lookup(self.CONTENT_TYPES, b"raw") + + self.set_response_header( + b"content-length", str(stat(norm_path)[6]).encode("ascii") + ) + self.set_response_header(b"content-type", content_type) + self.terminate(200, True) + if self.method != self.HEAD: + self.resp_handler = open(norm_path, "rb") # pylint: disable=R1732 + return + except OSError: + self.terminate(404, True) def _start_multipart_parser_st(self, rx): # pylint: disable=W0613 """ - Initial state for processing multipart requests (disabled). + Initial state for processing multipart requests (placeholder). """ self.terminate(503) @@ -736,7 +838,7 @@ def _generate_multipart_response( self, rx, callback, dtype ): # pylint: disable=W0613 """ - Generate multipart response depening on the exact content type (disabled). + Generate multipart response depening on the exact content type (placeholder). """ self.terminate(503, True) @@ -750,7 +852,7 @@ def enable_optional_features(): http_multipart.apply_patches() - if get_config(CONF_HTTP_SERVE_FILES): + if get_config(CONF_HTTP_FILES_API): from pyrobusta.protocol import http_file_server http_file_server.apply_patches() diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index 767fb5e..c051e13 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -1,102 +1,299 @@ """ -State machine extension for file serving. +Module for extended file serving features, registered at the /files endpoint. """ # pylint: disable=W0212,R0401 -from os import stat +from os import stat, listdir, rmdir, remove, rename, mkdir +from json import dumps from pyrobusta.protocol import http -from pyrobusta.utils.helpers import normalize_path, add_method +from pyrobusta.utils.helpers import ( + normalize_path, + is_norm_path_served, + is_file_path_valid, + is_path_segment_valid, +) +from pyrobusta.utils.assets import iterate_fs, FS_ITER_FILE from ..utils.config import ( get_config, CONF_HTTP_SERVED_PATHS, ) -CONTENT_TYPES = ( - b"raw", - b"application/octet-stream", - b"html", - b"text/html", - b"css", - b"text/css", - b"js", - b"application/javascript", - b"json", - b"application/json", - b"ico", - b"image/x-icon", - b"jpeg", - b"image/jpeg", - b"jpg", - b"image/jpeg", - b"png", - b"image/png", - b"txt", - b"text/plain", - b"gif", - b"image/gif", -) +_UPLOAD_ROOT = normalize_path("/www/user_data") +_TMP_DIR = normalize_path("/tmp") -def is_norm_path_served(path: str): - """ - Returns true if a normalized path is configured to be served. - """ - served_paths = get_config(CONF_HTTP_SERVED_PATHS) - parts = path.split("/") - for i, _ in enumerate(parts): - current_path = "/".join(parts[: i + 1]) - if not current_path: - current_path = "/" - if current_path in served_paths: - return True - return False +################################################# +# CRUD methods +################################################# -def _send_file_st(self, _, file_path: bytes): +def fs_retrieve(http_ctx, _): """ - State for returning a file. By default, /www is prepended to the path. - Alternatively, ready any file from the root when the path starts with /files - if it is configured in http_served_paths. - :param file_path: path to the file (unnormalized) + State for retrieving a file or a directory structure. + The http_served_paths configuration controls which files/directories + can be retrieved. """ - if self.url == b"/files": - file_path = "/" - elif self.url.startswith(b"/files/"): - file_path = file_path[7:] - elif self.url == b"/": - file_path = b"/www/index.html" - else: - file_path = b"/www" + file_path + target_path = http_ctx.url[len(b"/files") :].decode("ascii") + norm_path = normalize_path(target_path) + is_path_served = is_norm_path_served(norm_path, get_config(CONF_HTTP_SERVED_PATHS)) - extension = file_path.rsplit(b".", 1)[-1] - norm_path = normalize_path(file_path.decode("ascii")) - is_path_served = self.is_norm_path_served(norm_path) - if not is_path_served: - try: + try: + if not is_path_served: stat(norm_path) - self.terminate(403, True) - return - except OSError: - self.terminate(404, True) + http_ctx.terminate(403, True) + return "text/plain", "Forbidden" + + # Retrieve directory structure + if stat(norm_path)[0] & 0x4000: + http_ctx.set_response_header(b"content-type", b"application/json") + http_ctx.set_response_header(b"transfer-encoding", b"chunked") + http_ctx.terminate(200, True) + http_ctx.resp_handler = _traverse_dir_factory(norm_path) return - try: - content_type = self._lookup(CONTENT_TYPES, extension) - except ValueError: - content_type = self._lookup(CONTENT_TYPES, b"raw") - try: - self.set_response_header( + + # Retrieve file + try: + extension = target_path.rsplit(".", 1)[-1] + content_type = http_ctx._lookup( + http_ctx.CONTENT_TYPES, extension.encode("ascii") + ) + except ValueError: + content_type = http_ctx._lookup(http_ctx.CONTENT_TYPES, b"raw") + + http_ctx.set_response_header( b"content-length", str(stat(norm_path)[6]).encode("ascii") ) - self.set_response_header(b"content-type", content_type) - self.terminate(200, True) - if self.method != self.HEAD: - self.resp_handler = open(norm_path, "rb") # pylint: disable=R1732 - return + http_ctx.set_response_header(b"content-type", content_type) + http_ctx.terminate(200, True) + if http_ctx.method != http_ctx.HEAD: + http_ctx.resp_handler = open(norm_path, "rb") # pylint: disable=R1732 + except OSError: + http_ctx.terminate(404, True) + return "text/plain", "Not found" + + +def delete_file(http_ctx, _): + """ + Callback function for handling delete operations. The URL path + must point to the exact file or directory path under _UPLOAD_ROOT. + Only empty directories can be deleted. + """ + + fs_path = normalize_path(http_ctx.url.decode("ascii")[6:]) + + try: + if not fs_path.startswith(_UPLOAD_ROOT): + stat(fs_path) + http_ctx.terminate(403, True) + return "text/plain", "Forbidden" + + # Delete directory structure + if stat(fs_path)[0] & 0x4000: + if listdir(fs_path): + http_ctx.terminate(400, True) + return "text/plain", "Directory not empty" + rmdir(fs_path) + http_ctx.terminate(204, True) + return "text/plain", "Deleted" + + # Delete file + remove(fs_path) + http_ctx.terminate(204, True) + return "text/plain", "Deleted" + except OSError: + http_ctx.terminate(404, True) + return "text/plain", "Not found" + + +def upload_file(http_ctx, payload: bytes): + """ + Callback function for handling single file uploads, supporting chunked transfer encoding. + Uploads are saved to _UPLOAD_ROOT, with the name determined by the URL path. + """ + content_type = http_ctx.headers.get("content-type") + if content_type and content_type.lower().startswith("multipart/"): + http_ctx.terminate(400) + return "text/plain", "Bad request" + + is_chunked = http_ctx.headers.get("transfer-encoding") == "chunked" + + if is_chunked: + url_path = http_ctx.url.decode("ascii") + file_name_idx = url_path.rfind("/") + 1 + if not file_name_idx: + http_ctx.terminate(400) + return "text/plain", "Bad request" + file_path = normalize_path( + _TMP_DIR + "/" + f"{url_path[file_name_idx:]}.{http_ctx.id}" + ) + else: + file_path = normalize_path(http_ctx.url.decode("ascii")[6:]) + + if not is_file_path_valid(file_path): + http_ctx.terminate(400) + return "text/plain", "Invalid or missing filename" + + try: + if not file_path.startswith(_UPLOAD_ROOT) and not file_path.startswith( + _TMP_DIR + ): + http_ctx.terminate(403, True) + return "text/plain", "Forbidden" + + if is_chunked: + if not payload: # Last chunk received, finalize upload + rename(file_path, normalize_path(http_ctx.url.decode("ascii")[6:])) + else: + with open(file_path, "ab") as f: + f.write(payload) + else: + with open(file_path, "wb") as f: + f.write(payload) + + http_ctx.terminate(201, True) + return "text/plain", "OK" except OSError: - self.terminate(404, True) + http_ctx.terminate(404, True) + return "text/plain", "Not found" + + +def bulk_upload_file(http_ctx, payload: tuple): + """ + Callback function for handling bulk file uploads (Content-Type: multipart/form-data) + This callback is invoked on every part. Every file is saved to _UPLOAD_ROOT, with + the name determined by the content disposition header. When two parts specify the + same file name, the content of the second part is appended to the first part. + Split files to multiple parts for chunking large files to avoid HTTP 413 errors. + """ + content_type = http_ctx.headers.get("content-type") + if not content_type or not content_type.lower().startswith("multipart/form-data"): + http_ctx.terminate(400) + return "text/plain", "Bad request" + + part_headers, part_body = payload + + try: + filename = get_filename(part_headers) + except ValueError: + http_ctx.terminate(415) + return "text/plain", "Invalid content disposition" + + if not is_path_segment_valid(filename): + http_ctx.terminate(400) + return "text/plain", "Invalid or missing filename" + + # Clean stale partial uploads + if http_ctx.mp_is_first: + for file in listdir(_TMP_DIR): + if file.endswith(f".{http_ctx.id}"): + remove(_TMP_DIR + "/" + file) + + # TODO: support X-Upload-Directory; pylint: disable=W0511 + target_path = normalize_path(_TMP_DIR + "/" + f"{filename}.{http_ctx.id}") + with open(target_path, "ab") as f: + f.write(part_body) + + # Finalize uploads + if http_ctx.mp_is_last: + suffix = f".{http_ctx.id}" + for file in listdir(_TMP_DIR): + if file.endswith(suffix): + rename(_TMP_DIR + "/" + file, _UPLOAD_ROOT + "/" + file[: -len(suffix)]) + + http_ctx.terminate(201, True) + return "text/plain", "OK" + + +################################################# +# Helper functions +################################################# + + +def get_filename(part_headers: dict): + """ + Get filename field from content-disposition headers. + :param part_headers: headers of an individual part + """ + cd = part_headers.get("content-disposition", "") + filename = None + + if cd[: min(max(cd.find(";"), 0), len(cd))].strip() != "form-data": + raise ValueError() + + f_start = cd.find(";") + 1 + f_end = cd.find(";", f_start) + + while f_start < len(cd): + f_end = len(cd) if f_end == -1 else f_end + parameter = cd[f_start:f_end].split("=") + if len(parameter) == 2 and parameter[0].strip() == "filename": + filename = parameter[1].strip().strip("'").strip('"') + f_start = f_end + 1 + f_end = cd.find(";", f_start) + + return filename + + +def _traverse_dir_factory(path): + """ + Factory method for creating a response handler closure + for directory content traversal. + :param path: normalized path to the directory to traverse + """ + + def _traverse_dir(tx): + """ + Traverse a directory and produce a JSON-formatted + response of the directory contents. + :param tx: response buffer + """ + tx.write(b"[") + for i, it in enumerate(iterate_fs(path, FS_ITER_FILE)): + if i != 0: + tx.write(b",") + + file_stat = stat(it) + obj = dumps( + { + "path": it, + "size": str(file_stat[6]), + "created": str(file_stat[9]), + } + ).encode("ascii") + + written = 0 + while written < len(obj): + to_write = tx.capacity - tx.size() + if not to_write: + raise BufferError() + tx.write(obj[written : written + to_write]) + written += to_write + yield False + tx.write(b"]\r\n") + yield True + + return _traverse_dir + + +def setup_directories(): + """ + Set up the required directories for file uploads. + """ + for http_dir in (_UPLOAD_ROOT, _TMP_DIR): + base_dir = normalize_path("/") + sub_dirs = http_dir[len(base_dir) :].lstrip("/") + + for subdir in sub_dirs.split("/"): + current_dir = base_dir + "/" + subdir + if not subdir in listdir(base_dir): + mkdir(current_dir) + base_dir = current_dir + + for file in listdir(_TMP_DIR): + remove(_TMP_DIR + "/" + file) def apply_patches(): @@ -104,5 +301,14 @@ def apply_patches(): Apply patches to class attributes for file serving. """ - add_method(http.HttpEngine, _send_file_st) - add_method(http.HttpEngine, is_norm_path_served, "static") + http.HttpEngine.deregister("/files/{fs_path:path}", "GET") + http.HttpEngine.deregister("/files/{fs_path:path}", "DELETE") + http.HttpEngine.deregister("/files/{fs_path:path}", "PUT") + http.HttpEngine.deregister("/files", "POST") + + http.HttpEngine.register("/files/{fs_path:path}", fs_retrieve, "GET") + http.HttpEngine.register("/files/{fs_path:path}", delete_file, "DELETE") + http.HttpEngine.register("/files/{fs_path:path}", upload_file, "PUT") + http.HttpEngine.register("/files", bulk_upload_file, "POST") + + setup_directories() diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index 7e8cb5d..966c00e 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -137,7 +137,9 @@ def _parse_complete_part_st(self, rx): raise http.InvalidContentLength() self.mp_is_last = True dtype, data = callback(self, (part_headers, part_body)) - self.terminate(200, True) + + if not self.is_terminated(): + self.terminate(200, True) self.set_response_body(data, dtype) diff --git a/src/pyrobusta/utils/config.py b/src/pyrobusta/utils/config.py index 40b54c9..037ec28 100644 --- a/src/pyrobusta/utils/config.py +++ b/src/pyrobusta/utils/config.py @@ -28,7 +28,7 @@ def const(n): # pylint: disable=C0116 CONF_HTTP_MULTIPART = const(4) CONF_HTTP_MEM_CAP = const(5) CONF_HTTP_SERVED_PATHS = const(6) -CONF_HTTP_SERVE_FILES = const(7) +CONF_HTTP_FILES_API = const(7) CONF_SOCKET_MAX_CON = const(8) CONF_TLS = const(9) CONF_LOG_LEVEL = const(10) @@ -52,8 +52,8 @@ def const(n): # pylint: disable=C0116 0.1, CONF_HTTP_SERVED_PATHS, ["/www", "/lib/pyrobusta"], - CONF_HTTP_SERVE_FILES, - True, + CONF_HTTP_FILES_API, + False, CONF_SOCKET_MAX_CON, 2, CONF_TLS, @@ -70,7 +70,7 @@ def parse_config(key, value): """ Normalize a configuration value depending on the key. """ - if key in (CONF_HTTP_MULTIPART, CONF_HTTP_SERVE_FILES, CONF_TLS): + if key in (CONF_HTTP_MULTIPART, CONF_HTTP_FILES_API, CONF_TLS): return value.lower() == "true" if key in (CONF_HTTP_PORT, CONF_HTTPS_PORT, CONF_SOCKET_MAX_CON): return int(value) diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py index b713309..e90f282 100644 --- a/src/pyrobusta/utils/helpers.py +++ b/src/pyrobusta/utils/helpers.py @@ -29,6 +29,56 @@ def normalize_path(path: str): return cwd +def is_norm_path_served(path: str, served_paths: list): + """ + Returns true if a normalized path is configured to be served. + :param path: path to check + :param served_paths: list of paths configured to be served + """ + parts = path.split("/") + for i, _ in enumerate(parts): + current_path = "/".join(parts[: i + 1]) + if not current_path: + current_path = "/" + if current_path in served_paths: + return True + return False + + +def is_file_path_valid(file_path: str): + """ + Returns true if an absolute file path is valid. + """ + if file_path[0] != "/": + return False + segment_start = 1 + + while True: + next_segment = file_path.find("/", segment_start + 1) + 1 + if not next_segment: + return is_path_segment_valid(file_path[segment_start:]) + if not is_path_segment_valid(file_path[segment_start : next_segment - 1]): + return False + segment_start = next_segment + + +def is_path_segment_valid(filename: str): + """ + Returns true if a filename is valid. + """ + if ( + not filename + or len(filename) > 32 + or not all( + ("A" <= c <= "Z") or ("a" <= c <= "z") or ("0" <= c <= "9") or c in "._-" + for c in filename + ) + or filename in (".", "..") + ): + return False + return True + + def add_method(cls, func: callable, method_type="instance"): """ Helper to patch/extend classes with additional methods and states. diff --git a/tests/.pylintrc b/tests/.pylintrc index 5cf251e..10e5469 100644 --- a/tests/.pylintrc +++ b/tests/.pylintrc @@ -4,4 +4,5 @@ disable=W0212, C0115, C0116, R0904, - R0902 + R0902, + R0801 diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py index 7618c9c..bcce7c4 100644 --- a/tests/functional/test_http.py +++ b/tests/functional/test_http.py @@ -3,12 +3,13 @@ import json import gc -from os import mkdir, remove, rmdir +from os import mkdir, remove, rmdir, stat, listdir from pyrobusta.server import http_server -from pyrobusta.protocol.http import HttpEngine, enable_optional_features +from pyrobusta.protocol.http import HttpEngine from pyrobusta.utils.config import ( CONF_HTTP_SERVED_PATHS, + CONF_HTTP_FILES_API, CONF_TLS, CONF_LOG_LEVEL, _CONFIG_CACHE, @@ -66,6 +67,30 @@ async def send_request(request, tls=False): return response +def fmkdir(path: str): + try: + mkdir(path) + except OSError: + pass + + +def delete_path(path): + for name in listdir(path): + if path == "/": + full = "/" + name + else: + full = path + "/" + name + + try: + remove(full) + except OSError: + delete_path(full) + try: + rmdir(full) + except OSError: + pass + + ################################################# # Test driver ################################################# @@ -111,48 +136,49 @@ async def test_simple_response(tls_enabled): setup_config(tls_enabled=tls_enabled) server, server_task = await start_server() - # Test: text/plain - plain_response = await send_request( - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Connection: close\r\n" - b"Accept:text/plain\r\n" - b"\r\n", - tls_enabled, - ) - test_assert( - f"http{"s" if tls_enabled else ""} response contains text/plain header", - b"text/plain" in plain_response, - True, - ) - test_assert( - f"http{"s" if tls_enabled else ""} response contains text/plain body", - b"Test response" in plain_response, - True, - ) - - # Test: application/json - json_response = await send_request( - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Connection: close\r\n" - b"Accept: application/json\r\n" - b"\r\n", - tls_enabled, - ) - test_assert( - f"http{"s" if tls_enabled else ""} response contains application/json header", - b"application/json" in json_response, - True, - ) - test_assert( - f"http{"s" if tls_enabled else ""} response contains application/json body", - b'{"response": "Test response"}' in json_response, - True, - ) - - server_task.cancel() - await server.terminate() + try: + # Test: text/plain + plain_response = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept:text/plain\r\n" + b"\r\n", + tls_enabled, + ) + test_assert( + f"http{"s" if tls_enabled else ""} response contains text/plain header", + b"text/plain" in plain_response, + True, + ) + test_assert( + f"http{"s" if tls_enabled else ""} response contains text/plain body", + b"Test response" in plain_response, + True, + ) + + # Test: application/json + json_response = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept: application/json\r\n" + b"\r\n", + tls_enabled, + ) + test_assert( + f"http{"s" if tls_enabled else ""} response contains application/json header", + b"application/json" in json_response, + True, + ) + test_assert( + f"http{"s" if tls_enabled else ""} response contains application/json body", + b'{"response": "Test response"}' in json_response, + True, + ) + finally: + server_task.cancel() + await server.terminate() @garbage_collect @@ -160,19 +186,20 @@ async def test_server_busy(): setup_config() server, server_task = await start_server() - plain_response = await send_request( - b"POST /test/busy HTTP/1.1\r\n" - b"Connection:close\r\n" - b"Host: localhost\r\n\r\n" - ) - test_assert( - f"response is rejected by busy service with 503", - b"503 Service Unavailable" in plain_response, - True, - ) - - server_task.cancel() - await server.terminate() + try: + plain_response = await send_request( + b"POST /test/busy HTTP/1.1\r\n" + b"Connection:close\r\n" + b"Host: localhost\r\n\r\n" + ) + test_assert( + f"response is rejected by busy service with 503", + b"503 Service Unavailable" in plain_response, + True, + ) + finally: + server_task.cancel() + await server.terminate() @garbage_collect @@ -181,84 +208,144 @@ async def test_chunked_transfer_encoding(): create_chunked_app_endpoint("/test/chunked") server, server_task = await start_server() - json_response = await send_request( - b"POST /test/chunked HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Connection: close\r\n" - b"Transfer-Encoding: chunked\r\n\r\n" - b"14\r\nchunking\r\ntest\r\ncase\r\n" - b"E\r\nchunking\r\ntest\r\n" - b"8\r\nchunking\r\n" - b"0\r\n\r\n" - ) - response_body = json.loads(json_response.split(b"\r\n\r\n")[1]) - test_assert( - f"chunked transfer encoding - all chunks are received", - response_body, - ["chunking\r\ntest\r\ncase", "chunking\r\ntest", "chunking"], - ) - - server_task.cancel() - await server.terminate() + try: + json_response = await send_request( + b"POST /test/chunked HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Transfer-Encoding: chunked\r\n\r\n" + b"14\r\nchunking\r\ntest\r\ncase\r\n" + b"E\r\nchunking\r\ntest\r\n" + b"8\r\nchunking\r\n" + b"0\r\n\r\n" + ) + response_body = json.loads(json_response.split(b"\r\n\r\n")[1]) + test_assert( + f"chunked transfer encoding - all chunks are received", + response_body, + ["chunking\r\ntest\r\ncase", "chunking\r\ntest", "chunking"], + ) + finally: + server_task.cancel() + await server.terminate() @garbage_collect async def test_fs_access_control(): - setup_config(served_paths="/www/allowed") + setup_config(served_paths="/www/test/allowed") server, server_task = await start_server() - workdir_root = normalize_path("/www") - try: - mkdir(workdir_root) - except: - pass - - # Index page under /www -> accepted - allowed_workdir = normalize_path("/www/allowed") - allowed_index_html = normalize_path("/www/allowed/index.html") - mkdir(allowed_workdir) + www_root = normalize_path("/www") + test_root = normalize_path("/www/test") + fmkdir(www_root) + fmkdir(test_root) + + # Index page under /test -> accepted + allowed_workdir = normalize_path("/www/test/allowed") + allowed_index_html = normalize_path("/www/test/allowed/index.html") + fmkdir(allowed_workdir) with open(allowed_index_html, "w") as f: f.write("PyRobusta Home") # Index page under / -> rejected - rejected_workdir = normalize_path("/www/rejected") - rejected_index_html = normalize_path("/www/rejected/index.html") - mkdir(rejected_workdir) + rejected_workdir = normalize_path("/www/test/rejected") + rejected_index_html = normalize_path("/www/test/rejected/index.html") + fmkdir(rejected_workdir) with open(rejected_index_html, "w") as f: f.write("PyRobusta Home") - # Case #1: /www/index.html - response = await send_request( - b"GET /allowed/index.html HTTP/1.1\r\n" - b"Connection: close\r\n" - b"Host: localhost\r\n\r\n" - ) + try: + # Case #1: /test/allowed/index.html + response = await send_request( + b"GET /test/allowed/index.html HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Host: localhost\r\n\r\n" + ) + + response_body = response.split(b"\r\n\r\n")[1] + test_assert( + f"test FS access control - index page loaded", + response_body, + b"PyRobusta Home", + ) + + # Case #2: /test/rejected/index.html + response = await send_request( + b"GET /test/rejected/index.html HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Host: localhost\r\n\r\n" + ) + + test_assert( + f"test FS access control - index page rejected", + response.startswith(b"HTTP/1.1 403 Forbidden"), + True, + ) + finally: + delete_path(test_root) + server_task.cancel() + await server.terminate() - response_body = response.split(b"\r\n\r\n")[1] - test_assert( - f"test FS access control - index page loaded", - response_body, - b"PyRobusta Home", - ) - # Case #2: /index.html - response = await send_request( - b"GET /rejected/index.html HTTP/1.1\r\n" - b"Connection: close\r\n" - b"Host: localhost\r\n\r\n" - ) +@garbage_collect +async def test_fs_path_traversal(): + setup_config(served_paths="/test", files_api_enabled=True) + server, server_task = await start_server() + test_root = normalize_path("/test") + styles_dir = normalize_path("/test/style") + fmkdir(test_root) + fmkdir(styles_dir) - test_assert( - f"test FS access control - index page rejected", - response.startswith(b"HTTP/1.1 403 Forbidden"), - True, - ) + index_html = normalize_path("/test/index.html") + styles_css = normalize_path("/test/style/styles.css") - remove(allowed_index_html) - remove(rejected_index_html) - rmdir(allowed_workdir) - rmdir(rejected_workdir) - server_task.cancel() - await server.terminate() + with open(index_html, "w") as f: + f.write("PyRobusta Home") + with open(styles_css, "w") as f: + f.write("/* This is the main stylesheet */") + + try: + # Test case + response = await send_request( + b"GET /files/test HTTP/1.1\r\n" + b"Connection: close\r\n" + b"Host: localhost\r\n\r\n" + ) + + # Decode chunked transfer encoding + response_body = response.split(b"\r\n\r\n")[1] + response_body_decoded = b"" + start = 0 + + while start < len(response_body): + cursor = response_body.index(b"\r\n", start) + chunk_size = int(response_body[start:cursor], 16) + if chunk_size == 0: + break + chunk_start = cursor + 2 + chunk_end = chunk_start + chunk_size + response_body_decoded += response_body[chunk_start:chunk_end] + start = chunk_end + 2 + + test_assert( + f"test FS path traversal - JSON chunks received", + json.loads(response_body_decoded), + [ + { + "path": index_html, + "created": str(stat(index_html)[9]), + "size": str(stat(index_html)[6]), + }, + { + "path": styles_css, + "created": str(stat(styles_css)[9]), + "size": str(stat(styles_css)[6]), + }, + ], + ) + finally: + delete_path(test_root) + server_task.cancel() + await server.terminate() @garbage_collect @@ -266,83 +353,84 @@ async def test_keepalive(): setup_config() server, server_task = await start_server() - # ---------------------------------- - # Case 1: all requests are processed - # ---------------------------------- - plain_responses = await send_request( - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Connection: close\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - ) - - test_assert( - f"contains all responses (connection: keep-alive)", - plain_responses.count(b"HTTP/1.1 200 OK"), - 3, - ) - - # ------------------------------------------------------------------- - # Case 2: close connection after the second request (invalid framing) - # ------------------------------------------------------------------- - plain_responses = await send_request( - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"" - b"Accept:text/plain\r\n" - b"\r\n" - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - ) - - test_assert( - f"contains two responses (connection: keep-alive, invalid framing)", - plain_responses.count(b"HTTP/1.1"), - 2, - ) - - # ------------------------------------------------ - # Case 3: close connection after the first request - # ------------------------------------------------ - plain_response = await send_request( - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Connection: close\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - b"GET /test/simple HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Accept:text/plain\r\n" - b"\r\n" - ) - - test_assert( - f"contains single response (connection: close)", - plain_response.count(b"HTTP/1.1 200 OK"), - 1, - ) - - server_task.cancel() - await server.terminate() + try: + # ---------------------------------- + # Case 1: all requests are processed + # ---------------------------------- + plain_responses = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + ) + + test_assert( + f"contains all responses (connection: keep-alive)", + plain_responses.count(b"HTTP/1.1 200 OK"), + 3, + ) + + # ------------------------------------------------------------------- + # Case 2: close connection after the second request (invalid framing) + # ------------------------------------------------------------------- + plain_responses = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + ) + + test_assert( + f"contains two responses (connection: keep-alive, invalid framing)", + plain_responses.count(b"HTTP/1.1"), + 2, + ) + + # ------------------------------------------------ + # Case 3: close connection after the first request + # ------------------------------------------------ + plain_response = await send_request( + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + b"GET /test/simple HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept:text/plain\r\n" + b"\r\n" + ) + + test_assert( + f"contains single response (connection: close)", + plain_response.count(b"HTTP/1.1 200 OK"), + 1, + ) + finally: + server_task.cancel() + await server.terminate() ################################################# @@ -350,7 +438,7 @@ async def test_keepalive(): ################################################# -def setup_config(tls_enabled=False, served_paths=""): +def setup_config(tls_enabled=False, served_paths="", files_api_enabled=False): http_server.HttpServer.LISTEN_PORT_HTTP = 8080 http_server.HttpServer.LISTEN_PORT_HTTPS = 4443 @@ -359,7 +447,7 @@ def setup_config(tls_enabled=False, served_paths=""): _CONFIG_CACHE[2 * CONF_HTTP_SERVED_PATHS + 1] = parse_config( CONF_HTTP_SERVED_PATHS, served_paths ) - enable_optional_features() + _CONFIG_CACHE[2 * CONF_HTTP_FILES_API + 1] = files_api_enabled def test_registration(): @@ -384,6 +472,7 @@ def test_main(): asyncio.run(test_server_busy()) asyncio.run(test_chunked_transfer_encoding()) asyncio.run(test_fs_access_control()) + asyncio.run(test_fs_path_traversal()) asyncio.run(test_keepalive()) diff --git a/tests/system/http_dimensioning/test.py b/tests/system/http_dimensioning/test.py index aa48adb..562440c 100644 --- a/tests/system/http_dimensioning/test.py +++ b/tests/system/http_dimensioning/test.py @@ -23,7 +23,7 @@ "socket_max_con": 1, "http_mem_cap": 0.05, "http_multipart": False, - "http_serve_files": True, + "http_files_api": False, "tls": False, "http_port": 8080, "https_port": 4443, @@ -343,7 +343,7 @@ def test_config_delta(device_name, device_ip, base_config, config_delta={}): "https_port", "http_served_paths", "log_level", - "http_serve_files", + "http_files_api", }, ) diff --git a/tests/unit/http_base.py b/tests/unit/http_base.py new file mode 100644 index 0000000..3f20c37 --- /dev/null +++ b/tests/unit/http_base.py @@ -0,0 +1,73 @@ +import os +import sys +import unittest + +from unittest.mock import patch, mock_open +from tests.unit.test_buffer import load_module + + +class TestHttpBase(unittest.TestCase): + """ + Base class for HTTP file server module. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = {} + cls.cwd = os.getcwd() + + def setUp(self): + # ------------------------------- + # Patch current working directory + # ------------------------------- + self.helpers_module = load_module("pyrobusta/utils/helpers.py") + self.cwd_patcher = patch.object( + self.helpers_module, "getcwd", return_value=self.cwd + ) + self.cwd_patcher.start() + self.addCleanup(self.cwd_patcher.stop) + + # ------------------- + # Patch config module + # ------------------- + self.config = dict(self.base_config) + self.config_module = load_module("pyrobusta/utils/config.py") + self.module_patcher = patch.dict( + sys.modules, + {"pyrobusta.utils.config": self.config_module}, + ) + self.module_patcher.start() + self.addCleanup(self.module_patcher.stop) + + def open_side_effect(*args, **kwargs): + data = "\n".join(f"{k}={v}" for k, v in self.config.items()) + return mock_open(read_data=data)(*args, **kwargs) + + self.open_patcher = patch.object( + self.config_module, + "open", + side_effect=open_side_effect, + ) + self.open_patcher.start() + self.addCleanup(self.open_patcher.stop) + + # ------------------------------------------------ + # Load remaining modules, enable optional features + # ------------------------------------------------ + self.http_module = load_module("pyrobusta/protocol/http.py") + self.fs_module = load_module("pyrobusta/protocol/http_file_server.py") + self.multipart_module = load_module("pyrobusta/protocol/http_multipart.py") + + self.fs_patcher = patch.object(self.fs_module, "setup_directories") + self.fs_patcher.start() + self.addCleanup(self.fs_patcher.stop) + + self.http_module.enable_optional_features() + self.engine = self.http_module.HttpEngine() + + # -------------------- + # HTTP engine, buffers + # -------------------- + buffer_module = load_module("pyrobusta/stream/buffer.py") + self.rx = buffer_module.SlidingBuffer(bytearray(1024)) + self.tx = buffer_module.SlidingBuffer(bytearray(1024)) diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 767c1e4..2d85271 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -56,3 +56,75 @@ def test_path_normalization_host_root(self, _): ("/path/../../resource/..", "/"), ): self.assertEqual(self.helpers_module.normalize_path(case[0]), case[1]) + + @patch("pyrobusta.utils.helpers.getcwd", return_value="/") + def test_path_serving_list(self, _): + served_paths = ["/path/to/dir1", "/path/to/dir2"] + + for case in ( + ("", False), + ("/", False), + ("/path/to/dir1", True), + ("/path/to/dir2", True), + ("/path/to/dir12", False), + ("/path/to/dir1/file", True), + ("/path/to/dir2/file", True), + ("/path/to/other", False), + ("/path/to", False), + ): + self.assertEqual( + self.helpers_module.is_norm_path_served(case[0], served_paths), case[1] + ) + + @patch("pyrobusta.utils.helpers.getcwd", return_value="/") + def test_path_serving_root(self, _): + served_paths = ["/"] + + for case in ( + ("", True), + ("/", True), + ("/path/to/served", True), + ): + self.assertEqual( + self.helpers_module.is_norm_path_served(case[0], served_paths), case[1] + ) + + @patch("pyrobusta.utils.helpers.getcwd", return_value="/") + def test_path_serving_none(self, _): + served_paths = [] + + for case in ( + ("", False), + ("/", False), + ("/path/to/served", False), + ): + self.assertEqual( + self.helpers_module.is_norm_path_served(case[0], served_paths), case[1] + ) + + def test_path_segment_validation(self): + valid_segments = ["file", "dir1", "dir-2", "dir_3", "file.ext", "a"] + invalid_segments = [ + "", + ".", + "..", + "dir/segment", + "dir\\segment", + "/dir/segment/file", + ] + + for segment in valid_segments: + self.assertTrue(self.helpers_module.is_path_segment_valid(segment)) + + for segment in invalid_segments: + self.assertFalse(self.helpers_module.is_path_segment_valid(segment)) + + def test_file_path_validation(self): + valid_paths = ["/file", "/dir1/file", "/dir-2/file", "/dir_3/file"] + invalid_paths = ["file", "dir1/file", "/dir\\segment/file"] + + for path in valid_paths: + self.assertTrue(self.helpers_module.is_file_path_valid(path)) + + for path in invalid_paths: + self.assertFalse(self.helpers_module.is_file_path_valid(path)) diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index aa39e43..8283b84 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -1,81 +1,105 @@ import os -import sys import unittest from unittest import mock from unittest.mock import patch, mock_open -from .utils import load_module, fake_stat +from .utils import stat_factory +from .http_base import TestHttpBase -class TestWebStateMachineBase(unittest.TestCase): +class TestWebStateMachineHelpers(TestHttpBase): """ - Base class for state machine tests. + Tests for state machine helper functions. """ @classmethod def setUpClass(cls): - cls.base_config = {} + cls.base_config = {"http_multipart": "False", "http_files_api": "False"} cls.cwd = os.getcwd() - def setUp(self): - # ------------------------------- - # Patch current working directory - # ------------------------------- - self.helpers_module = load_module("pyrobusta/utils/helpers.py") - self.cwd_patcher = patch.object( - self.helpers_module, "getcwd", return_value=self.cwd + def test_response_header_setter(self): + self.engine.url = b"/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + + self.engine.set_response_header(b"content-type", b"application/json") + self.engine.set_response_header(b"transfer-encoding", b"chunked") + + self.assertEqual( + self.engine.get_response_header(b"content-type"), b"application/json" + ) + self.assertEqual( + self.engine.get_response_header(b"transfer-encoding"), b"chunked" ) - self.cwd_patcher.start() - self.addCleanup(self.cwd_patcher.stop) - - # ------------------- - # Patch config module - # ------------------- - self.config = dict(self.base_config) - self.config_module = load_module("pyrobusta/utils/config.py") - self.module_patcher = patch.dict( - sys.modules, - {"pyrobusta.utils.config": self.config_module}, + + def test_response_header_setter_override(self): + self.engine.url = b"/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + + self.engine.set_response_header(b"content-type", b"application/json") + self.engine.set_response_header(b"connection", b"keep-alive") + + self.engine.set_response_header(b"content-type", b"text/plain") + self.engine.set_response_header(b"connection", b"close") + + self.assertEqual( + self.engine.get_response_header(b"content-type"), b"text/plain" ) - self.module_patcher.start() - self.addCleanup(self.module_patcher.stop) + self.assertEqual(self.engine.get_response_header(b"connection"), b"close") + + def test_generate_response_head(self): + self.engine.url = b"/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + + self.engine.set_response_header(b"content-type", b"application/json") + self.engine.set_response_header(b"transfer-encoding", b"chunked") - def open_side_effect(*args, **kwargs): - data = "\n".join(f"{k}={v}" for k, v in self.config.items()) - return mock_open(read_data=data)(*args, **kwargs) + self.engine.terminate(200) - self.open_patcher = patch.object( - self.config_module, - "open", - side_effect=open_side_effect, + self.engine.write_response_head(self.tx) + self.assertEqual(bytes(self.tx.peek()).find(b"HTTP/1.1 200 OK\r\n"), 0) + self.assertNotEqual( + bytes(self.tx.peek()).find(b"\r\ncontent-type: application/json\r\n"), -1 ) - self.open_patcher.start() - self.addCleanup(self.open_patcher.stop) + self.assertNotEqual( + bytes(self.tx.peek()).find(b"\r\ntransfer-encoding: chunked\r\n"), -1 + ) + + def test_generate_response_head_override(self): + self.engine.url = b"/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" - # ------------------------------------------------ - # Load remaining modules, enable optional features - # ------------------------------------------------ - self.http_module = load_module("pyrobusta/protocol/http.py") - self.http_module.enable_optional_features() - self.engine = self.http_module.HttpEngine() + self.engine.set_response_header(b"content-type", b"application/json") + self.engine.set_response_header(b"date", b"Tue, 29 Feb 2026 15:32:12 GMT") + self.engine.terminate(200) - # -------------------- - # HTTP engine, buffers - # -------------------- - buffer_module = load_module("pyrobusta/stream/buffer.py") - self.rx = buffer_module.SlidingBuffer(bytearray(1024)) - self.tx = buffer_module.SlidingBuffer(bytearray(1024)) + self.engine.set_response_header(b"content-type", b"text/plain") + self.engine.set_response_header(b"date", b"Tue, 29 Feb 2026 15:32:13 GMT") + self.engine.terminate(400) + + self.engine.write_response_head(self.tx) + self.assertEqual(bytes(self.tx.peek()).find(b"HTTP/1.1 400 Bad Request\r\n"), 0) + self.assertNotEqual( + bytes(self.tx.peek()).find(b"\r\ncontent-type: text/plain\r\n"), -1 + ) + self.assertNotEqual( + bytes(self.tx.peek()).find(b"\r\ndate: Tue, 29 Feb 2026 15:32:13 GMT\r\n"), + -1, + ) -class TestWebStateMachine(TestWebStateMachineBase): +class TestWebStateMachine(TestHttpBase): """ Tests for the core functionality of the state machine. """ @classmethod def setUpClass(cls): - cls.base_config = {"http_multipart": "False", "http_serve_files": "False"} + cls.base_config = {"http_multipart": "False", "http_files_api": "False"} cls.cwd = os.getcwd() def test_status_parsing_valid(self): @@ -375,6 +399,12 @@ def test_url_path_matching(self): (b"/path/to/specific/resource", b"/path/to/{wildcard}/resource"), (b"anything", b"{wildcard}"), (b"path/to/resource", b"path/to/{wildcard}"), + (b"path/to/resource", b"path/to/{wildcard:path}"), + (b"path/to/resource/", b"path/to/{wildcard:path}"), + ( + b"path/to/resource/subresource/subsubresource", + b"path/to/{wildcard:path}", + ), ): self.assertEqual(self.engine._is_matching_url_path(case[0], case[1]), True) @@ -382,9 +412,12 @@ def test_url_path_matching_mismatch(self): for case in ( (b"", b"/"), (b"", b"{wildcard}"), + (b"", b"{wildcard:path}"), (b"/path/to/resource/subresource", b"/path/to/resource"), (b"/path/to/", b"/path/to/{wildcard}"), + (b"/path/to/", b"/path/to/{wildcard:path}"), (b"/to/resource", b"{wildcard}/to/resource"), + (b"path/to/resource/subresource/subsubresource", b"path/to/{wildcard}"), ): self.assertEqual(self.engine._is_matching_url_path(case[0], case[1]), False) @@ -455,208 +488,7 @@ def test_chunked_transfer_encoding_chunk_incomplete(self): self.assertEqual(self.engine.state, self.engine._recv_chunk_st) -class TestWebHelpers(TestWebStateMachineBase): - """ - Tests for helper functions. - """ - - @classmethod - def setUpClass(cls): - cls.base_config = {"http_multipart": "False", "http_serve_files": "True"} - # Simplify file-open assertions by treating resources - # as if they are installed at the root (/) rather than - # relative to the current working directory. - cls.cwd = "/" - - def test_path_serving_list(self): - self.config["http_served_paths"] = "/path/to/dir1 /path/to/dir2" - self.config_module.read_config() - self.assertEqual(self.engine.is_norm_path_served(""), False) - self.assertEqual(self.engine.is_norm_path_served("/"), False) - self.assertEqual(self.engine.is_norm_path_served("/path/to/dir1"), True) - self.assertEqual(self.engine.is_norm_path_served("/path/to/dir2"), True) - self.assertEqual(self.engine.is_norm_path_served("/path/to/dir12"), False) - self.assertEqual(self.engine.is_norm_path_served("/path/to/dir1/file"), True) - self.assertEqual(self.engine.is_norm_path_served("/path/to/dir2/file"), True) - self.assertEqual(self.engine.is_norm_path_served("/path/to/other"), False) - self.assertEqual(self.engine.is_norm_path_served("/path/to"), False) - - def test_path_serving_root(self): - self.config["http_served_paths"] = "/" - self.config_module.read_config() - self.assertEqual(self.engine.is_norm_path_served(""), True) - self.assertEqual(self.engine.is_norm_path_served("/"), True) - self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), True) - - def test_path_serving_none(self): - self.config["http_served_paths"] = "" - self.config_module.read_config() - self.assertEqual(self.engine.is_norm_path_served(""), False) - self.assertEqual(self.engine.is_norm_path_served("/"), False) - self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), False) - - -class TestMultipartStateMachine(TestWebStateMachineBase): - """ - Tests for multipart handling. - """ - - @classmethod - def setUpClass(cls): - cls.base_config = {"http_multipart": "True", "http_serve_files": "False"} - cls.cwd = os.getcwd() - - def test_multipart_parser(self): - for case in [ - ({}, None), - ( - {"content-type": 'multipart/form-data; boundary ="test-boundary"'}, - "test-boundary", - ), - ( - {"content-type": 'multipart/form-data; boundary =" test-boundary "'}, - " test-boundary ", - ), - ( - {"content-type": "multipart/form-data ;boundary= test-boundary "}, - "test-boundary", - ), - ( - {"content-type": "multipart/form-data;boundary=a test boundary "}, - "a test boundary", - ), - ]: - with self.subTest(headers=case[0], expected=case[1]): - self.assertEqual(self.engine._get_mp_boundary(case[0]), case[1]) - - for case in [ - {"content-type": "multipart/form-data"}, - {"content-type": 'multipart/form-data;boundary=""'}, - {"content-type": "multipart/form-data;boundary=\r\n"}, - {"content-type": 'multipart/form-data;boundary="missing-quote'}, - {"content-type": 'multipart/form-data;boundary=missing-quote"'}, - ]: - with self.subTest(headers=case): - with self.assertRaises(self.http_module.InvalidHeaders): - self.engine._get_mp_boundary(case) - - def test_multipart_receiver_valid(self): - self.engine.state = self.engine._start_multipart_parser_st - self.engine.headers["content-length"] = 100 - self.engine.mp_boundary = b"test-boundary" - body_part = b"--test-boundary\r\nContent-Type:text/plain" - - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) - - self.assertEqual(self.engine.state, self.engine._parse_boundary_st) - self.assertEqual(self.rx.peek(), b"Content-Type:text/plain") - - def test_multipart_receiver_boundary_mismatch(self): - self.engine.state = self.engine._start_multipart_parser_st - self.engine.version = b"HTTP/1.1" - self.engine.headers["content-length"] = 100 - self.engine.mp_boundary = b"test-boundary" - body_part = b"--test-boundary-delimiter\r\nContent-Type:text/plain" - - with self.assertRaises(self.http_module.MalformedRequest): - for i in range(len(body_part)): - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) - - def test_multipart_receiver_complete_part(self): - self.engine.state = self.engine._parse_boundary_st - self.engine.url = b"/api/test" - self.engine.method = b"GET" - - test_callback = mock.Mock() - self.engine.register("/api/test", test_callback) - - self.engine.headers["content-length"] = 1000 - self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" - - body_part = ( - b'Content-Disposition:form-data;name="file-chunk";filename="upload.txt"\r\n' - b"Content-Type:text/plain\r\n\r\n" - b"Upload content\r\n" - b"--test-boundary\r\n" - ) - - for i in range(len(body_part)): - self.assertEqual(self.engine.state, self.engine._parse_boundary_st) - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) - - self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) - self.assertEqual(self.rx.peek(), body_part) - self.assertEqual(self.engine.mp_is_first, True) - - self.engine.state(self.rx) - - self.assertEqual(self.engine.state, self.engine._parse_boundary_st) - test_callback.assert_called_once_with( - self.engine, - ( - { - "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', - "content-type": "text/plain", - }, - b"Upload content", - ), - ) - self.assertEqual(self.engine.mp_is_first, False) - self.assertEqual(self.engine.mp_is_last, False) - - def test_multipart_receiver_last_part(self): - self.engine.state = self.engine._parse_boundary_st - self.engine.url = b"/api/test" - self.engine.method = b"GET" - self.engine.version = b"HTTP/1.1" - self.engine.headers["content-length"] = 131 - self.engine.mp_boundary = b"test-boundary" - self.engine.mp_delimiter = b"--test-boundary\r\n" - self.engine.mp_last_delimiter = b"--test-boundary--" - - test_callback = mock.Mock(return_value=("text/plain", "OK")) - self.engine.register("/api/test", test_callback) - - body_part = ( - b'Content-Disposition:form-data;name="file-chunk";filename="upload.txt"\r\n' - b"Content-Type:text/plain\r\n\r\n" - b"Upload content\r\n" - b"--test-boundary--" - ) - - for i in range(len(body_part)): - self.assertEqual(self.engine.state, self.engine._parse_boundary_st) - self.rx.write(body_part[i : i + 1]) - self.engine.state(self.rx) - - self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) - self.assertEqual(self.rx.peek(), body_part) - - self.engine.state(self.rx) - - self.assertEqual(self.engine.state, None) - self.assertEqual(self.engine.status_code, 200) - test_callback.assert_called_once_with( - self.engine, - ( - { - "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', - "content-type": "text/plain", - }, - b"Upload content", - ), - ) - self.assertEqual(self.engine.mp_is_first, True) - self.assertEqual(self.engine.mp_is_last, True) - - -class TestFileServerStateMachine(TestWebStateMachineBase): +class TestFileServingStateMachine(TestHttpBase): """ Tests for file serving. """ @@ -665,7 +497,9 @@ class TestFileServerStateMachine(TestWebStateMachineBase): def setUpClass(cls): cls.base_config = { "http_multipart": "False", - "http_serve_files": "True", + # Only built-in file serving is tested here, + # so disable /files endpoint to avoid conflicts + "http_files_api": "False", "http_served_paths": "/www", } # Simplify file-open assertions by treating resources @@ -674,67 +508,70 @@ def setUpClass(cls): cls.cwd = "/" @staticmethod - def patch_all(f): - @patch("pyrobusta.protocol.http_file_server.stat", fake_stat) - def decorated(*args, **kwargs): - return f(*args, **kwargs) + def patch_os_stat(stat_is_file=True): + def patched(f): + @patch("pyrobusta.protocol.http.stat", stat_factory(stat_is_file)) + def decorated(*args, **kwargs): + return f(*args, **kwargs) - return decorated + return decorated - @patch_all - def test_file_serving_missing_file(self, *_): - self.engine.url = b"/files/www/nonexistent.js" - self.engine.method = b"GET" - self.engine.version = b"HTTP/1.1" - self.engine.state = self.engine._send_file_st - - self.engine.state(self.rx, self.engine.url) + return patched - self.assertEqual(self.engine.status_code, 404) - self.assertEqual(self.engine.state, None) - - @patch_all + @patch_os_stat() def test_file_serving_root(self, *_): self.engine.url = b"/" self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" - self.engine.state = self.engine._send_file_st + self.engine.state = self.engine._fs_retrieve_st file_content = "index content" with patch("builtins.open", mock_open(read_data=file_content)) as m: - self.engine.state(self.rx, self.engine.url) + self.engine.state(self.rx) m.assert_called_once_with("/www/index.html", "rb") self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) - @patch_all - def test_file_serving_files_endpoint(self, *_): - self.engine.url = b"/files/www/scripts.js" + @patch_os_stat() + def test_file_serving_subdir(self, *_): + self.engine.url = b"/style/styles.css" self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" - self.engine.state = self.engine._send_file_st - file_content = "data" + self.engine.state = self.engine._fs_retrieve_st + file_content = "/* Main styelsheet */" with patch("builtins.open", mock_open(read_data=file_content)) as m: - self.engine.state(self.rx, self.engine.url) - m.assert_called_once_with("/www/scripts.js", "rb") + self.engine.state(self.rx) + m.assert_called_once_with("/www/style/styles.css", "rb") self.assertEqual(self.engine.resp_handler.read(), file_content) self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) - @patch_all + @patch_os_stat() + def test_file_serving_missing_file(self, *_): + self.engine.url = b"/nonexistent.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._fs_retrieve_st + + self.engine.state(self.rx) + + self.assertEqual(self.engine.status_code, 404) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() def test_file_serving_known_content_type(self, *_): self.engine.url = b"/scripts.js" self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" - self.engine.state = self.engine._send_file_st + self.engine.state = self.engine._fs_retrieve_st file_content = "data" with patch("builtins.open", mock_open(read_data=file_content)) as m: - self.engine.state(self.rx, self.engine.url) + self.engine.state(self.rx) m.assert_called_once_with("/www/scripts.js", "rb") self.assertEqual(self.engine.resp_handler.read(), file_content) @@ -745,16 +582,16 @@ def test_file_serving_known_content_type(self, *_): self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) - @patch_all + @patch_os_stat() def test_file_serving_fallback_content_type(self, *_): self.engine.url = b"/scripts.unknown" self.engine.method = b"GET" self.engine.version = b"HTTP/1.1" - self.engine.state = self.engine._send_file_st + self.engine.state = self.engine._fs_retrieve_st file_content = "data" with patch("builtins.open", mock_open(read_data=file_content)) as m: - self.engine.state(self.rx, self.engine.url) + self.engine.state(self.rx) m.assert_called_once_with("/www/scripts.unknown", "rb") self.assertEqual(self.engine.resp_handler.read(), file_content) @@ -765,22 +602,6 @@ def test_file_serving_fallback_content_type(self, *_): self.assertEqual(self.engine.status_code, 200) self.assertEqual(self.engine.state, None) - @patch_all - def test_file_serving_unserved_content_rejected(self, *_): - self.engine.url = b"/files/unserved/script.js" - self.engine.method = b"GET" - self.engine.version = b"HTTP/1.1" - self.engine.state = self.engine._send_file_st - file_content = "data" - - with patch("builtins.open", mock_open(read_data=file_content)) as m: - self.engine.state(self.rx, self.engine.url) - m.assert_not_called() - - self.assertEqual(self.engine.resp_handler, None) - self.assertEqual(self.engine.status_code, 403) - self.assertEqual(self.engine.state, None) - if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/unit/test_http_file_server.py b/tests/unit/test_http_file_server.py new file mode 100644 index 0000000..a6402ff --- /dev/null +++ b/tests/unit/test_http_file_server.py @@ -0,0 +1,599 @@ +import json + +from unittest.mock import patch, mock_open, call + +from .utils import stat_factory +from .http_base import TestHttpBase + + +def patch_os_stat(stat_is_file=True): + def patched(f): + @patch("pyrobusta.protocol.http_file_server.stat", stat_factory(stat_is_file)) + def decorated(*args, **kwargs): + return f(*args, **kwargs) + + return decorated + + return patched + + +class TestFileServerRetrieve(TestHttpBase): + """ + Tests for GET /files/. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = { + "http_multipart": "False", + "http_files_api": "True", + "http_served_paths": "/www", + } + # Simplify file-open assertions by treating resources + # as if they are installed at the root (/) rather than + # relative to the current working directory. + cls.cwd = "/" + + @patch_os_stat() + def test_file_serving_missing_file(self, *_): + self.engine.url = b"/files/www/nonexistent.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + + self.engine.state(self.rx) + + self.assertEqual(self.engine.status_code, 404) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() + def test_file_serving_files_endpoint(self, *_): + self.engine.url = b"/files/www/scripts.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + self.engine.state(self.rx) + m.assert_called_once_with("/www/scripts.js", "rb") + + self.assertEqual(self.engine.resp_handler.read(), file_content) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() + def test_file_serving_known_content_type(self, *_): + self.engine.url = b"/files/www/scripts.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + self.engine.state(self.rx) + m.assert_called_once_with("/www/scripts.js", "rb") + + self.assertEqual(self.engine.resp_handler.read(), file_content) + self.assertEqual( + self.engine._lookup(self.engine.resp_headers, b"content-type"), + b"application/javascript", + ) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() + def test_file_serving_fallback_content_type(self, *_): + self.engine.url = b"/files/www/scripts.unknown" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + self.engine.state(self.rx) + m.assert_called_once_with("/www/scripts.unknown", "rb") + + self.assertEqual(self.engine.resp_handler.read(), file_content) + self.assertEqual( + self.engine._lookup(self.engine.resp_headers, b"content-type"), + b"application/octet-stream", + ) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() + def test_file_serving_unserved_content_rejected(self, *_): + self.engine.url = b"/files/unserved/script.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + self.engine.state(self.rx) + m.assert_not_called() + + self.assertNotEqual(self.engine.resp_handler.read(), file_content) + self.assertEqual(self.engine.status_code, 403) + self.assertEqual(self.engine.state, None) + + @patch_os_stat(stat_is_file=False) + def test_file_serving_directory_path(self): + self.engine.url = b"/files/www" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + + self.engine.state(self.rx) + + self.assertNotEqual(self.engine.resp_handler, None) + self.assertEqual( + self.engine._lookup(self.engine.resp_headers, b"content-type"), + b"application/json", + ) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() + def test_file_serving_directory_traversal(self): + directory_content = ["file1", "file2", "file3"] + with patch( + "pyrobusta.protocol.http_file_server.iterate_fs", + lambda *_: directory_content, + ): + traversal_function = self.fs_module._traverse_dir_factory("/www") + for is_finished in traversal_function(self.tx): + if is_finished: + break + + response = json.loads(bytes(self.tx.peek())) + response_files = [it["path"] for it in response] + self.assertListEqual(response_files, directory_content) + + +class TestFileServerDelete(TestHttpBase): + """ + Tests for DELETE /files/. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = { + "http_multipart": "False", + "http_files_api": "True", + "http_served_paths": "/www", + } + # Simplify file-open assertions by treating resources + # as if they are installed at the root (/) rather than + # relative to the current working directory. + cls.cwd = "/" + + @patch_os_stat() + def test_file_serving_missing_file(self, *_): + self.engine.url = b"/files/www/user_data/nonexistent.js" + self.engine.method = b"DELETE" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + + self.engine.state(self.rx) + + self.assertEqual(self.engine.status_code, 404) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() + def test_file_serving_non_user_data_rejected(self, *_): + self.engine.url = b"/files/www/index.html" + self.engine.method = b"DELETE" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + + self.engine.state(self.rx) + + self.assertEqual(self.engine.status_code, 403) + self.assertEqual(self.engine.state, None) + + @patch_os_stat() + def test_file_serving_user_data_deleted(self, *_): + self.engine.url = b"/files/www/user_data/user_content.json" + self.engine.method = b"DELETE" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + + with patch("pyrobusta.protocol.http_file_server.remove") as m: + self.engine.state(self.rx) + m.assert_called_once_with("/www/user_data/user_content.json") + + self.assertEqual(self.engine.status_code, 204) + self.assertEqual(self.engine.state, None) + + @patch_os_stat(stat_is_file=True) + def test_file_serving_user_directory_deleted(self, *_): + self.engine.url = b"/files/www/user_data/user_dir" + self.engine.method = b"DELETE" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._app_endpoint_st + + with patch("pyrobusta.protocol.http_file_server.remove") as m: + self.engine.state(self.rx) + m.assert_called_once_with("/www/user_data/user_dir") + + self.assertEqual(self.engine.status_code, 204) + self.assertEqual(self.engine.state, None) + + +class TestFileServerUpload(TestHttpBase): + """ + Tests for POST /files/. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = { + "http_multipart": "False", + "http_files_api": "True", + "http_served_paths": "/www", + } + # Simplify file-open assertions by treating resources + # as if they are installed at the root (/) rather than + # relative to the current working directory. + cls.cwd = "/" + + def test_file_serving_complete_file_upload(self, *_): + self.engine.url = b"/files/www/user_data/complete.txt" + self.engine.method = b"PUT" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 28 + self.engine.headers["content-type"] = "application/octet-stream" + + self.engine.state = self.engine._app_endpoint_st + body_part = b"File uploaded for testing.\r\n" + self.rx.write(body_part) + + with patch("builtins.open", mock_open(read_data=body_part)) as open_mock: + while self.engine.state is not None: + self.engine.state(self.rx) + + open_mock.assert_called_once_with("/www/user_data/complete.txt", "wb") + + self.assertEqual(self.engine.status_code, 201) + + def test_file_serving_complete_file_invalid_name(self, *_): + self.engine.url = b"/files/www/user_data/$test.txt" + self.engine.method = b"PUT" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 28 + self.engine.headers["content-type"] = "application/octet-stream" + + self.engine.state = self.engine._app_endpoint_st + body_part = b"File uploaded for testing.\r\n" + self.rx.write(body_part) + + while self.engine.state is not None: + self.engine.state(self.rx) + + self.assertEqual(self.engine.status_code, 400) + + def test_file_serving_chunked_file_upload(self, *_): + self.engine.url = b"/files/www/user_data/chunked.txt" + self.engine.method = b"PUT" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["transfer-encoding"] = "chunked" + self.engine.headers["content-type"] = "application/octet-stream" + + self.engine.state = self.engine._recv_chunk_size_st + body_part = ( + b"14\r\nchunking\r\ntest\r\ncase\r\n" + b"E\r\nchunking\r\ntest\r\n" + b"8\r\nchunking\r\n" + b"0\r\n\r\n" + ) + self.rx.write(body_part) + + m = mock_open() + with patch("builtins.open", m) as open_mock, patch( + "pyrobusta.protocol.http_file_server.rename" + ) as rename_mock: + while self.engine.state is not None: + self.engine.state(self.rx) + + open_mock.assert_has_calls( + [ + call(f"/tmp/chunked.txt.{self.engine.id}", "ab"), + call(f"/tmp/chunked.txt.{self.engine.id}", "ab"), + call(f"/tmp/chunked.txt.{self.engine.id}", "ab"), + ], + any_order=True, + ) + + rename_mock.assert_called_once_with( + f"/tmp/chunked.txt.{self.engine.id}", + "/www/user_data/chunked.txt", + ) + + handle = m() + + handle.write.assert_has_calls( + [ + call(b"chunking\r\ntest\r\ncase"), + call(b"chunking\r\ntest"), + call(b"chunking"), + ] + ) + + self.assertEqual(handle.write.call_count, 3) + + self.assertEqual(self.engine.status_code, 201) + + +class TestFileServerBulkUpload(TestHttpBase): + """ + Tests for POST /files/. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = { + # Bulk upload requires multipart handling + "http_multipart": "True", + "http_files_api": "True", + "http_served_paths": "/www", + } + # Simplify file-open assertions by treating resources + # as if they are installed at the root (/) rather than + # relative to the current working directory. + cls.cwd = "/" + + def test_file_serving_single_file_upload(self, *_): + self.engine.url = b"/files" + self.engine.method = b"POST" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 151 + self.engine.headers["content-type"] = "multipart/form-data" + + self.engine.mp_boundary = b"test-boundary" + self.engine.mp_delimiter = b"--test-boundary\r\n" + self.engine.mp_last_delimiter = b"--test-boundary--" + + self.engine.state = self.engine._start_multipart_parser_st + body_part = ( + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="complete-file";filename="upload.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + b"--test-boundary--" + ) + self.rx.write(body_part) + + with patch( + "pyrobusta.protocol.http_file_server.listdir", + lambda _: [f"upload.txt.{self.engine.id}"], + ), patch("pyrobusta.protocol.http_file_server.remove") as remove_mock, patch( + "builtins.open" + ) as open_mock, patch( + "pyrobusta.protocol.http_file_server.rename" + ) as rename_mock: + while self.engine.state is not None: + self.engine.state(self.rx) + + # stale upload is removed + remove_mock.assert_called_once_with(f"/tmp/upload.txt.{self.engine.id}") + + # file is opened in append mode to allow for chunked uploads, even if the complete + # file is sent in a single part + open_mock.assert_called_once_with(f"/tmp/upload.txt.{self.engine.id}", "ab") + + # file is renamed to final destination after upload completion + rename_mock.assert_called_once_with( + f"/tmp/upload.txt.{self.engine.id}", "/www/user_data/upload.txt" + ) + + self.assertEqual(remove_mock.call_count, 1) + self.assertEqual(open_mock.call_count, 1) + self.assertEqual(rename_mock.call_count, 1) + + self.assertEqual(self.engine.status_code, 201) + + def test_file_serving_multiple_file_upload(self, *_): + self.engine.url = b"/files" + self.engine.method = b"POST" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 287 + self.engine.headers["content-type"] = "multipart/form-data" + + self.engine.mp_boundary = b"test-boundary" + self.engine.mp_delimiter = b"--test-boundary\r\n" + self.engine.mp_last_delimiter = b"--test-boundary--" + + self.engine.state = self.engine._start_multipart_parser_st + body_part = ( + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="complete-file";filename="upload1.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="complete-file";filename="upload2.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + b"--test-boundary--" + ) + self.rx.write(body_part) + + with patch( + "pyrobusta.protocol.http_file_server.listdir", + lambda _: [ + f"upload1.txt.{self.engine.id}", + f"upload2.txt.{self.engine.id}", + ], + ), patch("pyrobusta.protocol.http_file_server.remove") as remove_mock, patch( + "builtins.open" + ) as open_mock, patch( + "pyrobusta.protocol.http_file_server.rename" + ) as rename_mock: + while self.engine.state is not None: + self.engine.state(self.rx) + + remove_mock_calls = [ + call(f"/tmp/upload1.txt.{self.engine.id}"), + call(f"/tmp/upload2.txt.{self.engine.id}"), + ] + open_mock_calls = [ + call(f"/tmp/upload1.txt.{self.engine.id}", "ab"), + call(f"/tmp/upload2.txt.{self.engine.id}", "ab"), + ] + rename_mock_calls = [ + call( + f"/tmp/upload1.txt.{self.engine.id}", "/www/user_data/upload1.txt" + ), + call( + f"/tmp/upload2.txt.{self.engine.id}", "/www/user_data/upload2.txt" + ), + ] + + remove_mock.assert_has_calls(remove_mock_calls) + open_mock.assert_has_calls(open_mock_calls, any_order=True) + rename_mock.assert_has_calls(rename_mock_calls) + + self.assertEqual(remove_mock.call_count, 2) + self.assertEqual(open_mock.call_count, 2) + self.assertEqual(rename_mock.call_count, 2) + + self.assertEqual(self.engine.status_code, 201) + + def test_file_serving_single_file_multiple_parts_upload(self, *_): + self.engine.url = b"/files" + self.engine.method = b"POST" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 285 + self.engine.headers["content-type"] = "multipart/form-data" + + self.engine.mp_boundary = b"test-boundary" + self.engine.mp_delimiter = b"--test-boundary\r\n" + self.engine.mp_last_delimiter = b"--test-boundary--" + + self.engine.state = self.engine._start_multipart_parser_st + body_part = ( + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="file-chunk-1";filename="upload1.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="file-chunk-2";filename="upload1.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + b"--test-boundary--" + ) + self.rx.write(body_part) + + with patch( + "pyrobusta.protocol.http_file_server.listdir", + lambda _: [ + f"upload1.txt.{self.engine.id}", + ], + ), patch("pyrobusta.protocol.http_file_server.remove") as remove_mock, patch( + "builtins.open" + ) as open_mock, patch( + "pyrobusta.protocol.http_file_server.rename" + ) as rename_mock: + while self.engine.state is not None: + self.engine.state(self.rx) + + open_mock_calls = [ + call(f"/tmp/upload1.txt.{self.engine.id}", "ab"), + call(f"/tmp/upload1.txt.{self.engine.id}", "ab"), + ] + + remove_mock.assert_called_once_with(f"/tmp/upload1.txt.{self.engine.id}") + open_mock.assert_has_calls(open_mock_calls, any_order=True) + rename_mock.assert_called_once_with( + f"/tmp/upload1.txt.{self.engine.id}", "/www/user_data/upload1.txt" + ) + + self.assertEqual(remove_mock.call_count, 1) + self.assertEqual(open_mock.call_count, 2) + self.assertEqual(rename_mock.call_count, 1) + + self.assertEqual(self.engine.status_code, 201) + + def test_file_serving_multiple_file_chunked_upload(self, *_): + self.engine.url = b"/files" + self.engine.method = b"POST" + self.engine.version = b"HTTP/1.1" + + self.engine.headers["content-length"] = 548 + self.engine.headers["content-type"] = "multipart/form-data" + + self.engine.mp_boundary = b"test-boundary" + self.engine.mp_delimiter = b"--test-boundary\r\n" + self.engine.mp_last_delimiter = b"--test-boundary--" + + self.engine.state = self.engine._start_multipart_parser_st + body_part = ( + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="file-chunk-1";filename="upload1.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content #1\r\n" + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="file-chunk-2";filename="upload2.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content #1\r\n" + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="file-chunk-3";filename="upload1.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content #2\r\n" + b"--test-boundary\r\n" + b'Content-Disposition:form-data;name="file-chunk-4";filename="upload2.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content #2\r\n" + b"--test-boundary--" + ) + self.rx.write(body_part) + + with patch( + "pyrobusta.protocol.http_file_server.listdir", + lambda _: [ + f"upload1.txt.{self.engine.id}", + f"upload2.txt.{self.engine.id}", + ], + ), patch("pyrobusta.protocol.http_file_server.remove") as remove_mock, patch( + "builtins.open" + ) as open_mock, patch( + "pyrobusta.protocol.http_file_server.rename" + ) as rename_mock: + while self.engine.state is not None: + self.engine.state(self.rx) + + remove_mock_calls = [ + call(f"/tmp/upload1.txt.{self.engine.id}"), + call(f"/tmp/upload2.txt.{self.engine.id}"), + ] + open_mock_calls = [ + call(f"/tmp/upload1.txt.{self.engine.id}", "ab"), + call(f"/tmp/upload2.txt.{self.engine.id}", "ab"), + call(f"/tmp/upload1.txt.{self.engine.id}", "ab"), + call(f"/tmp/upload2.txt.{self.engine.id}", "ab"), + ] + rename_mock_calls = [ + call( + f"/tmp/upload1.txt.{self.engine.id}", "/www/user_data/upload1.txt" + ), + call( + f"/tmp/upload2.txt.{self.engine.id}", "/www/user_data/upload2.txt" + ), + ] + + remove_mock.assert_has_calls(remove_mock_calls) + open_mock.assert_has_calls(open_mock_calls, any_order=True) + rename_mock.assert_has_calls(rename_mock_calls) + + self.assertEqual(remove_mock.call_count, 2) + self.assertEqual(open_mock.call_count, 4) + self.assertEqual(rename_mock.call_count, 2) + + self.assertEqual(self.engine.status_code, 201) diff --git a/tests/unit/test_http_multipart.py b/tests/unit/test_http_multipart.py new file mode 100644 index 0000000..2108f12 --- /dev/null +++ b/tests/unit/test_http_multipart.py @@ -0,0 +1,170 @@ +import os +import unittest + +from unittest import mock + +from .http_base import TestHttpBase + + +class TestMultipartStateMachine(TestHttpBase): + """ + Tests for multipart handling. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = {"http_multipart": "True", "http_files_api": "False"} + cls.cwd = os.getcwd() + + def test_multipart_parser(self): + for case in [ + ({}, None), + ( + {"content-type": 'multipart/form-data; boundary ="test-boundary"'}, + "test-boundary", + ), + ( + {"content-type": 'multipart/form-data; boundary =" test-boundary "'}, + " test-boundary ", + ), + ( + {"content-type": "multipart/form-data ;boundary= test-boundary "}, + "test-boundary", + ), + ( + {"content-type": "multipart/form-data;boundary=a test boundary "}, + "a test boundary", + ), + ]: + with self.subTest(headers=case[0], expected=case[1]): + self.assertEqual(self.engine._get_mp_boundary(case[0]), case[1]) + + for case in [ + {"content-type": "multipart/form-data"}, + {"content-type": 'multipart/form-data;boundary=""'}, + {"content-type": "multipart/form-data;boundary=\r\n"}, + {"content-type": 'multipart/form-data;boundary="missing-quote'}, + {"content-type": 'multipart/form-data;boundary=missing-quote"'}, + ]: + with self.subTest(headers=case): + with self.assertRaises(self.http_module.InvalidHeaders): + self.engine._get_mp_boundary(case) + + def test_multipart_receiver_valid(self): + self.engine.state = self.engine._start_multipart_parser_st + self.engine.headers["content-length"] = 100 + self.engine.mp_boundary = b"test-boundary" + body_part = b"--test-boundary\r\nContent-Type:text/plain" + + for i in range(len(body_part)): + self.rx.write(body_part[i : i + 1]) + self.engine.state(self.rx) + + self.assertEqual(self.engine.state, self.engine._parse_boundary_st) + self.assertEqual(self.rx.peek(), b"Content-Type:text/plain") + + def test_multipart_receiver_boundary_mismatch(self): + self.engine.state = self.engine._start_multipart_parser_st + self.engine.version = b"HTTP/1.1" + self.engine.headers["content-length"] = 100 + self.engine.mp_boundary = b"test-boundary" + body_part = b"--test-boundary-delimiter\r\nContent-Type:text/plain" + + with self.assertRaises(self.http_module.MalformedRequest): + for i in range(len(body_part)): + self.rx.write(body_part[i : i + 1]) + self.engine.state(self.rx) + + def test_multipart_receiver_complete_part(self): + self.engine.state = self.engine._parse_boundary_st + self.engine.url = b"/api/test" + self.engine.method = b"GET" + + test_callback = mock.Mock() + self.engine.register("/api/test", test_callback) + + self.engine.headers["content-length"] = 1000 + self.engine.mp_boundary = b"test-boundary" + self.engine.mp_delimiter = b"--test-boundary\r\n" + self.engine.mp_last_delimiter = b"--test-boundary--" + + body_part = ( + b'Content-Disposition:form-data;name="file-chunk";filename="upload.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + b"--test-boundary\r\n" + ) + + for i in range(len(body_part)): + self.assertEqual(self.engine.state, self.engine._parse_boundary_st) + self.rx.write(body_part[i : i + 1]) + self.engine.state(self.rx) + + self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) + self.assertEqual(self.rx.peek(), body_part) + self.assertEqual(self.engine.mp_is_first, True) + + self.engine.state(self.rx) + + self.assertEqual(self.engine.state, self.engine._parse_boundary_st) + test_callback.assert_called_once_with( + self.engine, + ( + { + "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', + "content-type": "text/plain", + }, + b"Upload content", + ), + ) + self.assertEqual(self.engine.mp_is_first, False) + self.assertEqual(self.engine.mp_is_last, False) + + def test_multipart_receiver_last_part(self): + self.engine.state = self.engine._parse_boundary_st + self.engine.url = b"/api/test" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.headers["content-length"] = 131 + self.engine.mp_boundary = b"test-boundary" + self.engine.mp_delimiter = b"--test-boundary\r\n" + self.engine.mp_last_delimiter = b"--test-boundary--" + + test_callback = mock.Mock(return_value=("text/plain", "OK")) + self.engine.register("/api/test", test_callback) + + body_part = ( + b'Content-Disposition:form-data;name="file-chunk";filename="upload.txt"\r\n' + b"Content-Type:text/plain\r\n\r\n" + b"Upload content\r\n" + b"--test-boundary--" + ) + + for i in range(len(body_part)): + self.assertEqual(self.engine.state, self.engine._parse_boundary_st) + self.rx.write(body_part[i : i + 1]) + self.engine.state(self.rx) + + self.assertEqual(self.engine.state, self.engine._parse_complete_part_st) + self.assertEqual(self.rx.peek(), body_part) + + self.engine.state(self.rx) + + self.assertEqual(self.engine.state, None) + self.assertEqual(self.engine.status_code, 200) + test_callback.assert_called_once_with( + self.engine, + ( + { + "content-disposition": 'form-data;name="file-chunk";filename="upload.txt"', + "content-type": "text/plain", + }, + b"Upload content", + ), + ) + self.assertEqual(self.engine.mp_is_first, True) + self.assertEqual(self.engine.mp_is_last, True) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/unit/utils.py b/tests/unit/utils.py index b9b04f8..9b01f13 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -29,18 +29,22 @@ def load_module(relative_path): return mod -def fake_stat(size=1024): - return os.stat_result( - ( - 0o100644, # st_mode (regular file, 644 perms) - 12345678, # st_ino - 2049, # st_dev - 1, # st_nlink - 1000, # st_uid - 1000, # st_gid - size, # st_size - int(time.time()), # st_atime - int(time.time()), # st_mtime - int(time.time()), # st_ctime +def stat_factory(is_file): + def fake_stat(_): + st_mode = 0o100000 | 0o644 if is_file else 0o040000 | 0o755 + return os.stat_result( + ( + st_mode, # st_mode + 12345678, # st_ino + 2049, # st_dev + 1, # st_nlink + 1000, # st_uid + 1000, # st_gid + 1024, # st_size + int(time.time()), # st_atime + int(time.time()), # st_mtime + int(time.time()), # st_ctime + ) ) - ) + + return fake_stat