Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 7 additions & 31 deletions src/pyrobusta/protocol/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,30 +94,6 @@ class HttpEngine:
505,
b"505 Version Not Supported",
)
CONTENT_TYPES = (
b"raw",
b"application/octet-stream",
b"html",
b"text/html",
b"css",
b"text/css",
b"js",
b"application/javascript",
b"json",
b"application/json",
b"ico",
b"image/x-icon",
b"jpeg",
b"image/jpeg",
b"jpg",
b"image/jpeg",
b"png",
b"image/png",
b"txt",
b"text/plain",
b"gif",
b"image/gif",
)

DELETE = b"DELETE"
GET = b"GET"
Expand Down Expand Up @@ -319,7 +295,10 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]:
value = int(value.strip())
else:
value = value.strip().decode("ascii")
headers[name] = value
if name not in headers and value:
headers[name] = value
elif value:
headers[name] += ", " + value # Combined field value
return headers

@staticmethod
Expand Down Expand Up @@ -412,7 +391,7 @@ def write_response_head(self, tx):
def set_response_body(
self,
body: bytes | str | dict | tuple | list,
content_type: bytes = b"text/plain",
content_type: str = "text/plain",
):
"""
Serialize and wrap the response body with a BytesIO
Expand All @@ -436,7 +415,7 @@ def set_response_body(
self.set_response_header(
b"content-length", str(len(body_encoded)).encode("ascii")
)
self.set_response_header(b"content-type", content_type)
self.set_response_header(b"content-type", content_type.encode("ascii"))
if self.method != self.HEAD:
self.resp_handler = BytesIO(body_encoded)

Expand Down Expand Up @@ -702,7 +681,6 @@ def _app_endpoint_st(self, rx):
dtype, data = callback(
self, bytes(rx.peek(self.headers["content-length"]))
)
dtype = dtype.encode("ascii")
else:
if not callable(callback):
# Handle as a file path
Expand All @@ -711,10 +689,8 @@ def _app_endpoint_st(self, rx):
)
return
dtype, data = callback(self, b"")
dtype = dtype.encode("ascii")
self.set_response_header(b"content-type", dtype)

if dtype.startswith(b"multipart/"):
if dtype.startswith("multipart/"):
self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype)
return

Expand Down
29 changes: 27 additions & 2 deletions src/pyrobusta/protocol/http_file_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,31 @@
from pyrobusta.protocol import http
from pyrobusta.utils.helpers import normalize_path, add_method

CONTENT_TYPES = (
b"raw",
b"application/octet-stream",
b"html",
b"text/html",
b"css",
b"text/css",
b"js",
b"application/javascript",
b"json",
b"application/json",
b"ico",
b"image/x-icon",
b"jpeg",
b"image/jpeg",
b"jpg",
b"image/jpeg",
b"png",
b"image/png",
b"txt",
b"text/plain",
b"gif",
b"image/gif",
)


def _send_file_st(self, _, file_path: bytes):
"""
Expand Down Expand Up @@ -38,9 +63,9 @@ def _send_file_st(self, _, file_path: bytes):
self.terminate(404, True)
return
try:
content_type = self._lookup(self.CONTENT_TYPES, extension)
content_type = self._lookup(CONTENT_TYPES, extension)
except ValueError:
content_type = self._lookup(self.CONTENT_TYPES, b"raw")
content_type = self._lookup(CONTENT_TYPES, b"raw")
try:
self.set_response_header(
b"content-length", str(stat(norm_path)[6]).encode("ascii")
Expand Down
9 changes: 5 additions & 4 deletions src/pyrobusta/protocol/http_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyrobusta.utils.helpers import add_method


def _generate_multipart_response(self, _, callback: callable, dtype: bytes):
def _generate_multipart_response(self, _, callback: callable, dtype: str):
"""
Generate multipart response depening on the exact content type.
The callback function is called without arguments, and it must return bytes-like objects.
Expand All @@ -19,7 +19,9 @@ def _generate_multipart_response(self, _, callback: callable, dtype: bytes):
raise ValueError("Invalid response handler")
self.terminate(200, True)
boundary = self.MULTIPART_BOUNDARY
self.set_response_header(b"content-type", dtype + b"; boundary=" + boundary)
self.set_response_header(
b"content-type", dtype.encode("ascii") + b"; boundary=" + boundary
)
if self.method != self.HEAD:
self.resp_handler = self._multipart_wrapper_factory(callback, boundary)

Expand Down Expand Up @@ -135,9 +137,8 @@ def _parse_complete_part_st(self, rx):
raise http.InvalidContentLength()
self.mp_is_last = True
dtype, data = callback(self, (part_headers, part_body))
self.set_response_header(b"content-type", dtype.encode("ascii"))
self.terminate(200, True)
self.set_response_body(data)
self.set_response_body(data, dtype)


def apply_patches():
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,19 @@ def test_header_parsing_error(self):
with self.assertRaises(self.http_module.InvalidHeaders):
self.engine._parse_headers(case)

def test_header_parsing_combined(self):
for case in (
(
b"field-name: value1\r\nfield-name: value2",
{"field-name": "value1, value2"},
),
(
b"field-name: \r\nfield-name: value1\r\nfield-name:\r\nfield-name: value2",
{"field-name": "value1, value2"},
),
):
self.assertEqual(self.engine._parse_headers(case[0]), case[1])

def test_routing_unsupported_method(self):
self.engine.state = self.engine._route_request_st
self.engine.url = b"/api/test"
Expand Down
Loading