Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
PYROBUSTA_VERSION := 0.1.0
PYROBUSTA_VERSION := 0.2.0
DEVICE ?= u0

SRC_DIR := src
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ A lightweight HTTP server library for MicroPython designed for constrained embed
- Bounded-copy memory footprint
- Finite-state-machine parser with linear sliding buffer
- Robust byte-stream handling
- Query parameter parsing with percent encoding support
- TLS support

## Current limitation
- Query parameter parsing is not yet implemented

# Prerequisites

## Setup virtual environment
Expand Down
Binary file modified dist/pyrobusta/bindings/socket_http.mpy
Binary file not shown.
Binary file modified dist/pyrobusta/protocol/http.mpy
Binary file not shown.
Binary file modified dist/pyrobusta/protocol/http_multipart.mpy
Binary file not shown.
Binary file modified dist/pyrobusta/utils/config.mpy
Binary file not shown.
25 changes: 24 additions & 1 deletion example/mem_usage/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,34 @@


@HttpEngine.route("/mem-usage", "GET")
def mem_usage(*_):
def mem_usage(http_ctx, _):
collect()
free = mem_free()
used = mem_alloc()
usage_percentage = 100 * used / (free + used)

if http_ctx.query:
value_format = http_ctx.get_url_encoded_query_param(
http_ctx.query, "format", "bytes"
)
if value_format not in ("%", "bytes"):
raise ValueError("invalid format")

selector = http_ctx.get_url_encoded_query_param(http_ctx.query, "key", "")
if selector == "free":
if value_format == "%":
free = 100 * free / (used + free)
return "text/plain", f"Free [{value_format}]: {free}\n"
if selector == "used":
if value_format == "%":
used = 100 * used / (used + free)
return "text/plain", f"Used [{value_format}]: {used}\n"
if selector == "total":
return "text/plain", f"Total [bytes]: {used + free}\n"

if selector:
raise ValueError("invalid key")

return "text/plain", (
f"Currently used: {usage_percentage:.2f}%\n"
f"Free [bytes]: {free}\n"
Expand Down
4 changes: 2 additions & 2 deletions example/mip_repo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def append_package_files(dir, package_files, host_name, protocol):


@HttpEngine.route("/pyrobusta/package.json", "GET")
def self_serve_mip_package(headers, _):
def self_serve_mip_package(http_ctx, _):
package_files = {"version": config.PYROBUSTA_VERSION, "deps": [], "urls": []}
tls_enabled = config.get_config("tls").lower() == "true"
server_addr = headers["host"]
server_addr = http_ctx.headers["host"]
if ":" not in server_addr:
port = (
http_server.HttpServer.LISTEN_PORT_HTTPS
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"version": "0.1.0",
"version": "0.2.0",
"urls": [
[
"pyrobusta/transport/socket.mpy",
Expand Down
119 changes: 94 additions & 25 deletions src/pyrobusta/protocol/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ class HttpEngine:
"headers",
"method",
"url",
"query",
"content_length_cnt",
"mp_boundary",
"mp_first_part",
"mp_last_part",
"mp_delimiter",
"mp_closing_delimiter",
)

ENDPOINTS = {}
ENDPOINTS = [] # (endpoint, callback, method)
RESP_HEADERS = (
200,
b"200 OK",
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(self):
self.headers = {}
self.method = None
self.url = None
self.query = None
self.content_length_cnt = 0

# [Multipart state]
Expand All @@ -144,14 +147,16 @@ def register(
"""
endpoint = endpoint.encode(cls.ASCII)
method = method.encode(cls.ASCII)
if not endpoint in cls.ENDPOINTS:
cls.ENDPOINTS[endpoint] = {}
endpoint_exists = cls._get_callback(endpoint, method) is not None

if method not in cls.METHODS:
raise ValueError(f"method must be one of {cls.METHODS}")
cls.ENDPOINTS[endpoint][method] = callback
if endpoint_exists:
raise ValueError("endpoint exists")
cls.ENDPOINTS.append((endpoint, callback, method))

@staticmethod
def route(endpoint, method):
def route(endpoint: str, method: str):
"""
Decorator for registering endpoint callback functions.
"""
Expand All @@ -166,15 +171,66 @@ def decorator(func):
# Static helpers for parsing
# =========================================

@staticmethod
def percent_decode(s: str):
"""Decode percent-encoded input"""
out = []
i = 0
while i < len(s):
if s[i] == "%" and i + 2 < len(s):
out.append(chr(int(s[i + 1 : i + 3], 16)))
i += 3
else:
out.append(s[i])
i += 1
return "".join(out)

@staticmethod
def get_url_encoded_query_param(query: str, key: str, default: str = None):
"""
Parse query and return the value belonging to a key
according to x-www-form-urlencoded
:param query: query part
:param key: key to parse from the query
:param default: default value to return when key is not present
"""
idx_start = query.find(key + "=")
if idx_start != -1:
idx_end = -1
idx_end = query.find("&", idx_start)
if idx_start > -1:
if idx_end > -1:
return query[idx_start + len(key) + 1 : idx_end]
return query[idx_start + len(key) + 1 :]
if default is None:
raise KeyError()
return default

@staticmethod
def _lookup(tuple_, key):
idx = tuple_.index(key)
return tuple_[idx + 1]

@classmethod
def _get_callback(cls, endpoint, method):
for e in cls.ENDPOINTS:
if endpoint == e[0] and method == e[2]:
return e[1]

@classmethod
def _get_status(cls, status_code):
idx = cls.RESP_HEADERS.index(status_code)
return cls.RESP_HEADERS[idx + 1]
def _has_endpoint(cls, endpoint):
for e in cls.ENDPOINTS:
if endpoint == e[0]:
return True
return False

@classmethod
def _get_content_type(cls, extension):
idx = cls.CONTENT_TYPES.index(extension)
return cls.CONTENT_TYPES[idx + 1]
def _supported_methods(cls, endpoint):
supported_methods = []
for method in cls.METHODS:
if cls._get_callback(endpoint, method) is not None:
supported_methods.append(method)
return supported_methods

@classmethod
def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
Expand Down Expand Up @@ -263,7 +319,7 @@ def _write_response_head(self, tx, content_length: int = 0):
tx.consume()
tx.write(self.version)
tx.write(b" ")
tx.write(self._get_status(self.status_code))
tx.write(self._lookup(self.RESP_HEADERS, self.status_code))
if content_length is not None:
tx.write(b"\r\n")
tx.write(b"content-length: %s" % str(content_length).encode(self.ASCII))
Expand Down Expand Up @@ -364,7 +420,13 @@ def _parse_request_line_st(self, rx, tx):
self.on_client_error(tx, self.BAD_REQUEST_ERROR)
return
self.method = status_parts[0]
self.url = status_parts[1]
url_parts = status_parts[1].split(b"?", 1)
self.url = url_parts[0]
self.query = (
""
if len(url_parts) == 1
else self.percent_decode(url_parts[1].decode(self.ASCII))
)
self.version = status_parts[2]
if self.method not in self.METHODS:
self.on_method_not_allowed(tx)
Expand Down Expand Up @@ -398,13 +460,16 @@ def _route_request_st(self, _, tx):
State for routing requests
- supported ways: static resources, endpoint callback functions
"""
if self.url in self.ENDPOINTS and (
self.method in self.ENDPOINTS[self.url]
if self._has_endpoint(self.url) and (
self._get_callback(self.url, self.method) is not None
or self.method == self.OPTIONS
or (self.method == self.HEAD and self.GET in self.ENDPOINTS[self.url])
or (
self.method == self.HEAD
and self._get_callback(self.url, self.GET) is not None
)
):
if self.method == self.OPTIONS:
supported_methods = list(self.ENDPOINTS[self.url].keys())
supported_methods = self._supported_methods(self.url)
self._set_response_header(b"allow", b", ".join(supported_methods))
self.terminate(204, None)
self._write_response_head(tx, None)
Expand All @@ -421,12 +486,16 @@ def _route_request_st(self, _, tx):
else:
self.state = self._app_endpoint_st
return
if self.url in self.ENDPOINTS and self.method not in self.ENDPOINTS[self.url]:
supported_methods = list(self.ENDPOINTS[self.url].keys())

if (
self._has_endpoint(self.url)
and self._get_callback(self.method, self.url) is None
):
supported_methods = self._supported_methods(self.url)
self._set_response_header(b"allow", b", ".join(supported_methods))
self.on_method_not_allowed(tx)
return
if self.method == self.GET:
if self.method in (self.GET, self.HEAD):
resource = b"index.html" if not self.url else self.url
self.state = lambda _rx, _tx: self._send_file_st(_rx, _tx, resource)
return
Expand All @@ -443,10 +512,10 @@ def _recv_payload(self, rx, tx):
def _app_endpoint_st(self, rx, tx):
"""Process a request by registered callback functions"""
method = self.GET if self.method == self.HEAD else self.method
callback = self.ENDPOINTS[self.url][method]
callback = self._get_callback(self.url, method)
if self._has_payload():
self.state = None
dtype, data = callback(self.headers, bytes(rx.peek()))
dtype, data = callback(self, bytes(rx.peek()))
dtype = dtype.encode(self.ASCII)
else:
if not callable(callback):
Expand All @@ -455,7 +524,7 @@ def _app_endpoint_st(self, rx, tx):
_rx, _tx, callback.encode(HttpEngine.ASCII)
)
return
dtype, data = callback(self.headers, b"")
dtype, data = callback(self, b"")
dtype = dtype.encode(self.ASCII)
self._set_response_header(b"content-type", dtype)
if dtype == b"image/jpeg":
Expand Down Expand Up @@ -500,9 +569,9 @@ def _send_file_st(self, _, tx, web_resource: bytes):
norm_path = b"/".join(parts)

try:
content_type = self._get_content_type(extension)
content_type = self._lookup(self.CONTENT_TYPES, extension)
except ValueError:
content_type = self._get_content_type(b"raw")
content_type = self._lookup(self.CONTENT_TYPES, b"raw")
try:
self._set_response_header(
b"content-length", str(stat(norm_path)[6]).encode(HttpEngine.ASCII)
Expand Down
8 changes: 5 additions & 3 deletions src/pyrobusta/protocol/http_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def _parse_complete_part_st(self, rx, tx):
except http.HeaderParsingError:
self.on_client_error(tx, http.HttpEngine.HEADER_ERROR)
return
callback = http.HttpEngine.ENDPOINTS[self.url][self.method]
callback = http.HttpEngine._get_callback(self.url, self.method)
# Process complete part
if not is_final:
callback(part_headers, part_body, first=self.mp_first_part, last=False)
callback(self, (part_headers, part_body))
if rx.peek(len(self.mp_delimiter)) != self.mp_delimiter:
self.on_client_error(tx, http.HttpEngine.MULTIPART_BOUNDARY_ERROR)
return
Expand All @@ -129,7 +129,8 @@ def _parse_complete_part_st(self, rx, tx):
):
self.on_client_error(tx, http.HttpEngine.CONTENT_LENGTH_ERROR)
return
dtype, data = callback(part_headers, part_body, first=self.mp_first_part, last=True)
self.mp_last_part = True
dtype, data = callback(self, (part_headers, part_body))
self.terminate(200, dtype.encode(http.HttpEngine.ASCII))
return self._generate_response(tx, data)

Expand All @@ -145,6 +146,7 @@ def apply_patches():
def new_init(self, *args, **kwargs):
orig_init(self, *args, **kwargs)
self.mp_first_part = True
self.mp_last_part = False
self.mp_delimiter = None
self.mp_closing_delimiter = None

Expand Down
4 changes: 2 additions & 2 deletions src/pyrobusta/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Values can be encapsulated by single or double quotes.
"""

PYROBUSTA_VERSION = "0.1.0"
PYROBUSTA_VERSION = "0.2.0"
CONFIG_LOADED = False
CONFIG_LOCATION = "pyrobusta.env"
CONFIG_CACHE = [
Expand Down Expand Up @@ -36,7 +36,7 @@ def read_config(config=CONFIG_LOCATION):
with open(config, encoding="utf-8") as conf:
for line in conf.read().splitlines("\n"):
key = line.split("=")[0].strip()
if key.startswith("#"):
if key.startswith("#") or not line.strip():
continue
value = line.split("=")[1].strip().strip("'").strip('"')
if key and value:
Expand Down
18 changes: 9 additions & 9 deletions tests/functional/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,22 @@ def response_generator():


@HttpEngine.route("/test/simple", "GET")
def simple_callback(headers, body):
if headers["accept"] == "text/plain":
def simple_callback(http_ctx, _):
if http_ctx.headers["accept"] == "text/plain":
return "text/plain", "Test response\n"
elif headers["accept"] == "application/json":
elif http_ctx.headers["accept"] == "application/json":
return "application/json", '{"response": "Test response"}'
raise ValueError("Unhandled content-type")


@HttpEngine.route("/test/multipart", "GET")
def multipart_callback(headers, body):
part_count = int(headers["x-part-count"])
def multipart_callback(http_ctx, _):
part_count = int(http_ctx.headers["x-part-count"])
return "multipart/form-data", ("text/plain", multipart_response(part_count))


@HttpEngine.route("/test/busy", "POST")
def busy_callback(headers, body):
def busy_callback(*_):
raise ServerBusyError()


Expand Down Expand Up @@ -211,19 +211,19 @@ def test_registration():
test_assert(
"simple endpoint registration",
simple_callback,
HttpEngine.ENDPOINTS[b"/test/simple"][b"GET"],
HttpEngine._get_callback(b"/test/simple", b"GET"),
)

test_assert(
"multipart endpoint registration",
multipart_callback,
HttpEngine.ENDPOINTS[b"/test/multipart"][b"GET"],
HttpEngine._get_callback(b"/test/multipart", b"GET"),
)

test_assert(
"busy endpoint registration",
busy_callback,
HttpEngine.ENDPOINTS[b"/test/busy"][b"POST"],
HttpEngine._get_callback(b"/test/busy", b"POST"),
)


Expand Down
Loading
Loading