diff --git a/Makefile b/Makefile index 4486799..f99940a 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,15 @@ -PYROBUSTA_VERSION := 0.3.0 +PYROBUSTA_VERSION := v0.4.0 DEVICE ?= u0 SRC_DIR := src TEST_DIR := tests -EXAMPLE_DIR := example/mem_usage +EXAMPLE_DIR := example/mip_repo BUILD_DIR := build DIST_DIR := dist -PKG := pyrobusta TLS_DIR := tls +ASSETS_DIR := assets + +PKG := pyrobusta MICROPY_ROOT := external/micropython MPY_CROSS := $(MICROPY_ROOT)/mpy-cross/build/mpy-cross @@ -47,6 +49,11 @@ toolchain: # ----------------------------- .PHONY: build build: $(MPY_TARGETS) $(INIT_TARGETS) + @mkdir -p $(BUILD_DIR) + @if [ -d assets ]; then \ + echo "Copying assets/ -> $(BUILD_DIR)"; \ + cp -r assets $(BUILD_DIR)/${PKG}/; \ + fi # Compile .py -> .mpy $(BUILD_DIR)/%.mpy: $(SRC_DIR)/%.py @@ -66,6 +73,7 @@ $(BUILD_DIR)/%.py: $(SRC_DIR)/%.py .PHONY: deploy deploy: @echo "Uploading build/$(PKG) to device $(DEVICE)" + @mpremote $(DEVICE) soft-reset @mpremote $(DEVICE) mkdir :/lib || true @find $(BUILD_DIR)/$(PKG) | while read source; do \ rel=$${source#$(BUILD_DIR)/}; \ @@ -79,6 +87,7 @@ deploy: fi; \ sleep 1; \ done + @mpremote $(DEVICE) reset # ----------------------------- # Deploy custom configuration @@ -86,7 +95,20 @@ deploy: .PHONY: deploy-config deploy-config: @echo "Uploading pyrobusta.env" + @mpremote $(DEVICE) soft-reset @if [ -f pyrobusta.env ]; then mpremote $(DEVICE) cp pyrobusta.env :pyrobusta.env; fi + @mpremote $(DEVICE) reset + + +# ----------------------------- +# Deploy index page # TODO use install_www from assets module +# ----------------------------- +.PHONY: deploy-www +deploy-www: + @echo "Deploying /www" + @mpremote $(DEVICE) soft-reset + @mpremote $(DEVICE) run scripts/install_www.py + @mpremote $(DEVICE) reset # ----------------------------- # Full redeploy @@ -107,6 +129,9 @@ publish: @sed -E -i.bak 's/(PYROBUSTA_VERSION[[:space:]]*=[[:space:]]*)"[^"]*"/\1"$(PYROBUSTA_VERSION)"/' \ $(SRC_DIR)/pyrobusta/utils/config.py \ && rm -f $(SRC_DIR)/pyrobusta/utils/config.py.bak + @sed -E -i.bak 's/(PyRobusta[[:space:]]).+([[:space:]]Web Server)/\1$(PYROBUSTA_VERSION)\2/' \ + $(ASSETS_DIR)/www/*.html \ + && rm -f $(ASSETS_DIR)/www/*.html.bak $(MAKE) clean $(MAKE) build BUILD_DIR=$(DIST_DIR) scripts/update_package.bash $(DIST_DIR) package.json $(PYROBUSTA_VERSION) @@ -128,7 +153,7 @@ stage-example: @echo "Copying built package" @cp -r build/pyrobusta $(RUNTIME_DIR)/lib - @echo "Copying example files" + @echo "Copying example app" @cp $(EXAMPLE_DIR)/app.py $(RUNTIME_DIR)/ @cp $(EXAMPLE_DIR)/boot.py $(RUNTIME_DIR)/ @@ -152,16 +177,20 @@ run-unix: stage-example .PHONY: deploy-example deploy-example: @echo "Uploading boot.py" + @mpremote $(DEVICE) soft-reset mpremote $(DEVICE) cp $(EXAMPLE_DIR)/boot.py :boot.py + mpremote $(DEVICE) cp $(EXAMPLE_DIR)/app.py :app.py @echo "Uploading pyrobusta.env" @if [ -f pyrobusta.env ]; then mpremote $(DEVICE) cp pyrobusta.env :pyrobusta.env; fi + @mpremote $(DEVICE) reset # ----------------------------- # Run example directly # ----------------------------- .PHONY: run-device run-device: + @mpremote $(DEVICE) soft-reset mpremote $(DEVICE) run $(EXAMPLE_DIR)/app.py @@ -226,7 +255,9 @@ test-unix: TLS_DIR=$(TEST_RUNTIME) test-unix: stage-test tls-cert @cd $(TEST_RUNTIME); \ for test in test_*.py; do \ + echo "\n==================================="; \ echo "Running $$test"; \ + echo "==================================="; \ MICROPYPATH=":.frozen:lib" ../$(MICROPYTHON) $$(basename $$test) || exit 1; \ done @@ -234,12 +265,16 @@ test-unix: stage-test tls-cert # Run functional tests on device # ----------------------------- .PHONY: test-device -test-device: #clean-device upload +test-device: stage-test #clean-device upload + @mpremote $(DEVICE) soft-reset @cd $(TEST_RUNTIME); \ for test in test_*.py; do \ + echo "\n==================================="; \ echo "Running $$test"; \ + echo "==================================="; \ mpremote $(DEVICE) run $$(basename $$test) || exit 1; \ done + @mpremote $(DEVICE) reset # ================================================ # Utilities for TLS @@ -272,8 +307,10 @@ tls-cert: # ----------------------------- .PHONY: deploy-cert deploy-cert: + @mpremote $(DEVICE) soft-reset @mpremote $(DEVICE) cp $(TLS_DIR)/key.der :key.der @mpremote $(DEVICE) cp $(TLS_DIR)/cert.der :cert.der + @mpremote $(DEVICE) reset # ================================================ # Cleanup @@ -305,4 +342,6 @@ clean: clean-build clean-runtime # ----------------------------- .PHONY: clean-device clean-device: + @mpremote $(DEVICE) soft-reset mpremote $(DEVICE) run scripts/clean_device.py + @mpremote $(DEVICE) reset diff --git a/assets/www/examples.html b/assets/www/examples.html new file mode 100644 index 0000000..095c675 --- /dev/null +++ b/assets/www/examples.html @@ -0,0 +1,157 @@ + + + + + + PyRobusta Home + + + + +

Getting Started

+ + ← Back + +

This page presents useful examples to configure your server.

+ +
+ +

Server configuration

+ +

+ +
+ +

Simple Server Application

+

The below example demonstrates how to set up a simple application, exposed at /app.

+ + + +

Soft reset the device and upload app.py and boot.py with mpremote.

+ + +

Hard reset the device to start the application and connect over REPL.

+ + +

Use curl to test your application.

+ + + + + \ No newline at end of file diff --git a/assets/www/index.html b/assets/www/index.html new file mode 100644 index 0000000..2f49452 --- /dev/null +++ b/assets/www/index.html @@ -0,0 +1,55 @@ + + + + + + PyRobusta Home + + + + +

PyRobusta Home

+ +

The server is running correctly and is ready to serve content.

+ +
+ +

Available Resources

+ + + + + + \ No newline at end of file diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html new file mode 100644 index 0000000..095c675 --- /dev/null +++ b/dist/pyrobusta/assets/www/examples.html @@ -0,0 +1,157 @@ + + + + + + PyRobusta Home + + + + +

Getting Started

+ + ← Back + +

This page presents useful examples to configure your server.

+ +
+ +

Server configuration

+ +

+ +
+ +

Simple Server Application

+

The below example demonstrates how to set up a simple application, exposed at /app.

+ + + +

Soft reset the device and upload app.py and boot.py with mpremote.

+ + +

Hard reset the device to start the application and connect over REPL.

+ + +

Use curl to test your application.

+ + + + + \ No newline at end of file diff --git a/dist/pyrobusta/assets/www/index.html b/dist/pyrobusta/assets/www/index.html new file mode 100644 index 0000000..2f49452 --- /dev/null +++ b/dist/pyrobusta/assets/www/index.html @@ -0,0 +1,55 @@ + + + + + + PyRobusta Home + + + + +

PyRobusta Home

+ +

The server is running correctly and is ready to serve content.

+ +
+ +

Available Resources

+ + + + + + \ No newline at end of file diff --git a/dist/pyrobusta/bindings/socket_http.mpy b/dist/pyrobusta/bindings/socket_http.mpy index b56ceb7..6ee7b49 100644 Binary files a/dist/pyrobusta/bindings/socket_http.mpy and b/dist/pyrobusta/bindings/socket_http.mpy differ diff --git a/dist/pyrobusta/con/wifi.mpy b/dist/pyrobusta/con/wifi.mpy index 8f5f2bc..5e82314 100644 Binary files a/dist/pyrobusta/con/wifi.mpy and b/dist/pyrobusta/con/wifi.mpy differ diff --git a/dist/pyrobusta/protocol/http.mpy b/dist/pyrobusta/protocol/http.mpy index d2a04ba..e645c38 100644 Binary files a/dist/pyrobusta/protocol/http.mpy and b/dist/pyrobusta/protocol/http.mpy differ diff --git a/dist/pyrobusta/protocol/http_file_server.mpy b/dist/pyrobusta/protocol/http_file_server.mpy new file mode 100644 index 0000000..c6e8a74 Binary files /dev/null and b/dist/pyrobusta/protocol/http_file_server.mpy differ diff --git a/dist/pyrobusta/protocol/http_multipart.mpy b/dist/pyrobusta/protocol/http_multipart.mpy index 03fdbed..e6dbbce 100644 Binary files a/dist/pyrobusta/protocol/http_multipart.mpy and b/dist/pyrobusta/protocol/http_multipart.mpy differ diff --git a/dist/pyrobusta/server/http_server.mpy b/dist/pyrobusta/server/http_server.mpy index 467a75e..d798948 100644 Binary files a/dist/pyrobusta/server/http_server.mpy and b/dist/pyrobusta/server/http_server.mpy differ diff --git a/dist/pyrobusta/utils/assets.mpy b/dist/pyrobusta/utils/assets.mpy new file mode 100644 index 0000000..7093beb Binary files /dev/null and b/dist/pyrobusta/utils/assets.mpy differ diff --git a/dist/pyrobusta/utils/config.mpy b/dist/pyrobusta/utils/config.mpy index 5e9b55a..671b920 100644 Binary files a/dist/pyrobusta/utils/config.mpy and b/dist/pyrobusta/utils/config.mpy differ diff --git a/dist/pyrobusta/utils/helpers.mpy b/dist/pyrobusta/utils/helpers.mpy index 19f006a..5760dd1 100644 Binary files a/dist/pyrobusta/utils/helpers.mpy and b/dist/pyrobusta/utils/helpers.mpy differ diff --git a/docs/configuration.md b/docs/configuration.md index 8ab2b12..f6b63fb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -9,7 +9,8 @@ to upload it to the root directory of the target device. | wifi_password | Password of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None | | http_multipart | Enable multipart HTTP requests/responses. | "False" | | http_mem_cap | Max memory cap (% × 0.01) of usable heap for HTTP request/response stream buffers. | 0.1 | -| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/lib/pyrobusta" | +| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/www /lib/pyrobusta" | +| http_serve_files | Enable/disable file serving. | "True" | | socket_max_con | Max number of socket connections of any enabled application server. | 2 | | tls | Enable/disable TLS. When turned on, cert.der/key.der must be installed at the root. | "False" | -| log_level | Can be one of: warning, info, debug. | "warning" | +| log_level | Can be one of: warning, info, debug. | "info" | diff --git a/example/demo_app/app.py b/example/demo_app/app.py new file mode 100644 index 0000000..4d2491b --- /dev/null +++ b/example/demo_app/app.py @@ -0,0 +1,33 @@ +import asyncio +from gc import mem_free, mem_alloc + +from pyrobusta.server import http_server +from pyrobusta.protocol.http import HttpEngine +from pyrobusta.utils import logging + + +@HttpEngine.route("/app", "GET") +def app(http_ctx, payload): + free = mem_free() + value_format = "bytes" + + if http_ctx.query: + value_format = http_ctx.get_url_encoded_query_param( + http_ctx.query, "format", default="bytes" + ) + if value_format not in ("%", "bytes"): + raise ValueError("invalid format") + + if value_format == "%": + free = round(100 * free / (free + mem_alloc()), 2) + + return "text/plain", (f"Free memory [{value_format}]: {free}\n") + + +def main(): + http_server.main() + try: + asyncio.get_event_loop().run_forever() + except Exception as e: + logging.warning(f"loop stopped: {e}") + asyncio.get_event_loop().close() diff --git a/example/demo_app/boot.py b/example/demo_app/boot.py new file mode 100644 index 0000000..a7c9c58 --- /dev/null +++ b/example/demo_app/boot.py @@ -0,0 +1,12 @@ +# This file is executed on every boot (including wake-boot from deepsleep) +import machine +from os import listdir + +from pyrobusta.con import wifi + +connected = wifi.initialize() +if connected and not machine.reset_cause() == machine.SOFT_RESET: + if "app.py" in listdir(): + import app + + app.main() diff --git a/example/mem_usage/app.py b/example/mem_usage/app.py index bccf0ae..b3ac746 100644 --- a/example/mem_usage/app.py +++ b/example/mem_usage/app.py @@ -23,11 +23,11 @@ def mem_usage(http_ctx, _): selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "") if selector == "free": if value_format == "%": - free = 100 * free / (used + free) + free = round(100 * free / (used + free),2) return "text/plain", f"Free [{value_format}]: {free}\n" if selector == "used": if value_format == "%": - used = 100 * used / (used + free) + used = round(100 * used / (used + free),2) return "text/plain", f"Used [{value_format}]: {used}\n" if selector == "total": return "text/plain", f"Total [bytes]: {used + free}\n" diff --git a/example/mem_usage/boot.py b/example/mem_usage/boot.py index d96f6f6..a7c9c58 100644 --- a/example/mem_usage/boot.py +++ b/example/mem_usage/boot.py @@ -1,4 +1,12 @@ # This file is executed on every boot (including wake-boot from deepsleep) +import machine +from os import listdir + from pyrobusta.con import wifi -wifi.initialize() +connected = wifi.initialize() +if connected and not machine.reset_cause() == machine.SOFT_RESET: + if "app.py" in listdir(): + import app + + app.main() diff --git a/example/mip_repo/app.py b/example/mip_repo/app.py index 5955b9a..2564e96 100644 --- a/example/mip_repo/app.py +++ b/example/mip_repo/app.py @@ -22,7 +22,7 @@ def append_package_files(dir, package_files, host_name, protocol): package_files["urls"].append( [ target_path, - f"{protocol}://{host_name}/{current_path}", + f"{protocol}://{host_name}/files/{current_path}", ] ) @@ -38,7 +38,8 @@ def self_serve_mip_package(http_ctx, _): if tls_enabled else http_server.HttpServer.LISTEN_PORT_HTTP ) - server_addr += f":{port}" + if not server_addr in (80, 443): + server_addr += f":{port}" protocol = "https" if tls_enabled else "http" diff --git a/example/mip_repo/boot.py b/example/mip_repo/boot.py index d96f6f6..a7c9c58 100644 --- a/example/mip_repo/boot.py +++ b/example/mip_repo/boot.py @@ -1,4 +1,12 @@ # This file is executed on every boot (including wake-boot from deepsleep) +import machine +from os import listdir + from pyrobusta.con import wifi -wifi.initialize() +connected = wifi.initialize() +if connected and not machine.reset_cause() == machine.SOFT_RESET: + if "app.py" in listdir(): + import app + + app.main() diff --git a/package.json b/package.json index ddfce0e..ad11182 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,5 @@ { - "version": "0.3.0", + "version": "v0.4.0", "urls": [ [ "pyrobusta/transport/socket.mpy", @@ -9,6 +9,14 @@ "pyrobusta/transport/__init__.py", "github:szeka9/PyRobusta/dist/pyrobusta/transport/__init__.py" ], + [ + "pyrobusta/assets/www/index.html", + "github:szeka9/PyRobusta/dist/pyrobusta/assets/www/index.html" + ], + [ + "pyrobusta/assets/www/examples.html", + "github:szeka9/PyRobusta/dist/pyrobusta/assets/www/examples.html" + ], [ "pyrobusta/utils/helpers.mpy", "github:szeka9/PyRobusta/dist/pyrobusta/utils/helpers.mpy" @@ -25,6 +33,10 @@ "pyrobusta/utils/logging.mpy", "github:szeka9/PyRobusta/dist/pyrobusta/utils/logging.mpy" ], + [ + "pyrobusta/utils/assets.mpy", + "github:szeka9/PyRobusta/dist/pyrobusta/utils/assets.mpy" + ], [ "pyrobusta/protocol/http_multipart.mpy", "github:szeka9/PyRobusta/dist/pyrobusta/protocol/http_multipart.mpy" @@ -37,6 +49,10 @@ "pyrobusta/protocol/__init__.py", "github:szeka9/PyRobusta/dist/pyrobusta/protocol/__init__.py" ], + [ + "pyrobusta/protocol/http_file_server.mpy", + "github:szeka9/PyRobusta/dist/pyrobusta/protocol/http_file_server.mpy" + ], [ "pyrobusta/stream/__init__.py", "github:szeka9/PyRobusta/dist/pyrobusta/stream/__init__.py" diff --git a/scripts/install_www.py b/scripts/install_www.py new file mode 100644 index 0000000..fcc8913 --- /dev/null +++ b/scripts/install_www.py @@ -0,0 +1,3 @@ +import pyrobusta.utils.assets as assets + +assets.install_www() \ No newline at end of file diff --git a/src/pyrobusta/bindings/socket_http.py b/src/pyrobusta/bindings/socket_http.py index 799fea7..f646719 100644 --- a/src/pyrobusta/bindings/socket_http.py +++ b/src/pyrobusta/bindings/socket_http.py @@ -4,12 +4,11 @@ import asyncio from asyncio import sleep_ms # pylint: disable=E1101 -from gc import mem_free, collect +from gc import collect -from ..stream.buffer import MemoryPool, SlidingBuffer, BufferFullError +from ..stream.buffer import BufferFullError from ..transport.socket import SocketBase -from ..protocol.http import HttpEngine, ServerBusyError -from ..utils.config import get_config +from ..protocol.http import HttpEngine, ServerBusyError, HeaderParsingError from ..utils import logging @@ -19,75 +18,19 @@ class SocketHttp(SocketBase): buffer management and state machine parser. """ - # Constants for memory footprint - MEM_CAP = float( - get_config("http_mem_cap") - ) # Default memory cap (percentage / 100) of free heap - SEND_BUF_MIN_BYTES = 512 # Minimum buffer size for responses - SEND_BUF_MAX_BYTES = 4096 # Max buffer size for responses - RECV_BUF_MIN_BYTES = 512 # Minimum buffer size for requests - RECV_BUF_MAX_BYTES = 4096 # Max buffer size for requests - CONN_OVERHEAD = 1024 # Overhead per connection - MTU_SIZE = 1460 # TCP maximum transmission unit - - # Timing settings + MTU_SIZE = 1460 STATE_MACHINE_SLEEP_MS = 2 RESP_HANDLER_SLEEP_MS = 2 RECV_TIMEOUT_SECONDS = 10 - # Static buffer pools - initialized by init_pools() - RECV_POOL = None - SEND_POOL = None - - @staticmethod - def init_pools(max_sockets): - """ - Initialize pool of buffers for sending/receiving based on different profiles - """ - mem_available = mem_free() - con_limit = max(1, max_sockets) - usable = int(SocketHttp.MEM_CAP * mem_available) - is_low_memory = (usable / con_limit) < ( - SocketHttp.RECV_BUF_MAX_BYTES - + SocketHttp.SEND_BUF_MAX_BYTES - + SocketHttp.CONN_OVERHEAD - ) - if is_low_memory: - logging.warning( - __name__ + ".init_pools: low-memory mode with reduced buffer size" - ) - recv_size = ( - SocketHttp.RECV_BUF_MIN_BYTES - if is_low_memory - else SocketHttp.RECV_BUF_MAX_BYTES - ) - send_size = ( - SocketHttp.SEND_BUF_MIN_BYTES - if is_low_memory - else SocketHttp.SEND_BUF_MAX_BYTES - ) - per_conn = recv_size + send_size + SocketHttp.CONN_OVERHEAD - if usable < per_conn: - raise MemoryError( - ( - f"Insufficient memory: {mem_available // 1024} KB " - f"at {SocketHttp.MEM_CAP*100}% cap, " - f"at least {per_conn // 1024} KB required" - ) - ) - con_limit = min(usable // per_conn, con_limit) - logging.info((__name__ + f".init_pools: {con_limit} connection(s) allowed")) - SocketHttp.RECV_POOL = MemoryPool(recv_size, con_limit, wrapper=SlidingBuffer) - SocketHttp.SEND_POOL = MemoryPool(send_size, con_limit, wrapper=SlidingBuffer) - __slots__ = ("_engine", "_prev_state", "_recv_buf", "_send_buf") - def __init__(self, reader, writer): + def __init__(self, reader, writer, recv_buf, send_buf): super().__init__(reader, writer) self._engine = HttpEngine() self._prev_state = None - self._recv_buf = None - self._send_buf = None + self._recv_buf = recv_buf + self._send_buf = send_buf async def _flush_response(self): data = self._send_buf.peek() @@ -99,38 +42,41 @@ async def _flush_response(self): async def run(self): """ Handle socket connection with HTTP state machine parser. - - 1) reserve buffer - - 2) run state machine parser - - 3) release reserved buffers and terminate socket connection """ - await self._reserve_buffers() self._prev_state = None try: while self._engine.state is not None: await self._run_state_machine() await sleep_ms(SocketHttp.STATE_MACHINE_SLEEP_MS) - except Exception as e: # pylint: disable=W0718 - logging.warning(__name__ + f": error in run_web: {e}") finally: - if self._send_buf: - self._send_buf.consume() - SocketHttp.SEND_POOL.release(self._send_buf) - if self._recv_buf: - self._recv_buf.consume() - SocketHttp.RECV_POOL.release(self._recv_buf) await self.close() collect() - async def _reserve_buffers(self): - if SocketHttp.SEND_POOL is None or SocketHttp.RECV_POOL is None: - raise RuntimeError("Pools are ninitialized") - - while not self._recv_buf or not self._send_buf: - if not self._recv_buf: - self._recv_buf = SocketHttp.RECV_POOL.reserve() - if not self._send_buf: - self._send_buf = SocketHttp.SEND_POOL.reserve() - await sleep_ms(SocketHttp.STATE_MACHINE_SLEEP_MS) + async def _read_to_buf(self): + buf_free = self._recv_buf.capacity - self._recv_buf.size() + if not buf_free: + self._engine.on_buffer_full(self._send_buf) + await self._flush_response() + return 0 + try: + request = await self.read( + read_bytes=buf_free, + decoding=None, + timeout_seconds=SocketHttp.RECV_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + self._engine.on_timeout(self._send_buf) + await self._flush_response() + return 0 + except Exception as e: # pylint: disable=W0718 + self._engine.on_failure( + self._send_buf, b"Read error: " + str(e).encode("ascii") + ) + await self._flush_response() + return 0 + self._recv_buf.write(request) + logging.debug(__name__ + f"._read_to_buf: [{request}]") + return len(request) async def _run_state_machine(self): if self._prev_state == self._engine.state or self._prev_state is None: @@ -151,7 +97,11 @@ async def _run_state_machine(self): await self._flush_response() return except ServerBusyError: - self._engine.on_busy(self._send_buf) + self._engine.on_unavailable(self._send_buf) + await self._flush_response() + return + except HeaderParsingError: + self._engine.on_client_error(self._send_buf, b"Invalid headers") await self._flush_response() return except Exception as e: # pylint: disable=W0718 @@ -162,32 +112,6 @@ async def _run_state_machine(self): if self._engine.state is None and resp_handler is not None: await self._response_handler(resp_handler) - async def _read_to_buf(self): - buf_free = self._recv_buf.capacity - self._recv_buf.size() - if not buf_free: - self._engine.on_buffer_full(self._send_buf) - await self._flush_response() - return 0 - try: - request = await self.read( - read_bytes=buf_free, - decoding=None, - timeout_seconds=SocketHttp.RECV_TIMEOUT_SECONDS, - ) - except asyncio.TimeoutError: - self._engine.on_timeout(self._send_buf) - await self._flush_response() - return 0 - except Exception as e: # pylint: disable=W0718 - self._engine.on_failure( - self._send_buf, b"Read error: " + str(e).encode("ascii") - ) - await self._flush_response() - return 0 - self._recv_buf.write(request) - logging.debug(__name__ + f"._read_to_buf: [{request}]") - return len(request) - async def _response_handler(self, resp_handler): if "closure" == type(resp_handler).__name__: for is_finished in resp_handler(self._send_buf): diff --git a/src/pyrobusta/con/wifi.py b/src/pyrobusta/con/wifi.py index 8755f44..b0d85e0 100644 --- a/src/pyrobusta/con/wifi.py +++ b/src/pyrobusta/con/wifi.py @@ -2,7 +2,10 @@ Helpers for setting up Wi-Fi in station mode """ +from time import sleep + from network import WLAN, STA_IF + from ..utils.config import get_config from ..utils import logging @@ -13,19 +16,30 @@ def initialize(): """ ssid = get_config("wifi_ssid") password = get_config("wifi_password") + if not ssid or not password: - logging.warning(__name__ + ": missing SSID/password, skip initialization") - return + logging.warning(__name__ + ": missing SSID/password") + return False sta_if = WLAN(STA_IF) sta_if.active(True) - nets = sta_if.scan() - for net in nets: - if net[0].decode() == get_config("wifi_ssid"): - logging.info(__name__ + f": network {net[0]} found!") - sta_if.connect(net[0], get_config("wifi_password")) - logging.info(__name__ + f": connected, available at {sta_if.ifconfig()[0]}") - break + if sta_if.isconnected(): + logging.info(__name__ + f": already connected IP={sta_if.ifconfig()[0]}") + return True + + sta_if.connect(ssid, password) + + timeout = 30 + while timeout > 0: + if sta_if.isconnected(): + ip = sta_if.ifconfig()[0] + logging.info(__name__ + f": connected, IP={ip}") + return True + sleep(1) + timeout -= 1 + + logging.warning(__name__ + ": connection failed") + return False def get_address(): @@ -33,4 +47,6 @@ def get_address(): Get the address of the WLAN interface """ sta_if = WLAN(STA_IF) - return sta_if.ifconfig()[0] + if sta_if.isconnected(): + return sta_if.ifconfig()[0] + return None diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 0de2a54..4171c4d 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -5,10 +5,8 @@ from json import dumps from io import BytesIO -from os import stat from ..utils.config import get_config -from ..utils.helpers import normalize_path class HeaderParsingError(ValueError): @@ -199,17 +197,20 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None): :param key: key to parse from the query :param default: default value to return when key is not present """ - idx_start = query.find(key + "=") - if idx_start != -1: - idx_end = -1 - idx_end = query.find("&", idx_start) - if idx_start > -1: - if idx_end > -1: - return query[idx_start + len(key) + 1 : idx_end] - return query[idx_start + len(key) + 1 :] - if default is None: + if query.startswith(key + "="): + idx_start = 0 + elif (idx_start := query.find("&" + key + "=")) != -1: + idx_start += 1 + elif default is None: raise KeyError() - return default + else: + return default + + idx_end = -1 + idx_end = query.find("&", idx_start) + if idx_end > -1: + return query[idx_start + len(key) + 1 : idx_end] + return query[idx_start + len(key) + 1 :] @classmethod def is_norm_path_served(cls, path: str): @@ -262,13 +263,25 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: headers = {} for line in header_lines: # pylint: disable=W0511 - # TODO: support for UTF-8 in field values (e.g filenames), can be board dependent if any(c > 127 for c in line): raise HeaderParsingError("Non-ASCII character") if b":" not in line: raise HeaderParsingError() name, value = line.split(b":", 1) + if not name: + raise HeaderParsingError("Empty header name") + for c in name: + if ( + 48 <= c <= 57 # 0-9 + or 65 <= c <= 90 # A-Z + or 97 <= c <= 122 # a-z + or c in (45, 95) # -_ + ): + continue + raise HeaderParsingError("Invalid header name") name = name.strip().lower().decode(cls.ASCII) + if any((c < 32 and c != 9) or c == 127 for c in value): + raise HeaderParsingError("Invalid header value") if name == cls.CONTENT_LENGTH: value = int(value.strip()) else: @@ -277,18 +290,35 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: return headers @staticmethod - def _is_multipart(headers: dict) -> str: + def _get_mp_boundary(headers: dict) -> str: """Determine from the headers if a request is multipart, and return the boundary value""" content_type = headers.get("content-type") - if content_type and content_type.lower().startswith("multipart/form-data"): - parts = content_type.split(";") - for part in parts[1:]: - if "=" in part: - key, value = part.strip().split("=", 1) - if key.strip().lower() == "boundary": - boundary = value.strip().strip('"') - return boundary if boundary else None - return None + if not content_type or not content_type.lower().startswith( + "multipart/form-data" + ): + return None + + parts = content_type.split(";") + for part in parts[1:]: + if "=" not in part: + continue + key, value = part.strip().split("=", 1) + + if key.strip().lower() != "boundary": + continue + value = value.strip() + + if value.startswith('"'): + if len(value) < 2 or not value.endswith('"'): + raise HeaderParsingError() + value = value[1:-1] + elif value.endswith('"'): + raise HeaderParsingError() + + if not value: + raise HeaderParsingError() + return value + raise HeaderParsingError() @classmethod def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]: @@ -417,7 +447,7 @@ def on_failure(self, tx, info: bytes): self._write_response_head(tx, len(info)) tx.write(info) - def on_busy(self, tx): + def on_unavailable(self, tx): """Terminate state machine and write 503 response""" self.terminate(503) self._write_response_head(tx) @@ -508,10 +538,13 @@ def _route_request_st(self, _, tx): if self.method == self.HEAD: self.on_client_error(tx, self.BAD_REQUEST_ERROR) return - if mp_boundary := self._is_multipart(self.headers): + if mp_boundary := self._get_mp_boundary(self.headers): self.mp_boundary = mp_boundary.encode(self.ASCII) self.state = self._start_multipart_parser_st elif self._is_chunked(): + if self.CONTENT_LENGTH in self.headers: + self.on_client_error(tx, self.BAD_REQUEST_ERROR) + return self.state = self._recv_chunked_size_st else: self.state = self._recv_payload_st @@ -528,8 +561,7 @@ def _route_request_st(self, _, tx): self.on_method_not_allowed(tx) return if self.method in (self.GET, self.HEAD): - resource = b"index.html" if not self.url else self.url - self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, resource) + self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, self.url) return self.on_missing_resource(tx) @@ -584,61 +616,28 @@ def _app_endpoint_st(self, rx, tx): dtype = dtype.encode(self.ASCII) self._set_response_header(b"content-type", dtype) - if dtype in (b"multipart/x-mixed-replace", b"multipart/form-data"): - part_content_type = data[0] - callback = data[1] - if type(callback).__name__ not in ("function", "closure"): - self.on_failure(tx, b"Invalid response handler") - return - self.terminate(200, dtype) - boundary = self.MULTIPART_BOUNDARY - self._set_response_header( - b"content-type", dtype + b"; boundary=" + boundary + if dtype.startswith(b"multipart/"): + self.state = lambda _rx, _tx: self._generate_multipart_response( + _rx, _tx, data, dtype ) - self._write_response_head(tx, None) - if self.method != self.HEAD: - return self._multipart_wrapper_factory( - callback, part_content_type.encode(self.ASCII), boundary - ) return + self.terminate(200, dtype) return self._generate_response(tx, data) - def _send_file_st(self, _, tx, web_resource: bytes): - """State for returning a static resource""" - extension = web_resource.rsplit(b".", 1)[-1] - norm_path = normalize_path(web_resource.decode(self.ASCII)) - is_path_served = self.is_norm_path_served(norm_path) - if not is_path_served: - try: - stat(norm_path) - self.on_forbidden(tx) - return - except OSError: - self.on_missing_resource(tx) - return - try: - content_type = self._lookup(self.CONTENT_TYPES, extension) - except ValueError: - content_type = self._lookup(self.CONTENT_TYPES, b"raw") - try: - self._set_response_header( - b"content-length", str(stat(norm_path)[6]).encode(HttpEngine.ASCII) - ) - self.terminate(200, content_type) - self._write_response_head(tx, None) - if self.method != self.HEAD: - return open(norm_path, "rb") - return - except OSError: - self.on_missing_resource(tx) + def _send_file_st(self, _, tx, web_resource: bytes): # pylint: disable=W0613 + """State for returning a static resource - disabled""" + self.on_unavailable(tx) def _start_multipart_parser_st(self, rx, tx): # pylint: disable=W0613 - self.on_failure(tx, b"Multipart handling is disabled") + """Initial state for processing multipart requests""" + self.on_unavailable(tx) - @staticmethod - def _multipart_wrapper_factory(callback, content_type: bytes, boundary: bytes): - pass + def _generate_multipart_response( + self, rx, tx, callback, dtype + ): # pylint: disable=W0613 + """Generate multipart response depening on the exact content type""" + self.on_unavailable(tx) def enable_optional_features(): @@ -649,3 +648,8 @@ def enable_optional_features(): from pyrobusta.protocol import http_multipart http_multipart.apply_patches() + + if get_config("http_serve_files").lower() == "true": + from pyrobusta.protocol import http_file_server + + http_file_server.apply_patches() diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py new file mode 100644 index 0000000..1e1acc4 --- /dev/null +++ b/src/pyrobusta/protocol/http_file_server.py @@ -0,0 +1,57 @@ +""" +State machine extension for file serving. +""" + +# pylint: disable=W0212,R0401 + +from os import stat + +from pyrobusta.protocol import http +from pyrobusta.utils.helpers import normalize_path, add_method + + +def _send_file_st(self, _, tx, web_resource: bytes): + """State for returning a static resource""" + if self.url == b"/files": + web_resource = "/" + elif self.url.startswith(b"/files/"): + web_resource = web_resource[7:] + elif self.url == b"/": + web_resource = b"/www/index.html" + else: + web_resource = b"/www" + web_resource + + extension = web_resource.rsplit(b".", 1)[-1] + norm_path = normalize_path(web_resource.decode(self.ASCII)) + is_path_served = self.is_norm_path_served(norm_path) + if not is_path_served: + try: + stat(norm_path) + self.on_forbidden(tx) + return + except OSError: + self.on_missing_resource(tx) + return + try: + content_type = self._lookup(self.CONTENT_TYPES, extension) + except ValueError: + content_type = self._lookup(self.CONTENT_TYPES, b"raw") + try: + self._set_response_header( + b"content-length", str(stat(norm_path)[6]).encode(http.HttpEngine.ASCII) + ) + self.terminate(200, content_type) + self._write_response_head(tx, None) + if self.method != self.HEAD: + return open(norm_path, "rb") + return + except OSError: + self.on_missing_resource(tx) + + +def apply_patches(): + """ + Apply patches to class attributes for file serving. + """ + + add_method(http.HttpEngine, _send_file_st) diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index dc95393..e58074a 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -5,24 +5,23 @@ # pylint: disable=W0212,R0401 from pyrobusta.protocol import http +from pyrobusta.utils.helpers import add_method -def add_method(cls, func, method_type="instance"): - """ - Helper to extend web.WebEngine with - additional methods and states. - """ - if method_type == "instance": - setattr(cls, func.__name__, func) - elif method_type == "static": - setattr(cls, func.__name__, staticmethod(func)) - elif method_type == "class": - setattr(cls, func.__name__, classmethod(func)) - else: - raise ValueError("Invalid type") +def _generate_multipart_response(self, _, tx, callback, dtype): + """Generate multipart response depening on the exact content type""" + if type(callback).__name__ not in ("function", "closure"): + self.on_failure(tx, b"Invalid response handler") + return + self.terminate(200, dtype) + boundary = self.MULTIPART_BOUNDARY + self._set_response_header(b"content-type", dtype + b"; boundary=" + boundary) + self._write_response_head(tx, None) + if self.method != self.HEAD: + return self._multipart_wrapper_factory(callback, boundary) -def _multipart_wrapper_factory(callback, content_type: bytes, boundary: bytes): +def _multipart_wrapper_factory(callback, boundary: bytes): """ Factory method for creating closures that write multipart responses :param callback: function without arguments, must return bytes-like objects @@ -30,8 +29,7 @@ def _multipart_wrapper_factory(callback, content_type: bytes, boundary: bytes): :param boundary: boundary value :return closure: closure to invoke for response generation """ - boundary = b"--" + boundary - content_type_header = b"content-type: %s\r\n\r\n" % content_type + delimiter = b"--" + boundary def _multipart_wrapper(tx): """ @@ -41,13 +39,16 @@ def _multipart_wrapper(tx): :return bool: true if the stream is completed """ while True: - tx.write(boundary) - part_body = callback() - if not part_body: + tx.write(delimiter) + part = callback() + if not part: tx.write(b"--") yield True + content_type, part_body = part tx.write(b"\r\n") - tx.write(content_type_header) + tx.write(b"content-type:") + tx.write(content_type.encode("ascii")) + tx.write(b"\r\n\r\n") written = 0 while written < len(part_body): to_write = tx.capacity - tx.size() @@ -139,9 +140,7 @@ def apply_patches(): """ Apply patches to class attributes for multipart parsing. """ - cls = http.HttpEngine - - orig_init = cls.__init__ + orig_init = http.HttpEngine.__init__ def new_init(self, *args, **kwargs): orig_init(self, *args, **kwargs) @@ -150,8 +149,9 @@ def new_init(self, *args, **kwargs): self.mp_delimiter = None self.mp_last_delimiter = None - cls.__init__ = new_init + http.HttpEngine.__init__ = new_init + add_method(http.HttpEngine, _generate_multipart_response) add_method(http.HttpEngine, _multipart_wrapper_factory, "static") add_method(http.HttpEngine, _start_multipart_parser_st) add_method(http.HttpEngine, _parse_boundary_st) diff --git a/src/pyrobusta/server/http_server.py b/src/pyrobusta/server/http_server.py index e0cecff..6761469 100644 --- a/src/pyrobusta/server/http_server.py +++ b/src/pyrobusta/server/http_server.py @@ -2,120 +2,212 @@ Socket server application """ -import gc +from gc import collect, mem_free from asyncio import sleep_ms, start_server, run # pylint: disable=E1101 from time import ticks_ms, ticks_diff from ..protocol import http from ..bindings.socket_http import SocketHttp +from ..stream.buffer import MemoryPool, SlidingBuffer from ..utils.config import get_config +from ..utils.helpers import normalize_path from ..utils import logging class HttpServer: """ Socket server class, handling global config (timeout, port, max connections etc.), - and managing active sockets. + and managing active clients. """ - __slots__ = ["_host", "_max_sockets", "_port", "_timeout", "_server"] + __slots__ = ["_host", "_port", "_server", "_max_clients"] - ACTIVE_SOCKETS = [] + ACTIVE_CLIENTS = [] + + # --------------- + # Server settings + # --------------- CON_ACCEPT_TIMEOUT_MS = 5000 # Timeout value for accepting new connection CON_ACCEPT_SLEEP_MS = ( 100 # Duration of sleep between attempts to accept new connection ) - MAX_SOCKETS = int(get_config("socket_max_con")) - SOCKET_TIMEOUT_SEC = 30 - LISTEN_PORT_HTTP = 8080 - LISTEN_PORT_HTTPS = 4443 - TLS_CERT_PATH = "cert.der" - TLS_KEY_PATH = "key.der" + LISTEN_PORT_HTTP = 80 + LISTEN_PORT_HTTPS = 443 + TLS_CERT_PATH = "/cert.der" + TLS_KEY_PATH = "/key.der" + CON_TIMEOUT_S = 30 + + # ----------------------------------------- + # Constants for controlled memory footprint + # ----------------------------------------- + + MEM_CAP = float(get_config("http_mem_cap")) # Default memory cap (percentage / 100) + SEND_BUF_MIN_BYTES = 512 # Minimum buffer size for responses + SEND_BUF_MAX_BYTES = 4096 # Max buffer size for responses + RECV_BUF_MIN_BYTES = 512 # Minimum buffer size for requests + RECV_BUF_MAX_BYTES = 4096 # Max buffer size for requests + CON_OVERHEAD_BYTES = 1024 # Overhead per connection + + # ------------------------------------------ + # Buffer pools - initialized by init_pools() + # ------------------------------------------ + + RECV_POOL = None + SEND_POOL = None + + @classmethod + def _init_pools(cls, max_clients): + """ + Initialize pool of buffers for sending/receiving based on different profiles + """ + mem_available = mem_free() + con_limit = max_clients + usable = int(cls.MEM_CAP * mem_available) + is_low_memory = (usable / con_limit) < ( + cls.RECV_BUF_MAX_BYTES + cls.SEND_BUF_MAX_BYTES + cls.CON_OVERHEAD_BYTES + ) + if is_low_memory: + logging.warning( + __name__ + ".init_pools: low-memory mode with reduced buffer size" + ) + recv_size = cls.RECV_BUF_MIN_BYTES if is_low_memory else cls.RECV_BUF_MAX_BYTES + send_size = cls.SEND_BUF_MIN_BYTES if is_low_memory else cls.SEND_BUF_MAX_BYTES + per_con = recv_size + send_size + cls.CON_OVERHEAD_BYTES + if usable < per_con: + raise MemoryError( + ( + f"Insufficient memory: {mem_available // 1024} KB " + f"at {cls.MEM_CAP*100}% cap, " + f"at least {per_con // 1024} KB required" + ) + ) + con_limit = min(usable // per_con, con_limit) + logging.info((__name__ + f".init_pools: {con_limit} connection(s) allowed")) + cls.RECV_POOL = MemoryPool(recv_size, con_limit, wrapper=SlidingBuffer) + cls.SEND_POOL = MemoryPool(send_size, con_limit, wrapper=SlidingBuffer) @classmethod - async def drop_client(cls, socket): - """Remove socket from active list""" - if socket not in cls.ACTIVE_SOCKETS: + async def _drop_client(cls, client): + """Remove client from active list""" + if client not in cls.ACTIVE_CLIENTS: return - logging.debug(__name__ + f": {socket.id} dropped") - await socket.close() - cls.ACTIVE_SOCKETS.remove(socket) - del socket - gc.collect() + logging.debug(__name__ + f": {client.id} dropped") + await client.close() + cls.ACTIVE_CLIENTS.remove(client) + del client + collect() + + # ---------------- + # Instance methods + # ---------------- def __init__(self): self._host = "0.0.0.0" - self._max_sockets = max(1, HttpServer.MAX_SOCKETS) self._port = ( HttpServer.LISTEN_PORT_HTTPS if get_config("tls").lower() == "true" else HttpServer.LISTEN_PORT_HTTP ) - self._timeout = HttpServer.SOCKET_TIMEOUT_SEC self._server = None + self._max_clients = 0 - async def can_handle_new_socket(self): + async def can_handle_new_client(self): """ Decide if the new socket can be handled. Evict closed/inactive sockets if needed. :return is_acceptable: true/false """ - gc.collect() + collect() con_timestamp = ticks_ms() while ticks_diff(ticks_ms(), con_timestamp) < self.CON_ACCEPT_TIMEOUT_MS: - if len(self.ACTIVE_SOCKETS) < self._max_sockets: + if len(self.ACTIVE_CLIENTS) < self._max_clients: return True # Attempt to evict inactive clients - for socket in self.ACTIVE_SOCKETS: - socket_inactive = int(ticks_diff(ticks_ms(), socket.last_event) * 0.001) - if not socket.connected or socket_inactive > self._timeout: + for client in self.ACTIVE_CLIENTS: + client_inactive = int(ticks_diff(ticks_ms(), client.last_event) * 0.001) + if not client.connected or client_inactive > self.CON_TIMEOUT_S: logging.debug( ( - __name__ + f": evicted {socket.id} " - f"timeout: {self._timeout - socket_inactive}s" + __name__ + f": evicted {client.id} " + f"timeout: {self.CON_TIMEOUT_S - client_inactive}s" ) ) - await self.drop_client(socket) - return True + await self._drop_client(client) await sleep_ms(self.CON_ACCEPT_SLEEP_MS) return False - async def accept_http(self, reader, writer): + async def _reserve_buffers(self): + if self.SEND_POOL is None or self.RECV_POOL is None: + raise RuntimeError("Pools are uninitialized") + + recv_buf = None + send_buf = None + + while not recv_buf or not send_buf: + if not recv_buf: + recv_buf = self.RECV_POOL.reserve() + if not send_buf: + send_buf = self.SEND_POOL.reserve() + await sleep_ms(self.CON_ACCEPT_SLEEP_MS) + + return recv_buf, send_buf + + async def _accept_socket(self, reader, writer): """ Handle incoming socket connection for HTTP. - creates SocketHttp object """ - if not await self.can_handle_new_socket(): + if not await self.can_handle_new_client(): logging.debug(__name__ + ": cannot accept new client") writer.close() await writer.wait_closed() return - new_client = SocketHttp(reader, writer) - logging.debug(__name__ + f": accept {new_client.id}") - self.ACTIVE_SOCKETS.append(new_client) - await new_client.run() - - async def run_server(self): + try: + recv_buf, send_buf = await self._reserve_buffers() + new_client = SocketHttp(reader, writer, recv_buf, send_buf) + logging.debug(__name__ + f": accept {new_client.id}") + self.ACTIVE_CLIENTS.append(new_client) + await new_client.run() + except Exception as e: # pylint: disable=W0718 + logging.warning(__name__ + f": error in run(): {e}") + finally: + if send_buf: + send_buf.consume() + self.SEND_POOL.release(send_buf) + if recv_buf: + recv_buf.consume() + self.RECV_POOL.release(recv_buf) + collect() + + async def start_socket_server(self): """ Start asyncio socket server on the specified port. """ try: - gc.collect() + collect() http.enable_optional_features() - logging.debug(f"Registered endpoints: {http.HttpEngine.ENDPOINTS}") - SocketHttp.init_pools(self._max_sockets) + logging.debug( + __name__ + f"registered endpoints: {http.HttpEngine.ENDPOINTS}" + ) + self._max_clients = int(get_config("socket_max_con")) + self._init_pools(self._max_clients) ssl_ctx = None + if get_config("tls").lower() == "true": import ssl ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_ctx.load_cert_chain(self.TLS_CERT_PATH, self.TLS_KEY_PATH) + ssl_ctx.load_cert_chain( + normalize_path(self.TLS_CERT_PATH), + normalize_path(self.TLS_KEY_PATH), + ) + self._server = await start_server( - self.accept_http, + self._accept_socket, self._host, self._port, - backlog=self._max_sockets, + backlog=max(1, self._max_clients), ssl=ssl_ctx, ) logging.info(__name__ + ": started") @@ -124,20 +216,20 @@ async def run_server(self): async def terminate(self): """ - Terminate HTTP server and close sockets + Terminate HTTP server and drop clients """ logging.info(__name__ + ": terminated") - while self.ACTIVE_SOCKETS: - await self.drop_client(self.ACTIVE_SOCKETS[0]) + while self.ACTIVE_CLIENTS: + await self._drop_client(self.ACTIVE_CLIENTS[0]) if self._server: self._server.close() await self._server.wait_closed() self._server = None - gc.collect() + collect() def main(): """ - Start socket server async task. + Start HTTP server async task. """ - run(HttpServer().run_server()) + run(HttpServer().start_socket_server()) diff --git a/src/pyrobusta/utils/assets.py b/src/pyrobusta/utils/assets.py new file mode 100644 index 0000000..c5f5a12 --- /dev/null +++ b/src/pyrobusta/utils/assets.py @@ -0,0 +1,79 @@ +""" +Helper functions to install assets. +""" + +from os import mkdir, listdir, stat + +from .helpers import normalize_path + +FS_ITER_ABS = 0 +FS_ITER_REL = 1 + +FS_ITER_FILE = 0 +FS_ITER_DIR = 1 + + +def copy_file(src_path, dst_path): + """ + Copy a file from a to a destination path. + """ + with open(src_path, "rb") as src: + with open(dst_path, "wb") as dst: + while True: + chunk = src.read(512) + if not chunk: + break + dst.write(chunk) + + +def iterate_fs(root, iter_mode=FS_ITER_FILE, path_mode=FS_ITER_ABS): + """ + Iterate over all files or directories and yield + resulting paths either as absolute or relative paths. + :param dir: directory in which to iterate + :iter_mode int: iterate over files (FS_ITER_FILE=0) or directories (FS_ITER_DIR=1) + :path_mode int: yield absolute paths (FS_ITER_ABS=0) ro relative paths (FS_ITER_REL=1) + """ + dirs = [root] + while dirs: + current_directory = dirs.pop(0) + for name in listdir(current_directory): + if current_directory == "/": + current_path = "/" + name + else: + current_path = current_directory + "/" + name + st = stat(current_path) + fs_mode = st[0] + if fs_mode & 0x4000: # directory bit set + dirs.append(current_path) + if iter_mode == FS_ITER_DIR: + if path_mode == FS_ITER_REL: + yield current_path[len(root) + 1 :] + else: + yield current_path + else: + continue + if iter_mode == FS_ITER_FILE: + if path_mode == FS_ITER_REL: + yield current_path[len(root) + 1 :] + else: + yield current_path + + +def install_www(): + """ + Install default web server assets under /www. + """ + source_dir = normalize_path("/lib/pyrobusta/assets/www") + target_dir = normalize_path("/www") + if "www" not in listdir(): + mkdir(target_dir) + + for asset_dir in iterate_fs(source_dir, FS_ITER_DIR, FS_ITER_ABS): + mkdir(asset_dir) + + for asset in iterate_fs(source_dir, FS_ITER_FILE, FS_ITER_REL): + copy_file( + source_dir + "/" + asset, + target_dir + "/" + asset, + ) diff --git a/src/pyrobusta/utils/config.py b/src/pyrobusta/utils/config.py index 35dc2ea..86dbf44 100644 --- a/src/pyrobusta/utils/config.py +++ b/src/pyrobusta/utils/config.py @@ -6,7 +6,7 @@ from .helpers import normalize_path -PYROBUSTA_VERSION = "0.3.0" +PYROBUSTA_VERSION = "v0.4.0" CONFIG_LOADED = False CONFIG_LOCATION = "pyrobusta.env" CONFIG_CACHE = [ @@ -19,13 +19,15 @@ "http_mem_cap", 0.1, "http_served_paths", - "/lib/pyrobusta", + "/www /lib/pyrobusta", + "http_serve_files", + "True", "socket_max_con", 2, "tls", "False", "log_level", - "warning", + "info", ] @@ -46,11 +48,12 @@ def read_config(config=CONFIG_LOCATION): try: with open(config, encoding="utf-8") as conf: for line in conf: - line = line.rstrip("\r\n") - key = line.split("=")[0].strip() - if key.startswith("#") or not line.strip(): + line = line.rstrip("\r\n").split("#")[0] + if not line.strip(): continue - value = line.split("=")[1].strip().strip("'").strip('"') + parts = line.split("=") + key = parts[0].strip() + value = parts[1].strip().strip("'").strip('"') if key and value: value = normalize(key, value) if ( diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py index cc0d9d3..08f647a 100644 --- a/src/pyrobusta/utils/helpers.py +++ b/src/pyrobusta/utils/helpers.py @@ -25,3 +25,18 @@ def normalize_path(path: str): return cwd + normalized return cwd + "/" + normalized return cwd + + +def add_method(cls, func, method_type="instance"): + """ + Helper to patch/extend classes with + additional methods and states. + """ + if method_type == "instance": + setattr(cls, func.__name__, func) + elif method_type == "static": + setattr(cls, func.__name__, staticmethod(func)) + elif method_type == "class": + setattr(cls, func.__name__, classmethod(func)) + else: + raise ValueError("Invalid type") diff --git a/tests/.pylintrc b/tests/.pylintrc index 1824775..5cf251e 100644 --- a/tests/.pylintrc +++ b/tests/.pylintrc @@ -3,4 +3,5 @@ disable=W0212, C0114, C0115, C0116, - R0904 + R0904, + R0902 diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py index 231f0c4..8800db0 100644 --- a/tests/functional/test_http.py +++ b/tests/functional/test_http.py @@ -1,11 +1,11 @@ import asyncio import ssl import json +import gc -from os import getcwd, mkdir +from os import mkdir, remove, rmdir from pyrobusta.server import http_server -from pyrobusta.protocol import http_multipart from pyrobusta.protocol.http import ( HttpEngine, enable_optional_features, @@ -19,6 +19,15 @@ ################################################# +def garbage_collect(coroutine): + async def decorated(*args, **kwargs): + gc.collect() + await coroutine(*args, **kwargs) + gc.collect() + + return decorated + + def test_assert(name, actual, expected): print(f"Test {name}: ", end="") if actual == expected: @@ -55,19 +64,6 @@ async def send_request(request, tls=False): return response -def multipart_response(num_responses): - i = 0 - - def response_generator(): - nonlocal i - i += 1 - if i > num_responses: - return None - return b"Response %s" % i - - return response_generator - - ################################################# # Test driver ################################################# @@ -82,12 +78,6 @@ def simple_callback(http_ctx, _): raise ValueError("Unhandled content-type") -@HttpEngine.route("/test/multipart", "GET") -def multipart_callback(http_ctx, _): - part_count = int(http_ctx.headers["x-part-count"]) - return "multipart/form-data", ("text/plain", multipart_response(part_count)) - - @HttpEngine.route("/test/busy", "POST") def busy_callback(*_): raise ServerBusyError() @@ -108,13 +98,14 @@ async def start_server(): Start an HTTP server as a background task """ server = http_server.HttpServer() - server_task = asyncio.create_task(server.run_server()) + server_task = asyncio.create_task(server.start_socket_server()) await asyncio.sleep_ms(100) return server, server_task +@garbage_collect async def test_simple_response(tls_enabled): - setup_config(multipart=False, tls_enabled=tls_enabled) + setup_config(tls_enabled=tls_enabled) server, server_task = await start_server() # Test: text/plain @@ -157,38 +148,7 @@ async def test_simple_response(tls_enabled): await server.terminate() -async def test_multipart_response(tls_enabled): - setup_config(multipart=True, tls_enabled=tls_enabled) - server, server_task = await start_server() - - # Test: 1 part - plain_response = await send_request( - b"GET /test/multipart HTTP/1.1\r\n" - b"Host: localhost\r\nX-Part-Count: 1\r\n\r\n", - tls_enabled, - ) - test_assert( - f"http{"s" if tls_enabled else ""} response contains 1 part", - b"Response 1" in plain_response, - True, - ) - - # Test: 10 parts - plain_response = await send_request( - b"GET /test/multipart HTTP/1.1\r\n" - b"Host: localhost\r\nX-Part-Count: 10\r\n\r\n", - tls_enabled, - ) - test_assert( - f"http{"s" if tls_enabled else ""} response contains 10 parts", - [b"Response %s" % i in plain_response for i in range(1, 11)], - [True] * 10, - ) - - server_task.cancel() - await server.terminate() - - +@garbage_collect async def test_server_busy(): setup_config() server, server_task = await start_server() @@ -206,6 +166,7 @@ async def test_server_busy(): await server.terminate() +@garbage_collect async def test_chunked_transfer_encoding(): setup_config() create_chunked_app_endpoint("/test/chunked") @@ -233,25 +194,33 @@ async def test_chunked_transfer_encoding(): await server.terminate() +@garbage_collect async def test_fs_access_control(): - setup_config(served_paths="/www") + setup_config(served_paths="/www/allowed") server, server_task = await start_server() + workdir_root = normalize_path("/www") + try: + mkdir(workdir_root) + except: + pass # Index page under /www -> accepted - workdir = normalize_path("/www") - index_html = normalize_path("/www/index.html") - mkdir(workdir) - with open(index_html, "w") as f: + allowed_workdir = normalize_path("/www/allowed") + allowed_index_html = normalize_path("/www/allowed/index.html") + mkdir(allowed_workdir) + with open(allowed_index_html, "w") as f: f.write("PyRobusta Home") # Index page under / -> rejected - index_html = normalize_path("/index.html") - with open(index_html, "w") as f: + rejected_workdir = normalize_path("/www/rejected") + rejected_index_html = normalize_path("/www/rejected/index.html") + mkdir(rejected_workdir) + with open(rejected_index_html, "w") as f: f.write("PyRobusta Home") # Case #1: /www/index.html response = await send_request( - (b"GET /www/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") + (b"GET /allowed/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") ) response_body = response.split(b"\r\n\r\n")[1] @@ -263,7 +232,7 @@ async def test_fs_access_control(): # Case #2: /index.html response = await send_request( - (b"GET /index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") + (b"GET /rejected/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n") ) test_assert( @@ -272,6 +241,10 @@ async def test_fs_access_control(): True, ) + remove(allowed_index_html) + remove(rejected_index_html) + rmdir(allowed_workdir) + rmdir(rejected_workdir) server_task.cancel() await server.terminate() @@ -281,9 +254,12 @@ async def test_fs_access_control(): ################################################# -def setup_config(multipart=False, tls_enabled=False, served_paths=""): - config_idx = config.CONFIG_CACHE.index("http_multipart") - config.CONFIG_CACHE[config_idx + 1] = str(multipart) +def setup_config(tls_enabled=False, served_paths=""): + http_server.HttpServer.LISTEN_PORT_HTTP = 8080 + http_server.HttpServer.LISTEN_PORT_HTTPS = 4443 + + config_idx = config.CONFIG_CACHE.index("log_level") + config.CONFIG_CACHE[config_idx + 1] = str("warning") config_idx = config.CONFIG_CACHE.index("tls") config.CONFIG_CACHE[config_idx + 1] = str(tls_enabled) config_idx = config.CONFIG_CACHE.index("http_served_paths") @@ -300,12 +276,6 @@ def test_registration(): HttpEngine._get_callback(b"/test/simple", b"GET"), ) - test_assert( - "multipart endpoint registration", - multipart_callback, - HttpEngine._get_callback(b"/test/multipart", b"GET"), - ) - test_assert( "busy endpoint registration", busy_callback, @@ -313,24 +283,11 @@ def test_registration(): ) -def test_multipart_patches(): - setup_config(multipart=True) - test_assert( - "multipart state machine patches", - http_multipart._start_multipart_parser_st, - HttpEngine._start_multipart_parser_st, - ) - - def test_main(): test_registration() asyncio.run(test_simple_response(tls_enabled=False)) asyncio.run(test_simple_response(tls_enabled=True)) - test_multipart_patches() - asyncio.run(test_multipart_response(tls_enabled=False)) - asyncio.run(test_multipart_response(tls_enabled=True)) - asyncio.run(test_server_busy()) asyncio.run(test_chunked_transfer_encoding()) asyncio.run(test_fs_access_control()) diff --git a/tests/functional/test_http_multipart.py b/tests/functional/test_http_multipart.py new file mode 100644 index 0000000..5463a59 --- /dev/null +++ b/tests/functional/test_http_multipart.py @@ -0,0 +1,172 @@ +import asyncio +import ssl +import gc + +from pyrobusta.server import http_server +from pyrobusta.protocol import http_multipart +from pyrobusta.protocol.http import ( + HttpEngine, + enable_optional_features, +) +from pyrobusta.utils import config + +################################################# +# Test helpers +################################################# + + +def garbage_collect(coroutine): + async def decorated(*args, **kwargs): + gc.collect() + await coroutine(*args, **kwargs) + gc.collect() + + return decorated + + +def test_assert(name, actual, expected): + print(f"Test {name}: ", end="") + if actual == expected: + print("OK") + else: + print("Fail") + raise AssertionError(f"{actual} != {expected}") + + +async def send_request(request, tls=False): + port = ( + http_server.HttpServer.LISTEN_PORT_HTTPS + if tls + else http_server.HttpServer.LISTEN_PORT_HTTP + ) + + ctx = None + if tls: + # Disable certificate verification due to self-signed cert + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.verify_mode = ssl.CERT_NONE + + reader, writer = await asyncio.open_connection("127.0.0.1", port, ssl=ctx) + writer.write(request) + await writer.drain() + + to_read = True + response = b"" + while to_read: + response_part = await reader.read(1024) + response += response_part + to_read = len(response_part) + writer.close() + return response + + +def multipart_response(num_responses): + i = 0 + + def response_generator(): + nonlocal i + i += 1 + if i > num_responses: + return None + return "text/plain", b"Response %s" % i + + return response_generator + + +################################################# +# Test driver +################################################# + + +@HttpEngine.route("/test/multipart", "GET") +def multipart_callback(http_ctx, _): + part_count = int(http_ctx.headers["x-part-count"]) + return "multipart/form-data", multipart_response(part_count) + + +async def start_server(): + """ + Start an HTTP server as a background task + """ + server = http_server.HttpServer() + server_task = asyncio.create_task(server.start_socket_server()) + await asyncio.sleep_ms(100) + return server, server_task + + +@garbage_collect +async def test_multipart_response(tls_enabled): + setup_config(tls_enabled=tls_enabled) + server, server_task = await start_server() + + # Test: 1 part + plain_response = await send_request( + b"GET /test/multipart HTTP/1.1\r\n" + b"Host: localhost\r\nX-Part-Count: 1\r\n\r\n", + tls_enabled, + ) + test_assert( + f"http{"s" if tls_enabled else ""} response contains 1 part", + b"Response 1" in plain_response, + True, + ) + + # Test: 10 parts + plain_response = await send_request( + b"GET /test/multipart HTTP/1.1\r\n" + b"Host: localhost\r\nX-Part-Count: 10\r\n\r\n", + tls_enabled, + ) + test_assert( + f"http{"s" if tls_enabled else ""} response contains 10 parts", + [b"Response %s" % i in plain_response for i in range(1, 11)], + [True] * 10, + ) + + server_task.cancel() + await server.terminate() + + +################################################# +# Test methods +################################################# + + +def setup_config(tls_enabled=False): + http_server.HttpServer.LISTEN_PORT_HTTP = 8080 + http_server.HttpServer.LISTEN_PORT_HTTPS = 4443 + + config_idx = config.CONFIG_CACHE.index("log_level") + config.CONFIG_CACHE[config_idx + 1] = str("warning") + config_idx = config.CONFIG_CACHE.index("http_multipart") + config.CONFIG_CACHE[config_idx + 1] = "True" + config_idx = config.CONFIG_CACHE.index("tls") + config.CONFIG_CACHE[config_idx + 1] = str(tls_enabled) + enable_optional_features() + + +def test_registration(): + test_assert( + "multipart endpoint registration", + multipart_callback, + HttpEngine._get_callback(b"/test/multipart", b"GET"), + ) + + +def test_multipart_patches(): + setup_config() + test_assert( + "multipart state machine patches", + http_multipart._start_multipart_parser_st, + HttpEngine._start_multipart_parser_st, + ) + + +def test_main(): + test_registration() + test_multipart_patches() + asyncio.run(test_multipart_response(tls_enabled=False)) + asyncio.run(test_multipart_response(tls_enabled=True)) + + +test_main() diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 6886fcc..10ee8a8 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -26,6 +26,7 @@ def test_path_normalization_virtual_root(self): cwd = getcwd() for case in ( ("", ""), + ("/", f"{cwd}"), ("/path/to/resource", f"{cwd}/path/to/resource"), ("/path/to/resource/", f"{cwd}/path/to/resource"), ("///path///to///resource///", f"{cwd}/path/to/resource"), @@ -45,6 +46,7 @@ def test_path_normalization_host_root(self, _): """ for case in ( ("", ""), + ("/", "/"), ("/path/to/resource", "/path/to/resource"), ("/path/to/resource/", "/path/to/resource"), ("///path///to///resource///", "/path/to/resource"), diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index ee04f02..9afec41 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -30,28 +30,26 @@ def setUp(self): }, ) self.patcher.start() - - for key, value in self.config.items(): - self.set_mock_config(key, value) + self.set_mock_config() # Load your web and buffer modules self.helpers_module = load_module("pyrobusta/utils/helpers.py") buffer_module = load_module("pyrobusta/stream/buffer.py") - web_module = load_module("pyrobusta/protocol/http.py") - web_module.enable_optional_features() + self.web_module = load_module("pyrobusta/protocol/http.py") + self.web_module.enable_optional_features() - self.engine = web_module.HttpEngine() + self.engine = self.web_module.HttpEngine() self.rx = buffer_module.SlidingBuffer(bytearray(1024)) self.tx = buffer_module.SlidingBuffer(bytearray(1024)) def tearDown(self): self.patcher.stop() - def set_mock_config(self, key, value): + def set_mock_config(self): def side_effect(input_arg, *_, **__): - if input_arg == key: - return value - raise ValueError(f"Unexpected argument: {input_arg}") + if input_arg in self.config: + return self.config[input_arg] + raise ValueError(f"Unexpected config key: {input_arg}") self.mock_utils_config.get_config.side_effect = side_effect @@ -63,7 +61,7 @@ class TestWebStateMachine(TestWebStateMachineBase): @classmethod def setUpClass(cls): - cls.config = {} + cls.config = {"http_multipart": "False", "http_serve_files": "False"} def test_status_parsing_valid(self): request = b"GET /index.html HTTP/1.1\r\nContent-Length:10" @@ -149,6 +147,18 @@ def test_header_parsing_incomplete_header(self): self.assertEqual(self.engine.status_code, 400) self.assertEqual(self.engine.state, None) + def test_header_parsing_error(self): + for case in ( + b"", + b":", + b": value", + b" leading-space: value", + b"space in header name: value", + b"new-line-in-header:\nvalue", + ): + with self.assertRaises(self.web_module.HeaderParsingError): + self.engine._parse_headers(case) + def test_routing_unsupported_method(self): self.engine.state = self.engine._route_request_st self.engine.url = b"/api/test" @@ -312,6 +322,26 @@ def test_empty_or_missing_url_encoded_query_parameter(self): with self.assertRaises(KeyError): self.engine.get_url_encoded_query_param(self.engine.query, "param3") + def test_overlapping_url_encoded_query_parameter(self): + request = b"GET /api/test?data=value1&ta=value2&a=value3 HTTP/1.1\r\n" + + for i in range(len(request)): + self.rx.write(request[i : i + 1]) + self.engine.state(self.rx, self.tx) + + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "data"), + "value1", + ) + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "ta"), + "value2", + ) + self.assertEqual( + self.engine.get_url_encoded_query_param(self.engine.query, "a"), + "value3", + ) + def test_chunked_transfer_encoding_valid(self): self.engine.url = b"/api/test" self.engine.method = b"GET" @@ -383,7 +413,7 @@ def test_chunked_transfer_encoding_chunk_incomplete(self): self.assertEqual(self.engine.state, self.engine._recv_chunk_st) def test_path_serving_list(self): - self.set_mock_config("http_served_paths", "/path/to/dir1 /path/to/dir2") + self.config["http_served_paths"] = "/path/to/dir1 /path/to/dir2" self.assertEqual(self.engine.is_norm_path_served(""), False) self.assertEqual(self.engine.is_norm_path_served("/"), False) self.assertEqual(self.engine.is_norm_path_served("/path/to/dir1"), True) @@ -395,13 +425,13 @@ def test_path_serving_list(self): self.assertEqual(self.engine.is_norm_path_served("/path/to"), False) def test_path_serving_root(self): - self.set_mock_config("http_served_paths", "/") + self.config["http_served_paths"] = "/" self.assertEqual(self.engine.is_norm_path_served(""), True) self.assertEqual(self.engine.is_norm_path_served("/"), True) self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), True) def test_path_serving_none(self): - self.set_mock_config("http_served_paths", "") + self.config["http_served_paths"] = "" self.assertEqual(self.engine.is_norm_path_served(""), False) self.assertEqual(self.engine.is_norm_path_served("/"), False) self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), False) @@ -414,14 +444,19 @@ class TestMultipartStateMachine(TestWebStateMachineBase): @classmethod def setUpClass(cls): - cls.config = {"http_multipart": "True"} + cls.config = {"http_multipart": "True", "http_serve_files": "True"} def test_multipart_parser(self): for case in [ + ({}, None), ( {"content-type": 'multipart/form-data; boundary ="test-boundary"'}, "test-boundary", ), + ( + {"content-type": 'multipart/form-data; boundary =" test-boundary "'}, + " test-boundary ", + ), ( {"content-type": "multipart/form-data ;boundary= test-boundary "}, "test-boundary", @@ -432,16 +467,18 @@ def test_multipart_parser(self): ), ]: with self.subTest(headers=case[0], expected=case[1]): - self.assertEqual(self.engine._is_multipart(case[0]), case[1]) + self.assertEqual(self.engine._get_mp_boundary(case[0]), case[1]) for case in [ - {}, {"content-type": "multipart/form-data"}, {"content-type": 'multipart/form-data;boundary=""'}, {"content-type": "multipart/form-data;boundary=\r\n"}, + {"content-type": 'multipart/form-data;boundary="missing-quote'}, + {"content-type": 'multipart/form-data;boundary=missing-quote"'}, ]: - with self.subTest(headers=case, expected=None): - self.assertEqual(self.engine._is_multipart(case), None) + with self.subTest(headers=case): + with self.assertRaises(self.web_module.HeaderParsingError): + self.engine._get_mp_boundary(case) def test_multipart_receiver_valid(self): self.engine.state = self.engine._start_multipart_parser_st