diff --git a/Makefile b/Makefile
index 4486799..f99940a 100644
--- a/Makefile
+++ b/Makefile
@@ -1,13 +1,15 @@
-PYROBUSTA_VERSION := 0.3.0
+PYROBUSTA_VERSION := v0.4.0
DEVICE ?= u0
SRC_DIR := src
TEST_DIR := tests
-EXAMPLE_DIR := example/mem_usage
+EXAMPLE_DIR := example/mip_repo
BUILD_DIR := build
DIST_DIR := dist
-PKG := pyrobusta
TLS_DIR := tls
+ASSETS_DIR := assets
+
+PKG := pyrobusta
MICROPY_ROOT := external/micropython
MPY_CROSS := $(MICROPY_ROOT)/mpy-cross/build/mpy-cross
@@ -47,6 +49,11 @@ toolchain:
# -----------------------------
.PHONY: build
build: $(MPY_TARGETS) $(INIT_TARGETS)
+ @mkdir -p $(BUILD_DIR)
+ @if [ -d assets ]; then \
+ echo "Copying assets/ -> $(BUILD_DIR)"; \
+ cp -r assets $(BUILD_DIR)/${PKG}/; \
+ fi
# Compile .py -> .mpy
$(BUILD_DIR)/%.mpy: $(SRC_DIR)/%.py
@@ -66,6 +73,7 @@ $(BUILD_DIR)/%.py: $(SRC_DIR)/%.py
.PHONY: deploy
deploy:
@echo "Uploading build/$(PKG) to device $(DEVICE)"
+ @mpremote $(DEVICE) soft-reset
@mpremote $(DEVICE) mkdir :/lib || true
@find $(BUILD_DIR)/$(PKG) | while read source; do \
rel=$${source#$(BUILD_DIR)/}; \
@@ -79,6 +87,7 @@ deploy:
fi; \
sleep 1; \
done
+ @mpremote $(DEVICE) reset
# -----------------------------
# Deploy custom configuration
@@ -86,7 +95,20 @@ deploy:
.PHONY: deploy-config
deploy-config:
@echo "Uploading pyrobusta.env"
+ @mpremote $(DEVICE) soft-reset
@if [ -f pyrobusta.env ]; then mpremote $(DEVICE) cp pyrobusta.env :pyrobusta.env; fi
+ @mpremote $(DEVICE) reset
+
+
+# -----------------------------
+# Deploy index page # TODO use install_www from assets module
+# -----------------------------
+.PHONY: deploy-www
+deploy-www:
+ @echo "Deploying /www"
+ @mpremote $(DEVICE) soft-reset
+ @mpremote $(DEVICE) run scripts/install_www.py
+ @mpremote $(DEVICE) reset
# -----------------------------
# Full redeploy
@@ -107,6 +129,9 @@ publish:
@sed -E -i.bak 's/(PYROBUSTA_VERSION[[:space:]]*=[[:space:]]*)"[^"]*"/\1"$(PYROBUSTA_VERSION)"/' \
$(SRC_DIR)/pyrobusta/utils/config.py \
&& rm -f $(SRC_DIR)/pyrobusta/utils/config.py.bak
+ @sed -E -i.bak 's/(PyRobusta[[:space:]]).+([[:space:]]Web Server)/\1$(PYROBUSTA_VERSION)\2/' \
+ $(ASSETS_DIR)/www/*.html \
+ && rm -f $(ASSETS_DIR)/www/*.html.bak
$(MAKE) clean
$(MAKE) build BUILD_DIR=$(DIST_DIR)
scripts/update_package.bash $(DIST_DIR) package.json $(PYROBUSTA_VERSION)
@@ -128,7 +153,7 @@ stage-example:
@echo "Copying built package"
@cp -r build/pyrobusta $(RUNTIME_DIR)/lib
- @echo "Copying example files"
+ @echo "Copying example app"
@cp $(EXAMPLE_DIR)/app.py $(RUNTIME_DIR)/
@cp $(EXAMPLE_DIR)/boot.py $(RUNTIME_DIR)/
@@ -152,16 +177,20 @@ run-unix: stage-example
.PHONY: deploy-example
deploy-example:
@echo "Uploading boot.py"
+ @mpremote $(DEVICE) soft-reset
mpremote $(DEVICE) cp $(EXAMPLE_DIR)/boot.py :boot.py
+ mpremote $(DEVICE) cp $(EXAMPLE_DIR)/app.py :app.py
@echo "Uploading pyrobusta.env"
@if [ -f pyrobusta.env ]; then mpremote $(DEVICE) cp pyrobusta.env :pyrobusta.env; fi
+ @mpremote $(DEVICE) reset
# -----------------------------
# Run example directly
# -----------------------------
.PHONY: run-device
run-device:
+ @mpremote $(DEVICE) soft-reset
mpremote $(DEVICE) run $(EXAMPLE_DIR)/app.py
@@ -226,7 +255,9 @@ test-unix: TLS_DIR=$(TEST_RUNTIME)
test-unix: stage-test tls-cert
@cd $(TEST_RUNTIME); \
for test in test_*.py; do \
+ echo "\n==================================="; \
echo "Running $$test"; \
+ echo "==================================="; \
MICROPYPATH=":.frozen:lib" ../$(MICROPYTHON) $$(basename $$test) || exit 1; \
done
@@ -234,12 +265,16 @@ test-unix: stage-test tls-cert
# Run functional tests on device
# -----------------------------
.PHONY: test-device
-test-device: #clean-device upload
+test-device: stage-test #clean-device upload
+ @mpremote $(DEVICE) soft-reset
@cd $(TEST_RUNTIME); \
for test in test_*.py; do \
+ echo "\n==================================="; \
echo "Running $$test"; \
+ echo "==================================="; \
mpremote $(DEVICE) run $$(basename $$test) || exit 1; \
done
+ @mpremote $(DEVICE) reset
# ================================================
# Utilities for TLS
@@ -272,8 +307,10 @@ tls-cert:
# -----------------------------
.PHONY: deploy-cert
deploy-cert:
+ @mpremote $(DEVICE) soft-reset
@mpremote $(DEVICE) cp $(TLS_DIR)/key.der :key.der
@mpremote $(DEVICE) cp $(TLS_DIR)/cert.der :cert.der
+ @mpremote $(DEVICE) reset
# ================================================
# Cleanup
@@ -305,4 +342,6 @@ clean: clean-build clean-runtime
# -----------------------------
.PHONY: clean-device
clean-device:
+ @mpremote $(DEVICE) soft-reset
mpremote $(DEVICE) run scripts/clean_device.py
+ @mpremote $(DEVICE) reset
diff --git a/assets/www/examples.html b/assets/www/examples.html
new file mode 100644
index 0000000..095c675
--- /dev/null
+++ b/assets/www/examples.html
@@ -0,0 +1,157 @@
+
+
+
+
+
+ PyRobusta Home
+
+
+
+
+ Getting Started
+
+ ← Back
+
+ This page presents useful examples to configure your server.
+
+
+
+ Server configuration
+
+
+
+
+
+ Simple Server Application
+ The below example demonstrates how to set up a simple application, exposed at /app.
+
+
+
+ Soft reset the device and upload app.py and boot.py with mpremote.
+
+
+ Hard reset the device to start the application and connect over REPL.
+
+
+ Use curl to test your application.
+
+
+
+
+
\ No newline at end of file
diff --git a/assets/www/index.html b/assets/www/index.html
new file mode 100644
index 0000000..2f49452
--- /dev/null
+++ b/assets/www/index.html
@@ -0,0 +1,55 @@
+
+
+
+
+
+ PyRobusta Home
+
+
+
+
+ PyRobusta Home
+
+ The server is running correctly and is ready to serve content.
+
+
+
+ Available Resources
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dist/pyrobusta/assets/www/examples.html b/dist/pyrobusta/assets/www/examples.html
new file mode 100644
index 0000000..095c675
--- /dev/null
+++ b/dist/pyrobusta/assets/www/examples.html
@@ -0,0 +1,157 @@
+
+
+
+
+
+ PyRobusta Home
+
+
+
+
+ Getting Started
+
+ ← Back
+
+ This page presents useful examples to configure your server.
+
+
+
+ Server configuration
+
+
+
+
+
+ Simple Server Application
+ The below example demonstrates how to set up a simple application, exposed at /app.
+
+
+
+ Soft reset the device and upload app.py and boot.py with mpremote.
+
+
+ Hard reset the device to start the application and connect over REPL.
+
+
+ Use curl to test your application.
+
+
+
+
+
\ No newline at end of file
diff --git a/dist/pyrobusta/assets/www/index.html b/dist/pyrobusta/assets/www/index.html
new file mode 100644
index 0000000..2f49452
--- /dev/null
+++ b/dist/pyrobusta/assets/www/index.html
@@ -0,0 +1,55 @@
+
+
+
+
+
+ PyRobusta Home
+
+
+
+
+ PyRobusta Home
+
+ The server is running correctly and is ready to serve content.
+
+
+
+ Available Resources
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dist/pyrobusta/bindings/socket_http.mpy b/dist/pyrobusta/bindings/socket_http.mpy
index b56ceb7..6ee7b49 100644
Binary files a/dist/pyrobusta/bindings/socket_http.mpy and b/dist/pyrobusta/bindings/socket_http.mpy differ
diff --git a/dist/pyrobusta/con/wifi.mpy b/dist/pyrobusta/con/wifi.mpy
index 8f5f2bc..5e82314 100644
Binary files a/dist/pyrobusta/con/wifi.mpy and b/dist/pyrobusta/con/wifi.mpy differ
diff --git a/dist/pyrobusta/protocol/http.mpy b/dist/pyrobusta/protocol/http.mpy
index d2a04ba..e645c38 100644
Binary files a/dist/pyrobusta/protocol/http.mpy and b/dist/pyrobusta/protocol/http.mpy differ
diff --git a/dist/pyrobusta/protocol/http_file_server.mpy b/dist/pyrobusta/protocol/http_file_server.mpy
new file mode 100644
index 0000000..c6e8a74
Binary files /dev/null and b/dist/pyrobusta/protocol/http_file_server.mpy differ
diff --git a/dist/pyrobusta/protocol/http_multipart.mpy b/dist/pyrobusta/protocol/http_multipart.mpy
index 03fdbed..e6dbbce 100644
Binary files a/dist/pyrobusta/protocol/http_multipart.mpy and b/dist/pyrobusta/protocol/http_multipart.mpy differ
diff --git a/dist/pyrobusta/server/http_server.mpy b/dist/pyrobusta/server/http_server.mpy
index 467a75e..d798948 100644
Binary files a/dist/pyrobusta/server/http_server.mpy and b/dist/pyrobusta/server/http_server.mpy differ
diff --git a/dist/pyrobusta/utils/assets.mpy b/dist/pyrobusta/utils/assets.mpy
new file mode 100644
index 0000000..7093beb
Binary files /dev/null and b/dist/pyrobusta/utils/assets.mpy differ
diff --git a/dist/pyrobusta/utils/config.mpy b/dist/pyrobusta/utils/config.mpy
index 5e9b55a..671b920 100644
Binary files a/dist/pyrobusta/utils/config.mpy and b/dist/pyrobusta/utils/config.mpy differ
diff --git a/dist/pyrobusta/utils/helpers.mpy b/dist/pyrobusta/utils/helpers.mpy
index 19f006a..5760dd1 100644
Binary files a/dist/pyrobusta/utils/helpers.mpy and b/dist/pyrobusta/utils/helpers.mpy differ
diff --git a/docs/configuration.md b/docs/configuration.md
index 8ab2b12..f6b63fb 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -9,7 +9,8 @@ to upload it to the root directory of the target device.
| wifi_password | Password of the Wi-Fi network. When empty, Wi-Fi is not initalized by the built-in wifi.py module. | None |
| http_multipart | Enable multipart HTTP requests/responses. | "False" |
| http_mem_cap | Max memory cap (% × 0.01) of usable heap for HTTP request/response stream buffers. | 0.1 |
-| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/lib/pyrobusta" |
+| http_served_paths | Space delimited list of filesystem paths allowed to be served through HTTP. | "/www /lib/pyrobusta" |
+| http_serve_files | Enable/disable file serving. | "True" |
| socket_max_con | Max number of socket connections of any enabled application server. | 2 |
| tls | Enable/disable TLS. When turned on, cert.der/key.der must be installed at the root. | "False" |
-| log_level | Can be one of: warning, info, debug. | "warning" |
+| log_level | Can be one of: warning, info, debug. | "info" |
diff --git a/example/demo_app/app.py b/example/demo_app/app.py
new file mode 100644
index 0000000..4d2491b
--- /dev/null
+++ b/example/demo_app/app.py
@@ -0,0 +1,33 @@
+import asyncio
+from gc import mem_free, mem_alloc
+
+from pyrobusta.server import http_server
+from pyrobusta.protocol.http import HttpEngine
+from pyrobusta.utils import logging
+
+
+@HttpEngine.route("/app", "GET")
+def app(http_ctx, payload):
+ free = mem_free()
+ value_format = "bytes"
+
+ if http_ctx.query:
+ value_format = http_ctx.get_url_encoded_query_param(
+ http_ctx.query, "format", default="bytes"
+ )
+ if value_format not in ("%", "bytes"):
+ raise ValueError("invalid format")
+
+ if value_format == "%":
+ free = round(100 * free / (free + mem_alloc()), 2)
+
+ return "text/plain", (f"Free memory [{value_format}]: {free}\n")
+
+
+def main():
+ http_server.main()
+ try:
+ asyncio.get_event_loop().run_forever()
+ except Exception as e:
+ logging.warning(f"loop stopped: {e}")
+ asyncio.get_event_loop().close()
diff --git a/example/demo_app/boot.py b/example/demo_app/boot.py
new file mode 100644
index 0000000..a7c9c58
--- /dev/null
+++ b/example/demo_app/boot.py
@@ -0,0 +1,12 @@
+# This file is executed on every boot (including wake-boot from deepsleep)
+import machine
+from os import listdir
+
+from pyrobusta.con import wifi
+
+connected = wifi.initialize()
+if connected and not machine.reset_cause() == machine.SOFT_RESET:
+ if "app.py" in listdir():
+ import app
+
+ app.main()
diff --git a/example/mem_usage/app.py b/example/mem_usage/app.py
index bccf0ae..b3ac746 100644
--- a/example/mem_usage/app.py
+++ b/example/mem_usage/app.py
@@ -23,11 +23,11 @@ def mem_usage(http_ctx, _):
selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "")
if selector == "free":
if value_format == "%":
- free = 100 * free / (used + free)
+ free = round(100 * free / (used + free),2)
return "text/plain", f"Free [{value_format}]: {free}\n"
if selector == "used":
if value_format == "%":
- used = 100 * used / (used + free)
+ used = round(100 * used / (used + free),2)
return "text/plain", f"Used [{value_format}]: {used}\n"
if selector == "total":
return "text/plain", f"Total [bytes]: {used + free}\n"
diff --git a/example/mem_usage/boot.py b/example/mem_usage/boot.py
index d96f6f6..a7c9c58 100644
--- a/example/mem_usage/boot.py
+++ b/example/mem_usage/boot.py
@@ -1,4 +1,12 @@
# This file is executed on every boot (including wake-boot from deepsleep)
+import machine
+from os import listdir
+
from pyrobusta.con import wifi
-wifi.initialize()
+connected = wifi.initialize()
+if connected and not machine.reset_cause() == machine.SOFT_RESET:
+ if "app.py" in listdir():
+ import app
+
+ app.main()
diff --git a/example/mip_repo/app.py b/example/mip_repo/app.py
index 5955b9a..2564e96 100644
--- a/example/mip_repo/app.py
+++ b/example/mip_repo/app.py
@@ -22,7 +22,7 @@ def append_package_files(dir, package_files, host_name, protocol):
package_files["urls"].append(
[
target_path,
- f"{protocol}://{host_name}/{current_path}",
+ f"{protocol}://{host_name}/files/{current_path}",
]
)
@@ -38,7 +38,8 @@ def self_serve_mip_package(http_ctx, _):
if tls_enabled
else http_server.HttpServer.LISTEN_PORT_HTTP
)
- server_addr += f":{port}"
+ if not server_addr in (80, 443):
+ server_addr += f":{port}"
protocol = "https" if tls_enabled else "http"
diff --git a/example/mip_repo/boot.py b/example/mip_repo/boot.py
index d96f6f6..a7c9c58 100644
--- a/example/mip_repo/boot.py
+++ b/example/mip_repo/boot.py
@@ -1,4 +1,12 @@
# This file is executed on every boot (including wake-boot from deepsleep)
+import machine
+from os import listdir
+
from pyrobusta.con import wifi
-wifi.initialize()
+connected = wifi.initialize()
+if connected and not machine.reset_cause() == machine.SOFT_RESET:
+ if "app.py" in listdir():
+ import app
+
+ app.main()
diff --git a/package.json b/package.json
index ddfce0e..ad11182 100644
--- a/package.json
+++ b/package.json
@@ -1,5 +1,5 @@
{
- "version": "0.3.0",
+ "version": "v0.4.0",
"urls": [
[
"pyrobusta/transport/socket.mpy",
@@ -9,6 +9,14 @@
"pyrobusta/transport/__init__.py",
"github:szeka9/PyRobusta/dist/pyrobusta/transport/__init__.py"
],
+ [
+ "pyrobusta/assets/www/index.html",
+ "github:szeka9/PyRobusta/dist/pyrobusta/assets/www/index.html"
+ ],
+ [
+ "pyrobusta/assets/www/examples.html",
+ "github:szeka9/PyRobusta/dist/pyrobusta/assets/www/examples.html"
+ ],
[
"pyrobusta/utils/helpers.mpy",
"github:szeka9/PyRobusta/dist/pyrobusta/utils/helpers.mpy"
@@ -25,6 +33,10 @@
"pyrobusta/utils/logging.mpy",
"github:szeka9/PyRobusta/dist/pyrobusta/utils/logging.mpy"
],
+ [
+ "pyrobusta/utils/assets.mpy",
+ "github:szeka9/PyRobusta/dist/pyrobusta/utils/assets.mpy"
+ ],
[
"pyrobusta/protocol/http_multipart.mpy",
"github:szeka9/PyRobusta/dist/pyrobusta/protocol/http_multipart.mpy"
@@ -37,6 +49,10 @@
"pyrobusta/protocol/__init__.py",
"github:szeka9/PyRobusta/dist/pyrobusta/protocol/__init__.py"
],
+ [
+ "pyrobusta/protocol/http_file_server.mpy",
+ "github:szeka9/PyRobusta/dist/pyrobusta/protocol/http_file_server.mpy"
+ ],
[
"pyrobusta/stream/__init__.py",
"github:szeka9/PyRobusta/dist/pyrobusta/stream/__init__.py"
diff --git a/scripts/install_www.py b/scripts/install_www.py
new file mode 100644
index 0000000..fcc8913
--- /dev/null
+++ b/scripts/install_www.py
@@ -0,0 +1,3 @@
+import pyrobusta.utils.assets as assets
+
+assets.install_www()
\ No newline at end of file
diff --git a/src/pyrobusta/bindings/socket_http.py b/src/pyrobusta/bindings/socket_http.py
index 799fea7..f646719 100644
--- a/src/pyrobusta/bindings/socket_http.py
+++ b/src/pyrobusta/bindings/socket_http.py
@@ -4,12 +4,11 @@
import asyncio
from asyncio import sleep_ms # pylint: disable=E1101
-from gc import mem_free, collect
+from gc import collect
-from ..stream.buffer import MemoryPool, SlidingBuffer, BufferFullError
+from ..stream.buffer import BufferFullError
from ..transport.socket import SocketBase
-from ..protocol.http import HttpEngine, ServerBusyError
-from ..utils.config import get_config
+from ..protocol.http import HttpEngine, ServerBusyError, HeaderParsingError
from ..utils import logging
@@ -19,75 +18,19 @@ class SocketHttp(SocketBase):
buffer management and state machine parser.
"""
- # Constants for memory footprint
- MEM_CAP = float(
- get_config("http_mem_cap")
- ) # Default memory cap (percentage / 100) of free heap
- SEND_BUF_MIN_BYTES = 512 # Minimum buffer size for responses
- SEND_BUF_MAX_BYTES = 4096 # Max buffer size for responses
- RECV_BUF_MIN_BYTES = 512 # Minimum buffer size for requests
- RECV_BUF_MAX_BYTES = 4096 # Max buffer size for requests
- CONN_OVERHEAD = 1024 # Overhead per connection
- MTU_SIZE = 1460 # TCP maximum transmission unit
-
- # Timing settings
+ MTU_SIZE = 1460
STATE_MACHINE_SLEEP_MS = 2
RESP_HANDLER_SLEEP_MS = 2
RECV_TIMEOUT_SECONDS = 10
- # Static buffer pools - initialized by init_pools()
- RECV_POOL = None
- SEND_POOL = None
-
- @staticmethod
- def init_pools(max_sockets):
- """
- Initialize pool of buffers for sending/receiving based on different profiles
- """
- mem_available = mem_free()
- con_limit = max(1, max_sockets)
- usable = int(SocketHttp.MEM_CAP * mem_available)
- is_low_memory = (usable / con_limit) < (
- SocketHttp.RECV_BUF_MAX_BYTES
- + SocketHttp.SEND_BUF_MAX_BYTES
- + SocketHttp.CONN_OVERHEAD
- )
- if is_low_memory:
- logging.warning(
- __name__ + ".init_pools: low-memory mode with reduced buffer size"
- )
- recv_size = (
- SocketHttp.RECV_BUF_MIN_BYTES
- if is_low_memory
- else SocketHttp.RECV_BUF_MAX_BYTES
- )
- send_size = (
- SocketHttp.SEND_BUF_MIN_BYTES
- if is_low_memory
- else SocketHttp.SEND_BUF_MAX_BYTES
- )
- per_conn = recv_size + send_size + SocketHttp.CONN_OVERHEAD
- if usable < per_conn:
- raise MemoryError(
- (
- f"Insufficient memory: {mem_available // 1024} KB "
- f"at {SocketHttp.MEM_CAP*100}% cap, "
- f"at least {per_conn // 1024} KB required"
- )
- )
- con_limit = min(usable // per_conn, con_limit)
- logging.info((__name__ + f".init_pools: {con_limit} connection(s) allowed"))
- SocketHttp.RECV_POOL = MemoryPool(recv_size, con_limit, wrapper=SlidingBuffer)
- SocketHttp.SEND_POOL = MemoryPool(send_size, con_limit, wrapper=SlidingBuffer)
-
__slots__ = ("_engine", "_prev_state", "_recv_buf", "_send_buf")
- def __init__(self, reader, writer):
+ def __init__(self, reader, writer, recv_buf, send_buf):
super().__init__(reader, writer)
self._engine = HttpEngine()
self._prev_state = None
- self._recv_buf = None
- self._send_buf = None
+ self._recv_buf = recv_buf
+ self._send_buf = send_buf
async def _flush_response(self):
data = self._send_buf.peek()
@@ -99,38 +42,41 @@ async def _flush_response(self):
async def run(self):
"""
Handle socket connection with HTTP state machine parser.
- - 1) reserve buffer
- - 2) run state machine parser
- - 3) release reserved buffers and terminate socket connection
"""
- await self._reserve_buffers()
self._prev_state = None
try:
while self._engine.state is not None:
await self._run_state_machine()
await sleep_ms(SocketHttp.STATE_MACHINE_SLEEP_MS)
- except Exception as e: # pylint: disable=W0718
- logging.warning(__name__ + f": error in run_web: {e}")
finally:
- if self._send_buf:
- self._send_buf.consume()
- SocketHttp.SEND_POOL.release(self._send_buf)
- if self._recv_buf:
- self._recv_buf.consume()
- SocketHttp.RECV_POOL.release(self._recv_buf)
await self.close()
collect()
- async def _reserve_buffers(self):
- if SocketHttp.SEND_POOL is None or SocketHttp.RECV_POOL is None:
- raise RuntimeError("Pools are ninitialized")
-
- while not self._recv_buf or not self._send_buf:
- if not self._recv_buf:
- self._recv_buf = SocketHttp.RECV_POOL.reserve()
- if not self._send_buf:
- self._send_buf = SocketHttp.SEND_POOL.reserve()
- await sleep_ms(SocketHttp.STATE_MACHINE_SLEEP_MS)
+ async def _read_to_buf(self):
+ buf_free = self._recv_buf.capacity - self._recv_buf.size()
+ if not buf_free:
+ self._engine.on_buffer_full(self._send_buf)
+ await self._flush_response()
+ return 0
+ try:
+ request = await self.read(
+ read_bytes=buf_free,
+ decoding=None,
+ timeout_seconds=SocketHttp.RECV_TIMEOUT_SECONDS,
+ )
+ except asyncio.TimeoutError:
+ self._engine.on_timeout(self._send_buf)
+ await self._flush_response()
+ return 0
+ except Exception as e: # pylint: disable=W0718
+ self._engine.on_failure(
+ self._send_buf, b"Read error: " + str(e).encode("ascii")
+ )
+ await self._flush_response()
+ return 0
+ self._recv_buf.write(request)
+ logging.debug(__name__ + f"._read_to_buf: [{request}]")
+ return len(request)
async def _run_state_machine(self):
if self._prev_state == self._engine.state or self._prev_state is None:
@@ -151,7 +97,11 @@ async def _run_state_machine(self):
await self._flush_response()
return
except ServerBusyError:
- self._engine.on_busy(self._send_buf)
+ self._engine.on_unavailable(self._send_buf)
+ await self._flush_response()
+ return
+ except HeaderParsingError:
+ self._engine.on_client_error(self._send_buf, b"Invalid headers")
await self._flush_response()
return
except Exception as e: # pylint: disable=W0718
@@ -162,32 +112,6 @@ async def _run_state_machine(self):
if self._engine.state is None and resp_handler is not None:
await self._response_handler(resp_handler)
- async def _read_to_buf(self):
- buf_free = self._recv_buf.capacity - self._recv_buf.size()
- if not buf_free:
- self._engine.on_buffer_full(self._send_buf)
- await self._flush_response()
- return 0
- try:
- request = await self.read(
- read_bytes=buf_free,
- decoding=None,
- timeout_seconds=SocketHttp.RECV_TIMEOUT_SECONDS,
- )
- except asyncio.TimeoutError:
- self._engine.on_timeout(self._send_buf)
- await self._flush_response()
- return 0
- except Exception as e: # pylint: disable=W0718
- self._engine.on_failure(
- self._send_buf, b"Read error: " + str(e).encode("ascii")
- )
- await self._flush_response()
- return 0
- self._recv_buf.write(request)
- logging.debug(__name__ + f"._read_to_buf: [{request}]")
- return len(request)
-
async def _response_handler(self, resp_handler):
if "closure" == type(resp_handler).__name__:
for is_finished in resp_handler(self._send_buf):
diff --git a/src/pyrobusta/con/wifi.py b/src/pyrobusta/con/wifi.py
index 8755f44..b0d85e0 100644
--- a/src/pyrobusta/con/wifi.py
+++ b/src/pyrobusta/con/wifi.py
@@ -2,7 +2,10 @@
Helpers for setting up Wi-Fi in station mode
"""
+from time import sleep
+
from network import WLAN, STA_IF
+
from ..utils.config import get_config
from ..utils import logging
@@ -13,19 +16,30 @@ def initialize():
"""
ssid = get_config("wifi_ssid")
password = get_config("wifi_password")
+
if not ssid or not password:
- logging.warning(__name__ + ": missing SSID/password, skip initialization")
- return
+ logging.warning(__name__ + ": missing SSID/password")
+ return False
sta_if = WLAN(STA_IF)
sta_if.active(True)
- nets = sta_if.scan()
- for net in nets:
- if net[0].decode() == get_config("wifi_ssid"):
- logging.info(__name__ + f": network {net[0]} found!")
- sta_if.connect(net[0], get_config("wifi_password"))
- logging.info(__name__ + f": connected, available at {sta_if.ifconfig()[0]}")
- break
+ if sta_if.isconnected():
+ logging.info(__name__ + f": already connected IP={sta_if.ifconfig()[0]}")
+ return True
+
+ sta_if.connect(ssid, password)
+
+ timeout = 30
+ while timeout > 0:
+ if sta_if.isconnected():
+ ip = sta_if.ifconfig()[0]
+ logging.info(__name__ + f": connected, IP={ip}")
+ return True
+ sleep(1)
+ timeout -= 1
+
+ logging.warning(__name__ + ": connection failed")
+ return False
def get_address():
@@ -33,4 +47,6 @@ def get_address():
Get the address of the WLAN interface
"""
sta_if = WLAN(STA_IF)
- return sta_if.ifconfig()[0]
+ if sta_if.isconnected():
+ return sta_if.ifconfig()[0]
+ return None
diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py
index 0de2a54..4171c4d 100644
--- a/src/pyrobusta/protocol/http.py
+++ b/src/pyrobusta/protocol/http.py
@@ -5,10 +5,8 @@
from json import dumps
from io import BytesIO
-from os import stat
from ..utils.config import get_config
-from ..utils.helpers import normalize_path
class HeaderParsingError(ValueError):
@@ -199,17 +197,20 @@ def get_url_encoded_query_param(query: str, key: str, default: str = None):
:param key: key to parse from the query
:param default: default value to return when key is not present
"""
- idx_start = query.find(key + "=")
- if idx_start != -1:
- idx_end = -1
- idx_end = query.find("&", idx_start)
- if idx_start > -1:
- if idx_end > -1:
- return query[idx_start + len(key) + 1 : idx_end]
- return query[idx_start + len(key) + 1 :]
- if default is None:
+ if query.startswith(key + "="):
+ idx_start = 0
+ elif (idx_start := query.find("&" + key + "=")) != -1:
+ idx_start += 1
+ elif default is None:
raise KeyError()
- return default
+ else:
+ return default
+
+ idx_end = -1
+ idx_end = query.find("&", idx_start)
+ if idx_end > -1:
+ return query[idx_start + len(key) + 1 : idx_end]
+ return query[idx_start + len(key) + 1 :]
@classmethod
def is_norm_path_served(cls, path: str):
@@ -262,13 +263,25 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
headers = {}
for line in header_lines:
# pylint: disable=W0511
- # TODO: support for UTF-8 in field values (e.g filenames), can be board dependent
if any(c > 127 for c in line):
raise HeaderParsingError("Non-ASCII character")
if b":" not in line:
raise HeaderParsingError()
name, value = line.split(b":", 1)
+ if not name:
+ raise HeaderParsingError("Empty header name")
+ for c in name:
+ if (
+ 48 <= c <= 57 # 0-9
+ or 65 <= c <= 90 # A-Z
+ or 97 <= c <= 122 # a-z
+ or c in (45, 95) # -_
+ ):
+ continue
+ raise HeaderParsingError("Invalid header name")
name = name.strip().lower().decode(cls.ASCII)
+ if any((c < 32 and c != 9) or c == 127 for c in value):
+ raise HeaderParsingError("Invalid header value")
if name == cls.CONTENT_LENGTH:
value = int(value.strip())
else:
@@ -277,18 +290,35 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
return headers
@staticmethod
- def _is_multipart(headers: dict) -> str:
+ def _get_mp_boundary(headers: dict) -> str:
"""Determine from the headers if a request is multipart, and return the boundary value"""
content_type = headers.get("content-type")
- if content_type and content_type.lower().startswith("multipart/form-data"):
- parts = content_type.split(";")
- for part in parts[1:]:
- if "=" in part:
- key, value = part.strip().split("=", 1)
- if key.strip().lower() == "boundary":
- boundary = value.strip().strip('"')
- return boundary if boundary else None
- return None
+ if not content_type or not content_type.lower().startswith(
+ "multipart/form-data"
+ ):
+ return None
+
+ parts = content_type.split(";")
+ for part in parts[1:]:
+ if "=" not in part:
+ continue
+ key, value = part.strip().split("=", 1)
+
+ if key.strip().lower() != "boundary":
+ continue
+ value = value.strip()
+
+ if value.startswith('"'):
+ if len(value) < 2 or not value.endswith('"'):
+ raise HeaderParsingError()
+ value = value[1:-1]
+ elif value.endswith('"'):
+ raise HeaderParsingError()
+
+ if not value:
+ raise HeaderParsingError()
+ return value
+ raise HeaderParsingError()
@classmethod
def _parse_body_part(cls, part: memoryview) -> tuple[dict, bytes]:
@@ -417,7 +447,7 @@ def on_failure(self, tx, info: bytes):
self._write_response_head(tx, len(info))
tx.write(info)
- def on_busy(self, tx):
+ def on_unavailable(self, tx):
"""Terminate state machine and write 503 response"""
self.terminate(503)
self._write_response_head(tx)
@@ -508,10 +538,13 @@ def _route_request_st(self, _, tx):
if self.method == self.HEAD:
self.on_client_error(tx, self.BAD_REQUEST_ERROR)
return
- if mp_boundary := self._is_multipart(self.headers):
+ if mp_boundary := self._get_mp_boundary(self.headers):
self.mp_boundary = mp_boundary.encode(self.ASCII)
self.state = self._start_multipart_parser_st
elif self._is_chunked():
+ if self.CONTENT_LENGTH in self.headers:
+ self.on_client_error(tx, self.BAD_REQUEST_ERROR)
+ return
self.state = self._recv_chunked_size_st
else:
self.state = self._recv_payload_st
@@ -528,8 +561,7 @@ def _route_request_st(self, _, tx):
self.on_method_not_allowed(tx)
return
if self.method in (self.GET, self.HEAD):
- resource = b"index.html" if not self.url else self.url
- self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, resource)
+ self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, self.url)
return
self.on_missing_resource(tx)
@@ -584,61 +616,28 @@ def _app_endpoint_st(self, rx, tx):
dtype = dtype.encode(self.ASCII)
self._set_response_header(b"content-type", dtype)
- if dtype in (b"multipart/x-mixed-replace", b"multipart/form-data"):
- part_content_type = data[0]
- callback = data[1]
- if type(callback).__name__ not in ("function", "closure"):
- self.on_failure(tx, b"Invalid response handler")
- return
- self.terminate(200, dtype)
- boundary = self.MULTIPART_BOUNDARY
- self._set_response_header(
- b"content-type", dtype + b"; boundary=" + boundary
+ if dtype.startswith(b"multipart/"):
+ self.state = lambda _rx, _tx: self._generate_multipart_response(
+ _rx, _tx, data, dtype
)
- self._write_response_head(tx, None)
- if self.method != self.HEAD:
- return self._multipart_wrapper_factory(
- callback, part_content_type.encode(self.ASCII), boundary
- )
return
+
self.terminate(200, dtype)
return self._generate_response(tx, data)
- def _send_file_st(self, _, tx, web_resource: bytes):
- """State for returning a static resource"""
- extension = web_resource.rsplit(b".", 1)[-1]
- norm_path = normalize_path(web_resource.decode(self.ASCII))
- is_path_served = self.is_norm_path_served(norm_path)
- if not is_path_served:
- try:
- stat(norm_path)
- self.on_forbidden(tx)
- return
- except OSError:
- self.on_missing_resource(tx)
- return
- try:
- content_type = self._lookup(self.CONTENT_TYPES, extension)
- except ValueError:
- content_type = self._lookup(self.CONTENT_TYPES, b"raw")
- try:
- self._set_response_header(
- b"content-length", str(stat(norm_path)[6]).encode(HttpEngine.ASCII)
- )
- self.terminate(200, content_type)
- self._write_response_head(tx, None)
- if self.method != self.HEAD:
- return open(norm_path, "rb")
- return
- except OSError:
- self.on_missing_resource(tx)
+ def _send_file_st(self, _, tx, web_resource: bytes): # pylint: disable=W0613
+ """State for returning a static resource - disabled"""
+ self.on_unavailable(tx)
def _start_multipart_parser_st(self, rx, tx): # pylint: disable=W0613
- self.on_failure(tx, b"Multipart handling is disabled")
+ """Initial state for processing multipart requests"""
+ self.on_unavailable(tx)
- @staticmethod
- def _multipart_wrapper_factory(callback, content_type: bytes, boundary: bytes):
- pass
+ def _generate_multipart_response(
+ self, rx, tx, callback, dtype
+ ): # pylint: disable=W0613
+ """Generate multipart response depening on the exact content type"""
+ self.on_unavailable(tx)
def enable_optional_features():
@@ -649,3 +648,8 @@ def enable_optional_features():
from pyrobusta.protocol import http_multipart
http_multipart.apply_patches()
+
+ if get_config("http_serve_files").lower() == "true":
+ from pyrobusta.protocol import http_file_server
+
+ http_file_server.apply_patches()
diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py
new file mode 100644
index 0000000..1e1acc4
--- /dev/null
+++ b/src/pyrobusta/protocol/http_file_server.py
@@ -0,0 +1,57 @@
+"""
+State machine extension for file serving.
+"""
+
+# pylint: disable=W0212,R0401
+
+from os import stat
+
+from pyrobusta.protocol import http
+from pyrobusta.utils.helpers import normalize_path, add_method
+
+
+def _send_file_st(self, _, tx, web_resource: bytes):
+ """State for returning a static resource"""
+ if self.url == b"/files":
+ web_resource = "/"
+ elif self.url.startswith(b"/files/"):
+ web_resource = web_resource[7:]
+ elif self.url == b"/":
+ web_resource = b"/www/index.html"
+ else:
+ web_resource = b"/www" + web_resource
+
+ extension = web_resource.rsplit(b".", 1)[-1]
+ norm_path = normalize_path(web_resource.decode(self.ASCII))
+ is_path_served = self.is_norm_path_served(norm_path)
+ if not is_path_served:
+ try:
+ stat(norm_path)
+ self.on_forbidden(tx)
+ return
+ except OSError:
+ self.on_missing_resource(tx)
+ return
+ try:
+ content_type = self._lookup(self.CONTENT_TYPES, extension)
+ except ValueError:
+ content_type = self._lookup(self.CONTENT_TYPES, b"raw")
+ try:
+ self._set_response_header(
+ b"content-length", str(stat(norm_path)[6]).encode(http.HttpEngine.ASCII)
+ )
+ self.terminate(200, content_type)
+ self._write_response_head(tx, None)
+ if self.method != self.HEAD:
+ return open(norm_path, "rb")
+ return
+ except OSError:
+ self.on_missing_resource(tx)
+
+
+def apply_patches():
+ """
+ Apply patches to class attributes for file serving.
+ """
+
+ add_method(http.HttpEngine, _send_file_st)
diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py
index dc95393..e58074a 100644
--- a/src/pyrobusta/protocol/http_multipart.py
+++ b/src/pyrobusta/protocol/http_multipart.py
@@ -5,24 +5,23 @@
# pylint: disable=W0212,R0401
from pyrobusta.protocol import http
+from pyrobusta.utils.helpers import add_method
-def add_method(cls, func, method_type="instance"):
- """
- Helper to extend web.WebEngine with
- additional methods and states.
- """
- if method_type == "instance":
- setattr(cls, func.__name__, func)
- elif method_type == "static":
- setattr(cls, func.__name__, staticmethod(func))
- elif method_type == "class":
- setattr(cls, func.__name__, classmethod(func))
- else:
- raise ValueError("Invalid type")
+def _generate_multipart_response(self, _, tx, callback, dtype):
+ """Generate multipart response depening on the exact content type"""
+ if type(callback).__name__ not in ("function", "closure"):
+ self.on_failure(tx, b"Invalid response handler")
+ return
+ self.terminate(200, dtype)
+ boundary = self.MULTIPART_BOUNDARY
+ self._set_response_header(b"content-type", dtype + b"; boundary=" + boundary)
+ self._write_response_head(tx, None)
+ if self.method != self.HEAD:
+ return self._multipart_wrapper_factory(callback, boundary)
-def _multipart_wrapper_factory(callback, content_type: bytes, boundary: bytes):
+def _multipart_wrapper_factory(callback, boundary: bytes):
"""
Factory method for creating closures that write multipart responses
:param callback: function without arguments, must return bytes-like objects
@@ -30,8 +29,7 @@ def _multipart_wrapper_factory(callback, content_type: bytes, boundary: bytes):
:param boundary: boundary value
:return closure: closure to invoke for response generation
"""
- boundary = b"--" + boundary
- content_type_header = b"content-type: %s\r\n\r\n" % content_type
+ delimiter = b"--" + boundary
def _multipart_wrapper(tx):
"""
@@ -41,13 +39,16 @@ def _multipart_wrapper(tx):
:return bool: true if the stream is completed
"""
while True:
- tx.write(boundary)
- part_body = callback()
- if not part_body:
+ tx.write(delimiter)
+ part = callback()
+ if not part:
tx.write(b"--")
yield True
+ content_type, part_body = part
tx.write(b"\r\n")
- tx.write(content_type_header)
+ tx.write(b"content-type:")
+ tx.write(content_type.encode("ascii"))
+ tx.write(b"\r\n\r\n")
written = 0
while written < len(part_body):
to_write = tx.capacity - tx.size()
@@ -139,9 +140,7 @@ def apply_patches():
"""
Apply patches to class attributes for multipart parsing.
"""
- cls = http.HttpEngine
-
- orig_init = cls.__init__
+ orig_init = http.HttpEngine.__init__
def new_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
@@ -150,8 +149,9 @@ def new_init(self, *args, **kwargs):
self.mp_delimiter = None
self.mp_last_delimiter = None
- cls.__init__ = new_init
+ http.HttpEngine.__init__ = new_init
+ add_method(http.HttpEngine, _generate_multipart_response)
add_method(http.HttpEngine, _multipart_wrapper_factory, "static")
add_method(http.HttpEngine, _start_multipart_parser_st)
add_method(http.HttpEngine, _parse_boundary_st)
diff --git a/src/pyrobusta/server/http_server.py b/src/pyrobusta/server/http_server.py
index e0cecff..6761469 100644
--- a/src/pyrobusta/server/http_server.py
+++ b/src/pyrobusta/server/http_server.py
@@ -2,120 +2,212 @@
Socket server application
"""
-import gc
+from gc import collect, mem_free
from asyncio import sleep_ms, start_server, run # pylint: disable=E1101
from time import ticks_ms, ticks_diff
from ..protocol import http
from ..bindings.socket_http import SocketHttp
+from ..stream.buffer import MemoryPool, SlidingBuffer
from ..utils.config import get_config
+from ..utils.helpers import normalize_path
from ..utils import logging
class HttpServer:
"""
Socket server class, handling global config (timeout, port, max connections etc.),
- and managing active sockets.
+ and managing active clients.
"""
- __slots__ = ["_host", "_max_sockets", "_port", "_timeout", "_server"]
+ __slots__ = ["_host", "_port", "_server", "_max_clients"]
- ACTIVE_SOCKETS = []
+ ACTIVE_CLIENTS = []
+
+ # ---------------
+ # Server settings
+ # ---------------
CON_ACCEPT_TIMEOUT_MS = 5000 # Timeout value for accepting new connection
CON_ACCEPT_SLEEP_MS = (
100 # Duration of sleep between attempts to accept new connection
)
- MAX_SOCKETS = int(get_config("socket_max_con"))
- SOCKET_TIMEOUT_SEC = 30
- LISTEN_PORT_HTTP = 8080
- LISTEN_PORT_HTTPS = 4443
- TLS_CERT_PATH = "cert.der"
- TLS_KEY_PATH = "key.der"
+ LISTEN_PORT_HTTP = 80
+ LISTEN_PORT_HTTPS = 443
+ TLS_CERT_PATH = "/cert.der"
+ TLS_KEY_PATH = "/key.der"
+ CON_TIMEOUT_S = 30
+
+ # -----------------------------------------
+ # Constants for controlled memory footprint
+ # -----------------------------------------
+
+ MEM_CAP = float(get_config("http_mem_cap")) # Default memory cap (percentage / 100)
+ SEND_BUF_MIN_BYTES = 512 # Minimum buffer size for responses
+ SEND_BUF_MAX_BYTES = 4096 # Max buffer size for responses
+ RECV_BUF_MIN_BYTES = 512 # Minimum buffer size for requests
+ RECV_BUF_MAX_BYTES = 4096 # Max buffer size for requests
+ CON_OVERHEAD_BYTES = 1024 # Overhead per connection
+
+ # ------------------------------------------
+ # Buffer pools - initialized by init_pools()
+ # ------------------------------------------
+
+ RECV_POOL = None
+ SEND_POOL = None
+
+ @classmethod
+ def _init_pools(cls, max_clients):
+ """
+ Initialize pool of buffers for sending/receiving based on different profiles
+ """
+ mem_available = mem_free()
+ con_limit = max_clients
+ usable = int(cls.MEM_CAP * mem_available)
+ is_low_memory = (usable / con_limit) < (
+ cls.RECV_BUF_MAX_BYTES + cls.SEND_BUF_MAX_BYTES + cls.CON_OVERHEAD_BYTES
+ )
+ if is_low_memory:
+ logging.warning(
+ __name__ + ".init_pools: low-memory mode with reduced buffer size"
+ )
+ recv_size = cls.RECV_BUF_MIN_BYTES if is_low_memory else cls.RECV_BUF_MAX_BYTES
+ send_size = cls.SEND_BUF_MIN_BYTES if is_low_memory else cls.SEND_BUF_MAX_BYTES
+ per_con = recv_size + send_size + cls.CON_OVERHEAD_BYTES
+ if usable < per_con:
+ raise MemoryError(
+ (
+ f"Insufficient memory: {mem_available // 1024} KB "
+ f"at {cls.MEM_CAP*100}% cap, "
+ f"at least {per_con // 1024} KB required"
+ )
+ )
+ con_limit = min(usable // per_con, con_limit)
+ logging.info((__name__ + f".init_pools: {con_limit} connection(s) allowed"))
+ cls.RECV_POOL = MemoryPool(recv_size, con_limit, wrapper=SlidingBuffer)
+ cls.SEND_POOL = MemoryPool(send_size, con_limit, wrapper=SlidingBuffer)
@classmethod
- async def drop_client(cls, socket):
- """Remove socket from active list"""
- if socket not in cls.ACTIVE_SOCKETS:
+ async def _drop_client(cls, client):
+ """Remove client from active list"""
+ if client not in cls.ACTIVE_CLIENTS:
return
- logging.debug(__name__ + f": {socket.id} dropped")
- await socket.close()
- cls.ACTIVE_SOCKETS.remove(socket)
- del socket
- gc.collect()
+ logging.debug(__name__ + f": {client.id} dropped")
+ await client.close()
+ cls.ACTIVE_CLIENTS.remove(client)
+ del client
+ collect()
+
+ # ----------------
+ # Instance methods
+ # ----------------
def __init__(self):
self._host = "0.0.0.0"
- self._max_sockets = max(1, HttpServer.MAX_SOCKETS)
self._port = (
HttpServer.LISTEN_PORT_HTTPS
if get_config("tls").lower() == "true"
else HttpServer.LISTEN_PORT_HTTP
)
- self._timeout = HttpServer.SOCKET_TIMEOUT_SEC
self._server = None
+ self._max_clients = 0
- async def can_handle_new_socket(self):
+ async def can_handle_new_client(self):
"""
Decide if the new socket can be handled.
Evict closed/inactive sockets if needed.
:return is_acceptable: true/false
"""
- gc.collect()
+ collect()
con_timestamp = ticks_ms()
while ticks_diff(ticks_ms(), con_timestamp) < self.CON_ACCEPT_TIMEOUT_MS:
- if len(self.ACTIVE_SOCKETS) < self._max_sockets:
+ if len(self.ACTIVE_CLIENTS) < self._max_clients:
return True
# Attempt to evict inactive clients
- for socket in self.ACTIVE_SOCKETS:
- socket_inactive = int(ticks_diff(ticks_ms(), socket.last_event) * 0.001)
- if not socket.connected or socket_inactive > self._timeout:
+ for client in self.ACTIVE_CLIENTS:
+ client_inactive = int(ticks_diff(ticks_ms(), client.last_event) * 0.001)
+ if not client.connected or client_inactive > self.CON_TIMEOUT_S:
logging.debug(
(
- __name__ + f": evicted {socket.id} "
- f"timeout: {self._timeout - socket_inactive}s"
+ __name__ + f": evicted {client.id} "
+ f"timeout: {self.CON_TIMEOUT_S - client_inactive}s"
)
)
- await self.drop_client(socket)
- return True
+ await self._drop_client(client)
await sleep_ms(self.CON_ACCEPT_SLEEP_MS)
return False
- async def accept_http(self, reader, writer):
+ async def _reserve_buffers(self):
+ if self.SEND_POOL is None or self.RECV_POOL is None:
+ raise RuntimeError("Pools are uninitialized")
+
+ recv_buf = None
+ send_buf = None
+
+ while not recv_buf or not send_buf:
+ if not recv_buf:
+ recv_buf = self.RECV_POOL.reserve()
+ if not send_buf:
+ send_buf = self.SEND_POOL.reserve()
+ await sleep_ms(self.CON_ACCEPT_SLEEP_MS)
+
+ return recv_buf, send_buf
+
+ async def _accept_socket(self, reader, writer):
"""
Handle incoming socket connection for HTTP.
- creates SocketHttp object
"""
- if not await self.can_handle_new_socket():
+ if not await self.can_handle_new_client():
logging.debug(__name__ + ": cannot accept new client")
writer.close()
await writer.wait_closed()
return
- new_client = SocketHttp(reader, writer)
- logging.debug(__name__ + f": accept {new_client.id}")
- self.ACTIVE_SOCKETS.append(new_client)
- await new_client.run()
-
- async def run_server(self):
+ try:
+ recv_buf, send_buf = await self._reserve_buffers()
+ new_client = SocketHttp(reader, writer, recv_buf, send_buf)
+ logging.debug(__name__ + f": accept {new_client.id}")
+ self.ACTIVE_CLIENTS.append(new_client)
+ await new_client.run()
+ except Exception as e: # pylint: disable=W0718
+ logging.warning(__name__ + f": error in run(): {e}")
+ finally:
+ if send_buf:
+ send_buf.consume()
+ self.SEND_POOL.release(send_buf)
+ if recv_buf:
+ recv_buf.consume()
+ self.RECV_POOL.release(recv_buf)
+ collect()
+
+ async def start_socket_server(self):
"""
Start asyncio socket server on the specified port.
"""
try:
- gc.collect()
+ collect()
http.enable_optional_features()
- logging.debug(f"Registered endpoints: {http.HttpEngine.ENDPOINTS}")
- SocketHttp.init_pools(self._max_sockets)
+ logging.debug(
+ __name__ + f"registered endpoints: {http.HttpEngine.ENDPOINTS}"
+ )
+ self._max_clients = int(get_config("socket_max_con"))
+ self._init_pools(self._max_clients)
ssl_ctx = None
+
if get_config("tls").lower() == "true":
import ssl
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
- ssl_ctx.load_cert_chain(self.TLS_CERT_PATH, self.TLS_KEY_PATH)
+ ssl_ctx.load_cert_chain(
+ normalize_path(self.TLS_CERT_PATH),
+ normalize_path(self.TLS_KEY_PATH),
+ )
+
self._server = await start_server(
- self.accept_http,
+ self._accept_socket,
self._host,
self._port,
- backlog=self._max_sockets,
+ backlog=max(1, self._max_clients),
ssl=ssl_ctx,
)
logging.info(__name__ + ": started")
@@ -124,20 +216,20 @@ async def run_server(self):
async def terminate(self):
"""
- Terminate HTTP server and close sockets
+ Terminate HTTP server and drop clients
"""
logging.info(__name__ + ": terminated")
- while self.ACTIVE_SOCKETS:
- await self.drop_client(self.ACTIVE_SOCKETS[0])
+ while self.ACTIVE_CLIENTS:
+ await self._drop_client(self.ACTIVE_CLIENTS[0])
if self._server:
self._server.close()
await self._server.wait_closed()
self._server = None
- gc.collect()
+ collect()
def main():
"""
- Start socket server async task.
+ Start HTTP server async task.
"""
- run(HttpServer().run_server())
+ run(HttpServer().start_socket_server())
diff --git a/src/pyrobusta/utils/assets.py b/src/pyrobusta/utils/assets.py
new file mode 100644
index 0000000..c5f5a12
--- /dev/null
+++ b/src/pyrobusta/utils/assets.py
@@ -0,0 +1,79 @@
+"""
+Helper functions to install assets.
+"""
+
+from os import mkdir, listdir, stat
+
+from .helpers import normalize_path
+
+FS_ITER_ABS = 0
+FS_ITER_REL = 1
+
+FS_ITER_FILE = 0
+FS_ITER_DIR = 1
+
+
+def copy_file(src_path, dst_path):
+ """
+ Copy a file from a to a destination path.
+ """
+ with open(src_path, "rb") as src:
+ with open(dst_path, "wb") as dst:
+ while True:
+ chunk = src.read(512)
+ if not chunk:
+ break
+ dst.write(chunk)
+
+
+def iterate_fs(root, iter_mode=FS_ITER_FILE, path_mode=FS_ITER_ABS):
+ """
+ Iterate over all files or directories and yield
+ resulting paths either as absolute or relative paths.
+ :param dir: directory in which to iterate
+ :iter_mode int: iterate over files (FS_ITER_FILE=0) or directories (FS_ITER_DIR=1)
+ :path_mode int: yield absolute paths (FS_ITER_ABS=0) ro relative paths (FS_ITER_REL=1)
+ """
+ dirs = [root]
+ while dirs:
+ current_directory = dirs.pop(0)
+ for name in listdir(current_directory):
+ if current_directory == "/":
+ current_path = "/" + name
+ else:
+ current_path = current_directory + "/" + name
+ st = stat(current_path)
+ fs_mode = st[0]
+ if fs_mode & 0x4000: # directory bit set
+ dirs.append(current_path)
+ if iter_mode == FS_ITER_DIR:
+ if path_mode == FS_ITER_REL:
+ yield current_path[len(root) + 1 :]
+ else:
+ yield current_path
+ else:
+ continue
+ if iter_mode == FS_ITER_FILE:
+ if path_mode == FS_ITER_REL:
+ yield current_path[len(root) + 1 :]
+ else:
+ yield current_path
+
+
+def install_www():
+ """
+ Install default web server assets under /www.
+ """
+ source_dir = normalize_path("/lib/pyrobusta/assets/www")
+ target_dir = normalize_path("/www")
+ if "www" not in listdir():
+ mkdir(target_dir)
+
+ for asset_dir in iterate_fs(source_dir, FS_ITER_DIR, FS_ITER_ABS):
+ mkdir(asset_dir)
+
+ for asset in iterate_fs(source_dir, FS_ITER_FILE, FS_ITER_REL):
+ copy_file(
+ source_dir + "/" + asset,
+ target_dir + "/" + asset,
+ )
diff --git a/src/pyrobusta/utils/config.py b/src/pyrobusta/utils/config.py
index 35dc2ea..86dbf44 100644
--- a/src/pyrobusta/utils/config.py
+++ b/src/pyrobusta/utils/config.py
@@ -6,7 +6,7 @@
from .helpers import normalize_path
-PYROBUSTA_VERSION = "0.3.0"
+PYROBUSTA_VERSION = "v0.4.0"
CONFIG_LOADED = False
CONFIG_LOCATION = "pyrobusta.env"
CONFIG_CACHE = [
@@ -19,13 +19,15 @@
"http_mem_cap",
0.1,
"http_served_paths",
- "/lib/pyrobusta",
+ "/www /lib/pyrobusta",
+ "http_serve_files",
+ "True",
"socket_max_con",
2,
"tls",
"False",
"log_level",
- "warning",
+ "info",
]
@@ -46,11 +48,12 @@ def read_config(config=CONFIG_LOCATION):
try:
with open(config, encoding="utf-8") as conf:
for line in conf:
- line = line.rstrip("\r\n")
- key = line.split("=")[0].strip()
- if key.startswith("#") or not line.strip():
+ line = line.rstrip("\r\n").split("#")[0]
+ if not line.strip():
continue
- value = line.split("=")[1].strip().strip("'").strip('"')
+ parts = line.split("=")
+ key = parts[0].strip()
+ value = parts[1].strip().strip("'").strip('"')
if key and value:
value = normalize(key, value)
if (
diff --git a/src/pyrobusta/utils/helpers.py b/src/pyrobusta/utils/helpers.py
index cc0d9d3..08f647a 100644
--- a/src/pyrobusta/utils/helpers.py
+++ b/src/pyrobusta/utils/helpers.py
@@ -25,3 +25,18 @@ def normalize_path(path: str):
return cwd + normalized
return cwd + "/" + normalized
return cwd
+
+
+def add_method(cls, func, method_type="instance"):
+ """
+ Helper to patch/extend classes with
+ additional methods and states.
+ """
+ if method_type == "instance":
+ setattr(cls, func.__name__, func)
+ elif method_type == "static":
+ setattr(cls, func.__name__, staticmethod(func))
+ elif method_type == "class":
+ setattr(cls, func.__name__, classmethod(func))
+ else:
+ raise ValueError("Invalid type")
diff --git a/tests/.pylintrc b/tests/.pylintrc
index 1824775..5cf251e 100644
--- a/tests/.pylintrc
+++ b/tests/.pylintrc
@@ -3,4 +3,5 @@ disable=W0212,
C0114,
C0115,
C0116,
- R0904
+ R0904,
+ R0902
diff --git a/tests/functional/test_http.py b/tests/functional/test_http.py
index 231f0c4..8800db0 100644
--- a/tests/functional/test_http.py
+++ b/tests/functional/test_http.py
@@ -1,11 +1,11 @@
import asyncio
import ssl
import json
+import gc
-from os import getcwd, mkdir
+from os import mkdir, remove, rmdir
from pyrobusta.server import http_server
-from pyrobusta.protocol import http_multipart
from pyrobusta.protocol.http import (
HttpEngine,
enable_optional_features,
@@ -19,6 +19,15 @@
#################################################
+def garbage_collect(coroutine):
+ async def decorated(*args, **kwargs):
+ gc.collect()
+ await coroutine(*args, **kwargs)
+ gc.collect()
+
+ return decorated
+
+
def test_assert(name, actual, expected):
print(f"Test {name}: ", end="")
if actual == expected:
@@ -55,19 +64,6 @@ async def send_request(request, tls=False):
return response
-def multipart_response(num_responses):
- i = 0
-
- def response_generator():
- nonlocal i
- i += 1
- if i > num_responses:
- return None
- return b"Response %s" % i
-
- return response_generator
-
-
#################################################
# Test driver
#################################################
@@ -82,12 +78,6 @@ def simple_callback(http_ctx, _):
raise ValueError("Unhandled content-type")
-@HttpEngine.route("/test/multipart", "GET")
-def multipart_callback(http_ctx, _):
- part_count = int(http_ctx.headers["x-part-count"])
- return "multipart/form-data", ("text/plain", multipart_response(part_count))
-
-
@HttpEngine.route("/test/busy", "POST")
def busy_callback(*_):
raise ServerBusyError()
@@ -108,13 +98,14 @@ async def start_server():
Start an HTTP server as a background task
"""
server = http_server.HttpServer()
- server_task = asyncio.create_task(server.run_server())
+ server_task = asyncio.create_task(server.start_socket_server())
await asyncio.sleep_ms(100)
return server, server_task
+@garbage_collect
async def test_simple_response(tls_enabled):
- setup_config(multipart=False, tls_enabled=tls_enabled)
+ setup_config(tls_enabled=tls_enabled)
server, server_task = await start_server()
# Test: text/plain
@@ -157,38 +148,7 @@ async def test_simple_response(tls_enabled):
await server.terminate()
-async def test_multipart_response(tls_enabled):
- setup_config(multipart=True, tls_enabled=tls_enabled)
- server, server_task = await start_server()
-
- # Test: 1 part
- plain_response = await send_request(
- b"GET /test/multipart HTTP/1.1\r\n"
- b"Host: localhost\r\nX-Part-Count: 1\r\n\r\n",
- tls_enabled,
- )
- test_assert(
- f"http{"s" if tls_enabled else ""} response contains 1 part",
- b"Response 1" in plain_response,
- True,
- )
-
- # Test: 10 parts
- plain_response = await send_request(
- b"GET /test/multipart HTTP/1.1\r\n"
- b"Host: localhost\r\nX-Part-Count: 10\r\n\r\n",
- tls_enabled,
- )
- test_assert(
- f"http{"s" if tls_enabled else ""} response contains 10 parts",
- [b"Response %s" % i in plain_response for i in range(1, 11)],
- [True] * 10,
- )
-
- server_task.cancel()
- await server.terminate()
-
-
+@garbage_collect
async def test_server_busy():
setup_config()
server, server_task = await start_server()
@@ -206,6 +166,7 @@ async def test_server_busy():
await server.terminate()
+@garbage_collect
async def test_chunked_transfer_encoding():
setup_config()
create_chunked_app_endpoint("/test/chunked")
@@ -233,25 +194,33 @@ async def test_chunked_transfer_encoding():
await server.terminate()
+@garbage_collect
async def test_fs_access_control():
- setup_config(served_paths="/www")
+ setup_config(served_paths="/www/allowed")
server, server_task = await start_server()
+ workdir_root = normalize_path("/www")
+ try:
+ mkdir(workdir_root)
+ except:
+ pass
# Index page under /www -> accepted
- workdir = normalize_path("/www")
- index_html = normalize_path("/www/index.html")
- mkdir(workdir)
- with open(index_html, "w") as f:
+ allowed_workdir = normalize_path("/www/allowed")
+ allowed_index_html = normalize_path("/www/allowed/index.html")
+ mkdir(allowed_workdir)
+ with open(allowed_index_html, "w") as f:
f.write("PyRobusta Home")
# Index page under / -> rejected
- index_html = normalize_path("/index.html")
- with open(index_html, "w") as f:
+ rejected_workdir = normalize_path("/www/rejected")
+ rejected_index_html = normalize_path("/www/rejected/index.html")
+ mkdir(rejected_workdir)
+ with open(rejected_index_html, "w") as f:
f.write("PyRobusta Home")
# Case #1: /www/index.html
response = await send_request(
- (b"GET /www/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n")
+ (b"GET /allowed/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n")
)
response_body = response.split(b"\r\n\r\n")[1]
@@ -263,7 +232,7 @@ async def test_fs_access_control():
# Case #2: /index.html
response = await send_request(
- (b"GET /index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n")
+ (b"GET /rejected/index.html HTTP/1.1\r\n" b"Host: localhost\r\n\r\n")
)
test_assert(
@@ -272,6 +241,10 @@ async def test_fs_access_control():
True,
)
+ remove(allowed_index_html)
+ remove(rejected_index_html)
+ rmdir(allowed_workdir)
+ rmdir(rejected_workdir)
server_task.cancel()
await server.terminate()
@@ -281,9 +254,12 @@ async def test_fs_access_control():
#################################################
-def setup_config(multipart=False, tls_enabled=False, served_paths=""):
- config_idx = config.CONFIG_CACHE.index("http_multipart")
- config.CONFIG_CACHE[config_idx + 1] = str(multipart)
+def setup_config(tls_enabled=False, served_paths=""):
+ http_server.HttpServer.LISTEN_PORT_HTTP = 8080
+ http_server.HttpServer.LISTEN_PORT_HTTPS = 4443
+
+ config_idx = config.CONFIG_CACHE.index("log_level")
+ config.CONFIG_CACHE[config_idx + 1] = str("warning")
config_idx = config.CONFIG_CACHE.index("tls")
config.CONFIG_CACHE[config_idx + 1] = str(tls_enabled)
config_idx = config.CONFIG_CACHE.index("http_served_paths")
@@ -300,12 +276,6 @@ def test_registration():
HttpEngine._get_callback(b"/test/simple", b"GET"),
)
- test_assert(
- "multipart endpoint registration",
- multipart_callback,
- HttpEngine._get_callback(b"/test/multipart", b"GET"),
- )
-
test_assert(
"busy endpoint registration",
busy_callback,
@@ -313,24 +283,11 @@ def test_registration():
)
-def test_multipart_patches():
- setup_config(multipart=True)
- test_assert(
- "multipart state machine patches",
- http_multipart._start_multipart_parser_st,
- HttpEngine._start_multipart_parser_st,
- )
-
-
def test_main():
test_registration()
asyncio.run(test_simple_response(tls_enabled=False))
asyncio.run(test_simple_response(tls_enabled=True))
- test_multipart_patches()
- asyncio.run(test_multipart_response(tls_enabled=False))
- asyncio.run(test_multipart_response(tls_enabled=True))
-
asyncio.run(test_server_busy())
asyncio.run(test_chunked_transfer_encoding())
asyncio.run(test_fs_access_control())
diff --git a/tests/functional/test_http_multipart.py b/tests/functional/test_http_multipart.py
new file mode 100644
index 0000000..5463a59
--- /dev/null
+++ b/tests/functional/test_http_multipart.py
@@ -0,0 +1,172 @@
+import asyncio
+import ssl
+import gc
+
+from pyrobusta.server import http_server
+from pyrobusta.protocol import http_multipart
+from pyrobusta.protocol.http import (
+ HttpEngine,
+ enable_optional_features,
+)
+from pyrobusta.utils import config
+
+#################################################
+# Test helpers
+#################################################
+
+
+def garbage_collect(coroutine):
+ async def decorated(*args, **kwargs):
+ gc.collect()
+ await coroutine(*args, **kwargs)
+ gc.collect()
+
+ return decorated
+
+
+def test_assert(name, actual, expected):
+ print(f"Test {name}: ", end="")
+ if actual == expected:
+ print("OK")
+ else:
+ print("Fail")
+ raise AssertionError(f"{actual} != {expected}")
+
+
+async def send_request(request, tls=False):
+ port = (
+ http_server.HttpServer.LISTEN_PORT_HTTPS
+ if tls
+ else http_server.HttpServer.LISTEN_PORT_HTTP
+ )
+
+ ctx = None
+ if tls:
+ # Disable certificate verification due to self-signed cert
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ ctx.verify_mode = ssl.CERT_NONE
+
+ reader, writer = await asyncio.open_connection("127.0.0.1", port, ssl=ctx)
+ writer.write(request)
+ await writer.drain()
+
+ to_read = True
+ response = b""
+ while to_read:
+ response_part = await reader.read(1024)
+ response += response_part
+ to_read = len(response_part)
+ writer.close()
+ return response
+
+
+def multipart_response(num_responses):
+ i = 0
+
+ def response_generator():
+ nonlocal i
+ i += 1
+ if i > num_responses:
+ return None
+ return "text/plain", b"Response %s" % i
+
+ return response_generator
+
+
+#################################################
+# Test driver
+#################################################
+
+
+@HttpEngine.route("/test/multipart", "GET")
+def multipart_callback(http_ctx, _):
+ part_count = int(http_ctx.headers["x-part-count"])
+ return "multipart/form-data", multipart_response(part_count)
+
+
+async def start_server():
+ """
+ Start an HTTP server as a background task
+ """
+ server = http_server.HttpServer()
+ server_task = asyncio.create_task(server.start_socket_server())
+ await asyncio.sleep_ms(100)
+ return server, server_task
+
+
+@garbage_collect
+async def test_multipart_response(tls_enabled):
+ setup_config(tls_enabled=tls_enabled)
+ server, server_task = await start_server()
+
+ # Test: 1 part
+ plain_response = await send_request(
+ b"GET /test/multipart HTTP/1.1\r\n"
+ b"Host: localhost\r\nX-Part-Count: 1\r\n\r\n",
+ tls_enabled,
+ )
+ test_assert(
+ f"http{"s" if tls_enabled else ""} response contains 1 part",
+ b"Response 1" in plain_response,
+ True,
+ )
+
+ # Test: 10 parts
+ plain_response = await send_request(
+ b"GET /test/multipart HTTP/1.1\r\n"
+ b"Host: localhost\r\nX-Part-Count: 10\r\n\r\n",
+ tls_enabled,
+ )
+ test_assert(
+ f"http{"s" if tls_enabled else ""} response contains 10 parts",
+ [b"Response %s" % i in plain_response for i in range(1, 11)],
+ [True] * 10,
+ )
+
+ server_task.cancel()
+ await server.terminate()
+
+
+#################################################
+# Test methods
+#################################################
+
+
+def setup_config(tls_enabled=False):
+ http_server.HttpServer.LISTEN_PORT_HTTP = 8080
+ http_server.HttpServer.LISTEN_PORT_HTTPS = 4443
+
+ config_idx = config.CONFIG_CACHE.index("log_level")
+ config.CONFIG_CACHE[config_idx + 1] = str("warning")
+ config_idx = config.CONFIG_CACHE.index("http_multipart")
+ config.CONFIG_CACHE[config_idx + 1] = "True"
+ config_idx = config.CONFIG_CACHE.index("tls")
+ config.CONFIG_CACHE[config_idx + 1] = str(tls_enabled)
+ enable_optional_features()
+
+
+def test_registration():
+ test_assert(
+ "multipart endpoint registration",
+ multipart_callback,
+ HttpEngine._get_callback(b"/test/multipart", b"GET"),
+ )
+
+
+def test_multipart_patches():
+ setup_config()
+ test_assert(
+ "multipart state machine patches",
+ http_multipart._start_multipart_parser_st,
+ HttpEngine._start_multipart_parser_st,
+ )
+
+
+def test_main():
+ test_registration()
+ test_multipart_patches()
+ asyncio.run(test_multipart_response(tls_enabled=False))
+ asyncio.run(test_multipart_response(tls_enabled=True))
+
+
+test_main()
diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py
index 6886fcc..10ee8a8 100644
--- a/tests/unit/test_helpers.py
+++ b/tests/unit/test_helpers.py
@@ -26,6 +26,7 @@ def test_path_normalization_virtual_root(self):
cwd = getcwd()
for case in (
("", ""),
+ ("/", f"{cwd}"),
("/path/to/resource", f"{cwd}/path/to/resource"),
("/path/to/resource/", f"{cwd}/path/to/resource"),
("///path///to///resource///", f"{cwd}/path/to/resource"),
@@ -45,6 +46,7 @@ def test_path_normalization_host_root(self, _):
"""
for case in (
("", ""),
+ ("/", "/"),
("/path/to/resource", "/path/to/resource"),
("/path/to/resource/", "/path/to/resource"),
("///path///to///resource///", "/path/to/resource"),
diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py
index ee04f02..9afec41 100644
--- a/tests/unit/test_http.py
+++ b/tests/unit/test_http.py
@@ -30,28 +30,26 @@ def setUp(self):
},
)
self.patcher.start()
-
- for key, value in self.config.items():
- self.set_mock_config(key, value)
+ self.set_mock_config()
# Load your web and buffer modules
self.helpers_module = load_module("pyrobusta/utils/helpers.py")
buffer_module = load_module("pyrobusta/stream/buffer.py")
- web_module = load_module("pyrobusta/protocol/http.py")
- web_module.enable_optional_features()
+ self.web_module = load_module("pyrobusta/protocol/http.py")
+ self.web_module.enable_optional_features()
- self.engine = web_module.HttpEngine()
+ self.engine = self.web_module.HttpEngine()
self.rx = buffer_module.SlidingBuffer(bytearray(1024))
self.tx = buffer_module.SlidingBuffer(bytearray(1024))
def tearDown(self):
self.patcher.stop()
- def set_mock_config(self, key, value):
+ def set_mock_config(self):
def side_effect(input_arg, *_, **__):
- if input_arg == key:
- return value
- raise ValueError(f"Unexpected argument: {input_arg}")
+ if input_arg in self.config:
+ return self.config[input_arg]
+ raise ValueError(f"Unexpected config key: {input_arg}")
self.mock_utils_config.get_config.side_effect = side_effect
@@ -63,7 +61,7 @@ class TestWebStateMachine(TestWebStateMachineBase):
@classmethod
def setUpClass(cls):
- cls.config = {}
+ cls.config = {"http_multipart": "False", "http_serve_files": "False"}
def test_status_parsing_valid(self):
request = b"GET /index.html HTTP/1.1\r\nContent-Length:10"
@@ -149,6 +147,18 @@ def test_header_parsing_incomplete_header(self):
self.assertEqual(self.engine.status_code, 400)
self.assertEqual(self.engine.state, None)
+ def test_header_parsing_error(self):
+ for case in (
+ b"",
+ b":",
+ b": value",
+ b" leading-space: value",
+ b"space in header name: value",
+ b"new-line-in-header:\nvalue",
+ ):
+ with self.assertRaises(self.web_module.HeaderParsingError):
+ self.engine._parse_headers(case)
+
def test_routing_unsupported_method(self):
self.engine.state = self.engine._route_request_st
self.engine.url = b"/api/test"
@@ -312,6 +322,26 @@ def test_empty_or_missing_url_encoded_query_parameter(self):
with self.assertRaises(KeyError):
self.engine.get_url_encoded_query_param(self.engine.query, "param3")
+ def test_overlapping_url_encoded_query_parameter(self):
+ request = b"GET /api/test?data=value1&ta=value2&a=value3 HTTP/1.1\r\n"
+
+ for i in range(len(request)):
+ self.rx.write(request[i : i + 1])
+ self.engine.state(self.rx, self.tx)
+
+ self.assertEqual(
+ self.engine.get_url_encoded_query_param(self.engine.query, "data"),
+ "value1",
+ )
+ self.assertEqual(
+ self.engine.get_url_encoded_query_param(self.engine.query, "ta"),
+ "value2",
+ )
+ self.assertEqual(
+ self.engine.get_url_encoded_query_param(self.engine.query, "a"),
+ "value3",
+ )
+
def test_chunked_transfer_encoding_valid(self):
self.engine.url = b"/api/test"
self.engine.method = b"GET"
@@ -383,7 +413,7 @@ def test_chunked_transfer_encoding_chunk_incomplete(self):
self.assertEqual(self.engine.state, self.engine._recv_chunk_st)
def test_path_serving_list(self):
- self.set_mock_config("http_served_paths", "/path/to/dir1 /path/to/dir2")
+ self.config["http_served_paths"] = "/path/to/dir1 /path/to/dir2"
self.assertEqual(self.engine.is_norm_path_served(""), False)
self.assertEqual(self.engine.is_norm_path_served("/"), False)
self.assertEqual(self.engine.is_norm_path_served("/path/to/dir1"), True)
@@ -395,13 +425,13 @@ def test_path_serving_list(self):
self.assertEqual(self.engine.is_norm_path_served("/path/to"), False)
def test_path_serving_root(self):
- self.set_mock_config("http_served_paths", "/")
+ self.config["http_served_paths"] = "/"
self.assertEqual(self.engine.is_norm_path_served(""), True)
self.assertEqual(self.engine.is_norm_path_served("/"), True)
self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), True)
def test_path_serving_none(self):
- self.set_mock_config("http_served_paths", "")
+ self.config["http_served_paths"] = ""
self.assertEqual(self.engine.is_norm_path_served(""), False)
self.assertEqual(self.engine.is_norm_path_served("/"), False)
self.assertEqual(self.engine.is_norm_path_served("/path/to/served"), False)
@@ -414,14 +444,19 @@ class TestMultipartStateMachine(TestWebStateMachineBase):
@classmethod
def setUpClass(cls):
- cls.config = {"http_multipart": "True"}
+ cls.config = {"http_multipart": "True", "http_serve_files": "True"}
def test_multipart_parser(self):
for case in [
+ ({}, None),
(
{"content-type": 'multipart/form-data; boundary ="test-boundary"'},
"test-boundary",
),
+ (
+ {"content-type": 'multipart/form-data; boundary =" test-boundary "'},
+ " test-boundary ",
+ ),
(
{"content-type": "multipart/form-data ;boundary= test-boundary "},
"test-boundary",
@@ -432,16 +467,18 @@ def test_multipart_parser(self):
),
]:
with self.subTest(headers=case[0], expected=case[1]):
- self.assertEqual(self.engine._is_multipart(case[0]), case[1])
+ self.assertEqual(self.engine._get_mp_boundary(case[0]), case[1])
for case in [
- {},
{"content-type": "multipart/form-data"},
{"content-type": 'multipart/form-data;boundary=""'},
{"content-type": "multipart/form-data;boundary=\r\n"},
+ {"content-type": 'multipart/form-data;boundary="missing-quote'},
+ {"content-type": 'multipart/form-data;boundary=missing-quote"'},
]:
- with self.subTest(headers=case, expected=None):
- self.assertEqual(self.engine._is_multipart(case), None)
+ with self.subTest(headers=case):
+ with self.assertRaises(self.web_module.HeaderParsingError):
+ self.engine._get_mp_boundary(case)
def test_multipart_receiver_valid(self):
self.engine.state = self.engine._start_multipart_parser_st