diff --git a/src/pyrobusta/protocol/http.py b/src/pyrobusta/protocol/http.py index 93d55ab..ccf4e2b 100644 --- a/src/pyrobusta/protocol/http.py +++ b/src/pyrobusta/protocol/http.py @@ -94,30 +94,6 @@ class HttpEngine: 505, b"505 Version Not Supported", ) - CONTENT_TYPES = ( - b"raw", - b"application/octet-stream", - b"html", - b"text/html", - b"css", - b"text/css", - b"js", - b"application/javascript", - b"json", - b"application/json", - b"ico", - b"image/x-icon", - b"jpeg", - b"image/jpeg", - b"jpg", - b"image/jpeg", - b"png", - b"image/png", - b"txt", - b"text/plain", - b"gif", - b"image/gif", - ) DELETE = b"DELETE" GET = b"GET" @@ -319,7 +295,10 @@ def _parse_headers(cls, raw_headers: memoryview) -> dict[str, str | int]: value = int(value.strip()) else: value = value.strip().decode("ascii") - headers[name] = value + if name not in headers and value: + headers[name] = value + elif value: + headers[name] += ", " + value # Combined field value return headers @staticmethod @@ -412,7 +391,7 @@ def write_response_head(self, tx): def set_response_body( self, body: bytes | str | dict | tuple | list, - content_type: bytes = b"text/plain", + content_type: str = "text/plain", ): """ Serialize and wrap the response body with a BytesIO @@ -436,7 +415,7 @@ def set_response_body( self.set_response_header( b"content-length", str(len(body_encoded)).encode("ascii") ) - self.set_response_header(b"content-type", content_type) + self.set_response_header(b"content-type", content_type.encode("ascii")) if self.method != self.HEAD: self.resp_handler = BytesIO(body_encoded) @@ -702,7 +681,6 @@ def _app_endpoint_st(self, rx): dtype, data = callback( self, bytes(rx.peek(self.headers["content-length"])) ) - dtype = dtype.encode("ascii") else: if not callable(callback): # Handle as a file path @@ -711,10 +689,8 @@ def _app_endpoint_st(self, rx): ) return dtype, data = callback(self, b"") - dtype = dtype.encode("ascii") - self.set_response_header(b"content-type", dtype) - if dtype.startswith(b"multipart/"): + if dtype.startswith("multipart/"): self.state = lambda _rx: self._generate_multipart_response(_rx, data, dtype) return diff --git a/src/pyrobusta/protocol/http_file_server.py b/src/pyrobusta/protocol/http_file_server.py index f128d37..a8c278b 100644 --- a/src/pyrobusta/protocol/http_file_server.py +++ b/src/pyrobusta/protocol/http_file_server.py @@ -9,6 +9,31 @@ from pyrobusta.protocol import http from pyrobusta.utils.helpers import normalize_path, add_method +CONTENT_TYPES = ( + b"raw", + b"application/octet-stream", + b"html", + b"text/html", + b"css", + b"text/css", + b"js", + b"application/javascript", + b"json", + b"application/json", + b"ico", + b"image/x-icon", + b"jpeg", + b"image/jpeg", + b"jpg", + b"image/jpeg", + b"png", + b"image/png", + b"txt", + b"text/plain", + b"gif", + b"image/gif", +) + def _send_file_st(self, _, file_path: bytes): """ @@ -38,9 +63,9 @@ def _send_file_st(self, _, file_path: bytes): self.terminate(404, True) return try: - content_type = self._lookup(self.CONTENT_TYPES, extension) + content_type = self._lookup(CONTENT_TYPES, extension) except ValueError: - content_type = self._lookup(self.CONTENT_TYPES, b"raw") + content_type = self._lookup(CONTENT_TYPES, b"raw") try: self.set_response_header( b"content-length", str(stat(norm_path)[6]).encode("ascii") diff --git a/src/pyrobusta/protocol/http_multipart.py b/src/pyrobusta/protocol/http_multipart.py index fbc0862..7e8cb5d 100644 --- a/src/pyrobusta/protocol/http_multipart.py +++ b/src/pyrobusta/protocol/http_multipart.py @@ -8,7 +8,7 @@ from pyrobusta.utils.helpers import add_method -def _generate_multipart_response(self, _, callback: callable, dtype: bytes): +def _generate_multipart_response(self, _, callback: callable, dtype: str): """ Generate multipart response depening on the exact content type. The callback function is called without arguments, and it must return bytes-like objects. @@ -19,7 +19,9 @@ def _generate_multipart_response(self, _, callback: callable, dtype: bytes): raise ValueError("Invalid response handler") self.terminate(200, True) boundary = self.MULTIPART_BOUNDARY - self.set_response_header(b"content-type", dtype + b"; boundary=" + boundary) + self.set_response_header( + b"content-type", dtype.encode("ascii") + b"; boundary=" + boundary + ) if self.method != self.HEAD: self.resp_handler = self._multipart_wrapper_factory(callback, boundary) @@ -135,9 +137,8 @@ def _parse_complete_part_st(self, rx): raise http.InvalidContentLength() self.mp_is_last = True dtype, data = callback(self, (part_headers, part_body)) - self.set_response_header(b"content-type", dtype.encode("ascii")) self.terminate(200, True) - self.set_response_body(data) + self.set_response_body(data, dtype) def apply_patches(): diff --git a/tests/unit/test_http.py b/tests/unit/test_http.py index 949bc2f..96be22d 100644 --- a/tests/unit/test_http.py +++ b/tests/unit/test_http.py @@ -170,6 +170,19 @@ def test_header_parsing_error(self): with self.assertRaises(self.http_module.InvalidHeaders): self.engine._parse_headers(case) + def test_header_parsing_combined(self): + for case in ( + ( + b"field-name: value1\r\nfield-name: value2", + {"field-name": "value1, value2"}, + ), + ( + b"field-name: \r\nfield-name: value1\r\nfield-name:\r\nfield-name: value2", + {"field-name": "value1, value2"}, + ), + ): + self.assertEqual(self.engine._parse_headers(case[0]), case[1]) + def test_routing_unsupported_method(self): self.engine.state = self.engine._route_request_st self.engine.url = b"/api/test"