diff --git a/Makefile b/Makefile index f99940a..843e4e1 100644 --- a/Makefile +++ b/Makefile @@ -152,6 +152,7 @@ stage-example: @echo "Copying built package" @cp -r build/pyrobusta $(RUNTIME_DIR)/lib + @cp -r build/pyrobusta/assets/www $(RUNTIME_DIR)/ @echo "Copying example app" @cp $(EXAMPLE_DIR)/app.py $(RUNTIME_DIR)/ @@ -162,6 +163,8 @@ stage-example: @cp $(TLS_DIR)/key.der $(RUNTIME_DIR)/ @if [ -f pyrobusta.env ]; then cp pyrobusta.env $(RUNTIME_DIR)/; fi + @echo "http_port=8080" >> $(RUNTIME_DIR)/pyrobusta.env + @echo "https_port=4443" >> $(RUNTIME_DIR)/pyrobusta.env # ----------------------------- # Run example locally with unix micropython diff --git a/assets/www/examples.html b/assets/www/examples.html index 095c675..7c6cbef 100644 --- a/assets/www/examples.html +++ b/assets/www/examples.html @@ -104,6 +104,8 @@

Simple Server Application

asyncio.get_event_loop().close()

Use curl to test your application.

diff --git a/docs/configuration.md b/docs/configuration.md index f6b63fb..e7a9df6 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -7,10 +7,12 @@ to upload it to the root directory of the target device. |-------------------|-------------------------------------------------------------------------------------------------------|-------------------------------| | wifi_ssid | Name of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | | wifi_password | Password of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | -| http_multipart | Enable multipart HTTP requests/responses. | "False" | +| http_port | Port number for HTTP. | 80 | +| https_port | Port number for HTTPS. | 443 | +| http_multipart | Enable multipart HTTP requests/responses. | False | | http_mem_cap | Max memory cap (% × 0.01) of usable heap for HTTP request/response stream buffers. | 0.1 | | http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/www /lib/pyrobusta" | -| http_serve_files | Enable/disable file serving. | "True" | +| http_serve_files | Enable/disable file serving. | True | | socket_max_con | Max number of socket connections of any enabled application server. | 2 | -| tls | Enable/disable TLS. When turned on, cert.der/key.der must be installed at the root. | "False" | +| tls | Enable/disable TLS. When turned on, cert.der/key.der must be installed at the root. | False | | log_level | Can be one of: warning, info, debug. | "info" | diff --git a/example/mip_repo/app.py b/example/mip_repo/app.py index 2564e96..80f5625 100644 --- a/example/mip_repo/app.py +++ b/example/mip_repo/app.py @@ -3,26 +3,20 @@ import pyrobusta.server.http_server as http_server from pyrobusta.protocol.http import HttpEngine -from pyrobusta.utils import logging, config +from pyrobusta.utils import logging, config, assets, helpers def append_package_files(dir, package_files, host_name, protocol): """ Construct package file list recursively. """ - for name in os.listdir(dir): - current_path = f"{dir}/{name}" - st = os.stat(current_path) - mode = st[0] - if mode & 0x4000: # directory bit set - append_package_files(current_path, package_files, host_name, protocol) - continue - - target_path = current_path[4:] if current_path.startswith("lib/") else current_path + dir = helpers.normalize_path(dir) + + for asset in assets.iterate_fs(dir): package_files["urls"].append( [ - target_path, - f"{protocol}://{host_name}/files/{current_path}", + asset, + f"{protocol}://{host_name}/files" + asset, ] ) @@ -30,7 +24,7 @@ def append_package_files(dir, package_files, host_name, protocol): @HttpEngine.route("/pyrobusta/package.json", "GET") def self_serve_mip_package(http_ctx, _): package_files = {"version": config.PYROBUSTA_VERSION, "deps": [], "urls": []} - tls_enabled = config.get_config("tls").lower() == "true" + tls_enabled = config.get_config(config.CONF_TLS) server_addr = http_ctx.headers["host"] if ":" not in server_addr: port = ( @@ -44,8 +38,7 @@ def self_serve_mip_package(http_ctx, _): protocol = "https" if tls_enabled else "http" logging.debug(f"[mip_repo] server_addr: {server_addr}") - root = "pyrobusta" if "pyrobusta" in os.listdir() else "lib/pyrobusta" - append_package_files(root, package_files, server_addr, protocol) + append_package_files("/lib/pyrobusta", package_files, server_addr, protocol) return "application/json", package_files diff --git a/src/pyrobusta/con/wifi.py b/src/pyrobusta/con/wifi.py index b0d85e0..bc27030 100644 --- a/src/pyrobusta/con/wifi.py +++ b/src/pyrobusta/con/wifi.py @@ -6,7 +6,7 @@ from network import WLAN, STA_IF -from ..utils.config import get_config +from ..utils.config import get_config, CONF_WIFI_SSID, CONF_WIFI_PASSWORD from ..utils import logging @@ -14,8 +14,8 @@ def initialize(): """ Initialize WLAN interface in station mode """ - ssid = get_config("wifi_ssid") - password = get_config("wifi_password") + ssid = get_config(CONF_WIFI_SSID) + password = get_config(CONF_WIFI_PASSWORD) if not ssid or not password: logging.warning(__name__ + ": missing SSID/password") diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 4171c4d..fbd6a45 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -6,7 +6,12 @@ from json import dumps from io import BytesIO -from ..utils.config import get_config +from ..utils.config import ( + get_config, + CONF_HTTP_SERVED_PATHS, + CONF_HTTP_MULTIPART, + CONF_HTTP_SERVE_FILES, +) class HeaderParsingError(ValueError): @@ -217,7 +222,7 @@ def is_norm_path_served(cls, path: str): """ Returns true if a normalized path is configured to be served """ - served_paths = set(get_config("http_served_paths").split()) + served_paths = get_config(CONF_HTTP_SERVED_PATHS) parts = path.split("/") for i, _ in enumerate(parts): current_path = "/".join(parts[: i + 1]) @@ -644,12 +649,12 @@ def enable_optional_features(): """ Enable related optional features, set in the config. """ - if get_config("http_multipart").lower() == "true": + if get_config(CONF_HTTP_MULTIPART): from pyrobusta.protocol import http_multipart http_multipart.apply_patches() - if get_config("http_serve_files").lower() == "true": + if get_config(CONF_HTTP_SERVE_FILES): from pyrobusta.protocol import http_file_server http_file_server.apply_patches() diff --git a/src/pyrobusta/server/http_server.py b/src/pyrobusta/server/http_server.py index 6761469..6b9a1c3 100644 --- a/src/pyrobusta/server/http_server.py +++ b/src/pyrobusta/server/http_server.py @@ -9,7 +9,14 @@ from ..protocol import http from ..bindings.socket_http import SocketHttp from ..stream.buffer import MemoryPool, SlidingBuffer -from ..utils.config import get_config +from ..utils.config import ( + get_config, + CONF_HTTP_PORT, + CONF_HTTPS_PORT, + CONF_HTTP_MEM_CAP, + CONF_TLS, + CONF_SOCKET_MAX_CON, +) from ..utils.helpers import normalize_path from ..utils import logging @@ -31,8 +38,8 @@ class HttpServer: CON_ACCEPT_SLEEP_MS = ( 100 # Duration of sleep between attempts to accept new connection ) - LISTEN_PORT_HTTP = 80 - LISTEN_PORT_HTTPS = 443 + LISTEN_PORT_HTTP = get_config(CONF_HTTP_PORT) + LISTEN_PORT_HTTPS = get_config(CONF_HTTPS_PORT) TLS_CERT_PATH = "/cert.der" TLS_KEY_PATH = "/key.der" CON_TIMEOUT_S = 30 @@ -41,7 +48,7 @@ class HttpServer: # Constants for controlled memory footprint # ----------------------------------------- - MEM_CAP = float(get_config("http_mem_cap")) # Default memory cap (percentage / 100) + MEM_CAP = get_config(CONF_HTTP_MEM_CAP) # Default memory cap (percentage / 100) SEND_BUF_MIN_BYTES = 512 # Minimum buffer size for responses SEND_BUF_MAX_BYTES = 4096 # Max buffer size for responses RECV_BUF_MIN_BYTES = 512 # Minimum buffer size for requests @@ -105,7 +112,7 @@ def __init__(self): self._host = "0.0.0.0" self._port = ( HttpServer.LISTEN_PORT_HTTPS - if get_config("tls").lower() == "true" + if get_config(CONF_TLS) else HttpServer.LISTEN_PORT_HTTP ) self._server = None @@ -190,11 +197,11 @@ async def start_socket_server(self): logging.debug( __name__ + f"registered endpoints: {http.HttpEngine.ENDPOINTS}" ) - self._max_clients = int(get_config("socket_max_con")) + self._max_clients = get_config(CONF_SOCKET_MAX_CON) self._init_pools(self._max_clients) ssl_ctx = None - if get_config("tls").lower() == "true": + if get_config(CONF_TLS): import ssl ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) diff --git a/src/pyrobusta/utils/config.py b/src/pyrobusta/utils/config.py index 86dbf44..4632640 100644 --- a/src/pyrobusta/utils/config.py +++ b/src/pyrobusta/utils/config.py @@ -4,39 +4,82 @@ Values can be encapsulated by single or double quotes. """ +try: + from micropython import const +except ImportError: + + def const(n): # pylint: disable=C0116 + return n + + from .helpers import normalize_path PYROBUSTA_VERSION = "v0.4.0" -CONFIG_LOADED = False CONFIG_LOCATION = "pyrobusta.env" -CONFIG_CACHE = [ - "wifi_ssid", + +# ------------------------------------------- +# Global runtime configuration keys. +# Provide these keys when using get_config(). +# ------------------------------------------- +CONF_WIFI_SSID = const(0) +CONF_WIFI_PASSWORD = const(1) +CONF_HTTP_PORT = const(2) +CONF_HTTPS_PORT = const(3) +CONF_HTTP_MULTIPART = const(4) +CONF_HTTP_MEM_CAP = const(5) +CONF_HTTP_SERVED_PATHS = const(6) +CONF_HTTP_SERVE_FILES = const(7) +CONF_SOCKET_MAX_CON = const(8) +CONF_TLS = const(9) +CONF_LOG_LEVEL = const(10) + +# ------------------- +# Configuration state +# ------------------- +_CONFIG_LOADED = False +_CONFIG_CACHE = [ + CONF_WIFI_SSID, None, - "wifi_password", + CONF_WIFI_PASSWORD, None, - "http_multipart", - "False", - "http_mem_cap", + CONF_HTTP_PORT, + 80, + CONF_HTTPS_PORT, + 443, + CONF_HTTP_MULTIPART, + False, + CONF_HTTP_MEM_CAP, 0.1, - "http_served_paths", - "/www /lib/pyrobusta", - "http_serve_files", - "True", - "socket_max_con", + CONF_HTTP_SERVED_PATHS, + ["/www", "/lib/pyrobusta"], + CONF_HTTP_SERVE_FILES, + True, + CONF_SOCKET_MAX_CON, 2, - "tls", - "False", - "log_level", + CONF_TLS, + False, + CONF_LOG_LEVEL, "info", ] -def normalize(key, value): +# -------------- +# Public helpers +# -------------- +def parse_config(key, value): """ Normalize a configuration value depending on the key. """ - if key == "http_served_paths": - return " ".join([normalize_path(p) for p in value.split()]) + if key in (CONF_HTTP_MULTIPART, CONF_HTTP_SERVE_FILES, CONF_TLS): + return value.lower() == "true" + if key in (CONF_HTTP_PORT, CONF_HTTPS_PORT, CONF_SOCKET_MAX_CON): + return int(value) + if key == CONF_HTTP_MEM_CAP: + return float(value) + if key == CONF_HTTP_SERVED_PATHS: + return [normalize_path(p) for p in value.split()] + if key not in (CONF_WIFI_SSID, CONF_WIFI_PASSWORD): + return value.lower() return value @@ -52,18 +95,22 @@ def read_config(config=CONFIG_LOCATION): if not line.strip(): continue parts = line.split("=") - key = parts[0].strip() + key_name = "CONF_" + parts[0].strip().upper() + if key_name in globals(): + key = globals()[key_name] + else: + key = len(_CONFIG_CACHE) // 2 + 1 + globals()[key_name] = key value = parts[1].strip().strip("'").strip('"') - if key and value: - value = normalize(key, value) - if ( - key in CONFIG_CACHE - and (conf_idx := CONFIG_CACHE.index(key)) % 2 == 0 - ): - CONFIG_CACHE[conf_idx + 1] = value - else: - CONFIG_CACHE.append(key) - CONFIG_CACHE.append(value) + value = parse_config(key, value) + if ( + key in _CONFIG_CACHE + and (conf_idx := _CONFIG_CACHE.index(key)) % 2 == 0 + ): + _CONFIG_CACHE[conf_idx + 1] = value + else: + _CONFIG_CACHE.append(key) + _CONFIG_CACHE.append(value) except OSError: pass @@ -74,15 +121,8 @@ def get_config(key): The cache is reloaded if the key is missing or the value is set to None. """ - global CONFIG_LOADED # pylint: disable=W0603 - if key not in CONFIG_CACHE or not CONFIG_LOADED: - read_config() - CONFIG_LOADED = True - try: - conf_idx = CONFIG_CACHE.index(key) - except IndexError: - return None - if CONFIG_CACHE[conf_idx + 1] is None: + global _CONFIG_LOADED # pylint: disable=W0603 + if _CONFIG_CACHE[2 * key + 1] is None or not _CONFIG_LOADED: read_config() - conf_idx = CONFIG_CACHE.index(key) - return CONFIG_CACHE[conf_idx + 1] + _CONFIG_LOADED = True + return _CONFIG_CACHE[2 * key + 1] diff --git a/src/pyrobusta/utils/logging.py b/src/pyrobusta/utils/logging.py index 406bd82..6c58f66 100644 --- a/src/pyrobusta/utils/logging.py +++ b/src/pyrobusta/utils/logging.py @@ -2,7 +2,7 @@ Config-based logging module for different log levels """ -from .config import get_config +from .config import get_config, CONF_LOG_LEVEL _LOG_LEVEL_WARNING = 0 _LOG_LEVEL_INFO = 1 @@ -13,7 +13,7 @@ def current_log_level(): """ Determine current log level from the config """ - current = get_config("log_level").lower() + current = get_config(CONF_LOG_LEVEL) if current == "debug": return _LOG_LEVEL_DEBUG if current == "info": diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py index 8800db0..68d5577 100644 --- a/tests/functional/test_http.py +++ b/tests/functional/test_http.py @@ -11,7 +11,13 @@ enable_optional_features, ServerBusyError, ) -from pyrobusta.utils import config +from pyrobusta.utils.config import ( + CONF_HTTP_SERVED_PATHS, + CONF_TLS, + CONF_LOG_LEVEL, + _CONFIG_CACHE, + parse_config, +) from pyrobusta.utils.helpers import normalize_path ################################################# @@ -258,13 +264,10 @@ def setup_config(tls_enabled=False, served_paths=""): http_server.HttpServer.LISTEN_PORT_HTTP = 8080 http_server.HttpServer.LISTEN_PORT_HTTPS = 4443 - config_idx = config.CONFIG_CACHE.index("log_level") - config.CONFIG_CACHE[config_idx + 1] = str("warning") - config_idx = config.CONFIG_CACHE.index("tls") - config.CONFIG_CACHE[config_idx + 1] = str(tls_enabled) - config_idx = config.CONFIG_CACHE.index("http_served_paths") - config.CONFIG_CACHE[config_idx + 1] = config.normalize( - "http_served_paths", served_paths + _CONFIG_CACHE[2 * CONF_LOG_LEVEL + 1] = "warning" + _CONFIG_CACHE[2 * CONF_TLS + 1] = tls_enabled + _CONFIG_CACHE[2 * CONF_HTTP_SERVED_PATHS + 1] = parse_config( + CONF_HTTP_SERVED_PATHS, served_paths ) enable_optional_features() diff --git a/tests/functional/test_http_multipart.py b/tests/functional/test_http_multipart.py index 5463a59..1765243 100644 --- a/tests/functional/test_http_multipart.py +++ b/tests/functional/test_http_multipart.py @@ -8,7 +8,12 @@ HttpEngine, enable_optional_features, ) -from pyrobusta.utils import config +from pyrobusta.utils.config import ( + CONF_TLS, + CONF_LOG_LEVEL, + CONF_HTTP_MULTIPART, + _CONFIG_CACHE, +) ################################################# # Test helpers @@ -136,12 +141,9 @@ def setup_config(tls_enabled=False): http_server.HttpServer.LISTEN_PORT_HTTP = 8080 http_server.HttpServer.LISTEN_PORT_HTTPS = 4443 - config_idx = config.CONFIG_CACHE.index("log_level") - config.CONFIG_CACHE[config_idx + 1] = str("warning") - config_idx = config.CONFIG_CACHE.index("http_multipart") - config.CONFIG_CACHE[config_idx + 1] = "True" - config_idx = config.CONFIG_CACHE.index("tls") - config.CONFIG_CACHE[config_idx + 1] = str(tls_enabled) + _CONFIG_CACHE[2 * CONF_LOG_LEVEL + 1] = "warning" + _CONFIG_CACHE[2 * CONF_TLS + 1] = tls_enabled + _CONFIG_CACHE[2 * CONF_HTTP_MULTIPART + 1] = True enable_optional_features() diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 10ee8a8..767c1e4 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -7,7 +7,7 @@ class TestHelpers(unittest.TestCase): """ - Base class for stat machine tests. + Base class for helper functions. """ @classmethod diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index 9afec41..31d393d 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -1,58 +1,72 @@ +import os import sys import unittest + from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import patch, mock_open -from .utils import load_module +from .utils import load_module, fake_stat class TestWebStateMachineBase(unittest.TestCase): """ - Base class for stat machine tests. + Base class for state machine tests. """ @classmethod def setUpClass(cls): - cls.config = {} + cls.base_config = {} + cls.cwd = os.getcwd() def setUp(self): - # Create mock modules - self.mock_utils = MagicMock() - self.mock_utils_config = MagicMock() - self.mock_utils_config.get_config = MagicMock() - self.mock_utils.config = self.mock_utils_config - - self.patcher = patch.dict( + # ------------------------------- + # Patch current working directory + # ------------------------------- + self.helpers_module = load_module("pyrobusta/utils/helpers.py") + self.cwd_patcher = patch.object( + self.helpers_module, "getcwd", return_value=self.cwd + ) + self.cwd_patcher.start() + self.addCleanup(self.cwd_patcher.stop) + + # ------------------- + # Patch config module + # ------------------- + self.config = dict(self.base_config) + self.config_module = load_module("pyrobusta/utils/config.py") + self.module_patcher = patch.dict( sys.modules, - { - "pyrobusta.utils": self.mock_utils, - "pyrobusta.utils.config": self.mock_utils_config, - }, + {"pyrobusta.utils.config": self.config_module}, ) - self.patcher.start() - self.set_mock_config() + self.module_patcher.start() + self.addCleanup(self.module_patcher.stop) - # Load your web and buffer modules - self.helpers_module = load_module("pyrobusta/utils/helpers.py") - buffer_module = load_module("pyrobusta/stream/buffer.py") - self.web_module = load_module("pyrobusta/protocol/http.py") - self.web_module.enable_optional_features() + def open_side_effect(*args, **kwargs): + data = "\n".join(f"{k}={v}" for k, v in self.config.items()) + return mock_open(read_data=data)(*args, **kwargs) - self.engine = self.web_module.HttpEngine() + self.open_patcher = patch.object( + self.config_module, + "open", + side_effect=open_side_effect, + ) + self.open_patcher.start() + self.addCleanup(self.open_patcher.stop) + + # ------------------------------------------------ + # Load remaining modules, enable optional features + # ------------------------------------------------ + self.http_module = load_module("pyrobusta/protocol/http.py") + self.http_module.enable_optional_features() + self.engine = self.http_module.HttpEngine() + + # -------------------- + # HTTP engine, buffers + # -------------------- + buffer_module = load_module("pyrobusta/stream/buffer.py") self.rx = buffer_module.SlidingBuffer(bytearray(1024)) self.tx = buffer_module.SlidingBuffer(bytearray(1024)) - def tearDown(self): - self.patcher.stop() - - def set_mock_config(self): - def side_effect(input_arg, *_, **__): - if input_arg in self.config: - return self.config[input_arg] - raise ValueError(f"Unexpected config key: {input_arg}") - - self.mock_utils_config.get_config.side_effect = side_effect - class TestWebStateMachine(TestWebStateMachineBase): """ @@ -61,7 +75,8 @@ class TestWebStateMachine(TestWebStateMachineBase): @classmethod def setUpClass(cls): - cls.config = {"http_multipart": "False", "http_serve_files": "False"} + cls.base_config = {"http_multipart": "False", "http_serve_files": "False"} + cls.cwd = os.getcwd() def test_status_parsing_valid(self): request = b"GET /index.html HTTP/1.1\r\nContent-Length:10" @@ -156,7 +171,7 @@ def test_header_parsing_error(self): b"space in header name: value", b"new-line-in-header:\nvalue", ): - with self.assertRaises(self.web_module.HeaderParsingError): + with self.assertRaises(self.http_module.HeaderParsingError): self.engine._parse_headers(case) def test_routing_unsupported_method(self): @@ -412,8 +427,23 @@ def test_chunked_transfer_encoding_chunk_incomplete(self): self.assertEqual(self.engine.status_code, None) self.assertEqual(self.engine.state, self.engine._recv_chunk_st) + +class TestWebHelpers(TestWebStateMachineBase): + """ + Tests for helper functions. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = {"http_multipart": "False", "http_serve_files": "False"} + # Simplify file-open assertions by treating resources + # as if they are installed at the root (/) rather than + # relative to the current working directory. + cls.cwd = "/" + def test_path_serving_list(self): self.config["http_served_paths"] = "/path/to/dir1 /path/to/dir2" + self.config_module.read_config() self.assertEqual(self.engine.is_norm_path_served(""), False) self.assertEqual(self.engine.is_norm_path_served("/"), False) self.assertEqual(self.engine.is_norm_path_served("/path/to/dir1"), True) @@ -426,12 +456,14 @@ def test_path_serving_list(self): def test_path_serving_root(self): self.config["http_served_paths"] = "/" + self.config_module.read_config() self.assertEqual(self.engine.is_norm_path_served(""), True) self.assertEqual(self.engine.is_norm_path_served("/"), True) self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), True) def test_path_serving_none(self): self.config["http_served_paths"] = "" + self.config_module.read_config() self.assertEqual(self.engine.is_norm_path_served(""), False) self.assertEqual(self.engine.is_norm_path_served("/"), False) self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), False) @@ -439,12 +471,13 @@ def test_path_serving_none(self): class TestMultipartStateMachine(TestWebStateMachineBase): """ - Tests for multipart handling + Tests for multipart handling. """ @classmethod def setUpClass(cls): - cls.config = {"http_multipart": "True", "http_serve_files": "True"} + cls.base_config = {"http_multipart": "True", "http_serve_files": "False"} + cls.cwd = os.getcwd() def test_multipart_parser(self): for case in [ @@ -477,7 +510,7 @@ def test_multipart_parser(self): {"content-type": 'multipart/form-data;boundary=missing-quote"'}, ]: with self.subTest(headers=case): - with self.assertRaises(self.web_module.HeaderParsingError): + with self.assertRaises(self.http_module.HeaderParsingError): self.engine._get_mp_boundary(case) def test_multipart_receiver_valid(self): @@ -601,5 +634,131 @@ def test_multipart_receiver_last_part(self): self.assertEqual(self.engine.mp_is_last, True) +class TestFileServerStateMachine(TestWebStateMachineBase): + """ + Tests for file serving. + """ + + @classmethod + def setUpClass(cls): + cls.base_config = { + "http_multipart": "False", + "http_serve_files": "True", + "http_served_paths": "/www", + } + # Simplify file-open assertions by treating resources + # as if they are installed at the root (/) rather than + # relative to the current working directory. + cls.cwd = "/" + + @staticmethod + def patch_all(f): + @patch("pyrobusta.protocol.http_file_server.stat", fake_stat) + def decorated(*args, **kwargs): + return f(*args, **kwargs) + + return decorated + + @patch_all + def test_file_serving_missing_file(self, *_): + self.engine.url = b"/files/www/nonexistent.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._send_file_st + + self.engine.state(self.rx, self.tx, self.engine.url) + + self.assertEqual(self.engine.status_code, 404) + self.assertEqual(self.engine.state, None) + + @patch_all + def test_file_serving_root(self, *_): + self.engine.url = b"/" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._send_file_st + file_content = "index content" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + m.assert_called_once_with("/www/index.html", "rb") + + self.assertEqual(response_generator.read(), file_content) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_all + def test_file_serving_files_endpoint(self, *_): + self.engine.url = b"/files/www/scripts.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._send_file_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + m.assert_called_once_with("/www/scripts.js", "rb") + + self.assertEqual(response_generator.read(), file_content) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_all + def test_file_serving_known_content_type(self, *_): + self.engine.url = b"/scripts.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._send_file_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + m.assert_called_once_with("/www/scripts.js", "rb") + + self.assertEqual(response_generator.read(), file_content) + self.assertEqual( + self.engine._lookup(self.engine.resp_headers, b"content-type"), + b"application/javascript", + ) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_all + def test_file_serving_fallback_content_type(self, *_): + self.engine.url = b"/scripts.unknown" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._send_file_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + m.assert_called_once_with("/www/scripts.unknown", "rb") + + self.assertEqual(response_generator.read(), file_content) + self.assertEqual( + self.engine._lookup(self.engine.resp_headers, b"content-type"), + b"application/octet-stream", + ) + self.assertEqual(self.engine.status_code, 200) + self.assertEqual(self.engine.state, None) + + @patch_all + def test_file_serving_unserved_content_rejected(self, *_): + self.engine.url = b"/files/unserved/script.js" + self.engine.method = b"GET" + self.engine.version = b"HTTP/1.1" + self.engine.state = self.engine._send_file_st + file_content = "data" + + with patch("builtins.open", mock_open(read_data=file_content)) as m: + response_generator = self.engine.state(self.rx, self.tx, self.engine.url) + m.assert_not_called() + + self.assertEqual(response_generator, None) + self.assertEqual(self.engine.status_code, 403) + self.assertEqual(self.engine.state, None) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/unit/utils.py b/tests/unit/utils.py index fdec51c..b9b04f8 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -4,6 +4,8 @@ import sys import importlib.util +import time +import os from pathlib import Path @@ -25,3 +27,20 @@ def load_module(relative_path): sys.modules[module_name] = mod spec.loader.exec_module(mod) return mod + + +def fake_stat(size=1024): + return os.stat_result( + ( + 0o100644, # st_mode (regular file, 644 perms) + 12345678, # st_ino + 2049, # st_dev + 1, # st_nlink + 1000, # st_uid + 1000, # st_gid + size, # st_size + int(time.time()), # st_atime + int(time.time()), # st_mtime + int(time.time()), # st_ctime + ) + )