diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..b08233e --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,22 @@ +FROM ubuntu:24.04 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential cmake git ca-certificates wget curl pkg-config \ + python3 python3-pip python3-venv python3-dev ca-certificates \ + libssl-dev libffi-dev && \ + rm -rf /var/lib/apt/lists/* + +# Install liboqs from source +WORKDIR /opt +RUN git clone --depth 1 https://github.com/open-quantum-safe/liboqs.git && \ + mkdir -p liboqs/build && cd liboqs/build && \ + cmake -DCMAKE_BUILD_TYPE=Release .. && \ + make -j"$(nproc)" && make install + +# Ensure pip is upgraded and install Python oqs wrapper +RUN python3 -m pip install --upgrade pip setuptools wheel && \ + python3 -m pip install oqs + +WORKDIR /workspace diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..8d408ef --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,11 @@ +{ + "name": "Mohawk Inference Devcontainer", + "build": { + "dockerfile": "Dockerfile" + }, + "workspaceFolder": "/workspace", + "settings": {}, + "extensions": [], + "forwardPorts": [8003], + "postCreateCommand": "./.devcontainer/post_create.sh" +} diff --git a/.devcontainer/post_create.sh b/.devcontainer/post_create.sh new file mode 100644 index 0000000..6f96305 --- /dev/null +++ b/.devcontainer/post_create.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +set -euo pipefail + +echo "Running devcontainer post-create: install build deps and liboqs" +# install system deps (attempt apt, then apk) +if command -v apt-get >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y build-essential cmake git python3-dev python3-pip pkg-config +elif command -v apk >/dev/null 2>&1; then + sudo apk add --no-cache build-base cmake git python3 python3-dev py3-pip pkgconfig +else + echo "Unknown package manager; please install build tools (cmake, make, git, python3-dev) manually" +fi + +CACHE_DIR="$HOME/.cache/liboqs" +mkdir -p "$CACHE_DIR" +if [ ! -d "$CACHE_DIR/liboqs" ]; then + git clone --depth 1 https://github.com/open-quantum-safe/liboqs.git "$CACHE_DIR/liboqs" +fi + +pushd "$CACHE_DIR/liboqs" +mkdir -p build && cd build +cmake -DCMAKE_BUILD_TYPE=Release .. +make -j"$(nproc)" +if command -v sudo >/dev/null 2>&1; then + sudo make install +else + make install +fi +popd + +# ensure pip and install oqs python package +python3 -m pip install --upgrade pip || true +python3 -m pip install oqs || true + +echo "post-create complete" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..eb397ce --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,47 @@ +name: CI + +on: + push: + branches: [ main, feat/* ] + pull_request: + branches: [ main ] + workflow_dispatch: + inputs: + build_liboqs: + description: Build liboqs from source before running tests + required: false + default: false + type: boolean + +jobs: + test: + runs-on: ubuntu-latest + env: + OQS_INSTALL_PATH: /usr/local + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Install system deps + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake libssl-dev pkg-config + - name: Optionally build liboqs + if: github.event_name == 'workflow_dispatch' && inputs.build_liboqs || vars.BUILD_LIBOQS == 'true' + run: | + git clone --depth 1 https://github.com/open-quantum-safe/liboqs.git /tmp/liboqs + mkdir -p /tmp/liboqs/build && cd /tmp/liboqs/build + cmake -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_PREFIX=/usr/local .. + make -j$(nproc) + sudo make install + sudo ldconfig + python -m pip install liboqs-python + - name: Install Python deps + run: | + python -m pip install --upgrade pip + pip install -r prototype/requirements.txt + - name: Run tests + run: | + pytest -q diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e1a2b55 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.venv/ +__pycache__/ +*.pyc +.pytest_cache/ +.vscode/ +.env +dist/ +build/ +/.pytest_cache/ +*.egg-info/ \ No newline at end of file diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..f41e393 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,125 @@ +Mohawk Inference Engine — Architecture Spec + +Overview + +Goal: provide a production-grade inference engine that enables capabilities LM Studio does not: multi-device layer splitting, PQC-secured edge offload, and high-concurrency session management. This document describes the core subsystems, dataflows, APIs, security model, and implementation priorities for an MVP. + +1. Core concepts + +- Layer-splitting: partitioning a neural network at layer boundaries (or sub-layer blocks) so different partitions (slices) execute on different devices (GPU/NPU/CPU/edge). Each slice exposes a small runtime ABI for input/output activation tensors and metadata. +- Offload: the act of sending one or more slices to a remote device for execution. Offloads must preserve confidentiality/integrity of model IP (weights) and activations as required by policy. +- PQC-secured channel: post-quantum cryptography handshake + authenticated encryption for slice packages and RPC traffic. +- Session manager: long-lived controller that maps client sessions to slice placements, manages QoS, adaptive batching, autoscaling, and failure recovery. + +2. High-level architecture + +Components: +- Controller (central or local): plans partitioning, placement, and routes requests to workers. +- Worker runtime: lightweight process on each device that accepts slice packages, registers capabilities (memory, device type), and executes slices. +- Offload transport: secure RPC over TCP/QUIC with PQC handshake and integrity checks. +- Session Manager: receives client requests, handles session state, batching, and QoS rules. +- Scheduler: maps slices to workers, performs placement decisions using cost model and current telemetry. +- Persistence: key/value store for slice metadata, session state, and logs (can be local filesystem or etcd for distributed setups). + +3. Layer-splitting design + +3.1 Partitioning model +- Static split: for MVP, support deterministic splits at transformer block or attention/MLP block granularity. Input: model graph (ONNX, TorchScript), cost model, device inventory. Output: ordered list of slices with boundary tensor shapes and serialization descriptors. +- Dynamic split (future): runtime re-partitioning based on latency/throughput signals. + +3.2 Slice format +- Metadata: slice id, inputs/outputs shapes, parameter size, expected memory footprint, device hints, version, policy tags (private/public). +- Artifact: serialized weights in compact format (FP16/int8 quantized optional) + small runtime glue to map tensor ops. +- Transport container: authenticated envelope (PQC AEAD) + optional compression. + +3.3 Runtime ABI +- Execute(slice_id, input_tensor, trace_id) -> output_tensor, metrics +- Health(check) -> status +- Preload(slice_id) -> ack + +3.4 Scheduling and placement +- Cost model inputs: parameter size, compute FLOPs per-token, estimated activation sizes, device throughput and free memory, network latency. +- Heuristics for MVP: place compute-heavy contiguous slices on GPU if available; place small parameter slices on CPU to lower memory duplication; prefer colocated slices to reduce network hops. +- Backpressure: if a worker is loaded, controller routes slice to alternate worker or falls back to local execution. + +4. PQC-secured edge offload + +4.1 Security goals +- Confidentiality of slice weights when policy requires (IP protection). +- Integrity of slice artifacts and runtime RPCs. +- Forward-secure key exchange resistant to quantum-capable adversaries. + +4.2 Keyflows and handshakes +- Root authority: operator provides long-term signing key (classical/ECDSA) for worker identity; optionally use hardware TPM for key storage. +- Session handshake: use a PQC KEM (e.g., Kyber or later NIST standard) to establish ephemeral symmetric AEAD keys per connection. Steps: + 1. Controller/worker exchange identity-signed certificates (classical) and PQC KEM public values. + 2. Both sides derive AEAD keys via HKDF over KEM shared secret and transcript. + 3. Optionally request remote attestation token before accepting slices (attestation hooks, e.g., Intel SGX/SEV or MDS attestation APIs). + +4.3 Slice packaging & integrity +- Each slice package: {manifest, weights.blob, signature, version} +- Manifest contains policy tags; controller encrypts package with AEAD key and includes HMAC/signature for extra assurance. +- Workers verify signature + AEAD before load. + +4.4 Performance considerations +- PQC KEM handshake cost is paid per long-lived connection; reuse AEAD keys for multiple RPCs. +- For high-throughput edge fleets, pre-provision slice packages to workers via provisioning channel to avoid repeated KEM costs. + +5. Session manager + +5.1 API (gRPC/HTTP) +- StartSession(request {model, routingHints, qos, tenant}) -> session_id +- Infer(session_id, input, options {sync|async}) -> response stream or token +- EndSession(session_id) +- GetSessionStats(session_id) -> metrics + +5.2 Session lifecycle +- Session creation: controller allocates slices, populates placement plan, preloads prioritized slices on workers, returns session token. +- Execution path: client -> session manager -> controller splits request across slices -> workers execute in pipeline -> session manager aggregates outputs. +- Adaptive batching: session manager groups small inferences into micro-batches per slice based on configured latency budgets. + +5.3 QoS and isolation +- Per-session resource caps (max concurrency, token rate). +- Tenant isolation: per-tenant slice caching and optional model duplication flags. +- Fair queuing or priority queues for low-latency sessions. + +6. Telemetry & metrics +- Per-slice metrics: exec latency, memory usage, throughput, error rate. +- Per-worker metrics: GPU util, free memory, network RTT, connection counts. +- Per-session metrics: p50/p95/p99 latencies, batch sizes, tokens/sec. +- Emit via Prometheus metrics endpoint and structured traces (OpenTelemetry) for tracing across slices. + +7. Failure modes and fallbacks +- Worker failure: controller reroutes to alternate worker or triggers local fallback (single-node execution). Evict/restore policy for preloaded slices. +- Network partition: fall back to local execution when possible; if offload required, return graceful degradation messages to client. +- Mismatched versions: use manifest version checks to prevent executing incompatible slices. + +8. Interfaces & data formats +- Model ingestion: accept ONNX and TorchScript (MVP) with translator that enumerates layer boundaries. +- Slice artifact: gzipped protobuf or tar with manifest.json and weights.bin. +- RPC: gRPC over QUIC (preferred) or HTTP/2 with AEAD wrapper. + +9. Testing & benchmarks +- Unit tests: correctness of slice outputs vs baseline single-node for a suite of models. +- Integration tests: end-to-end run across two devices (GPU + CPU) validating activations and outputs. +- Load tests: simulate 1k concurrent sessions with synthetic clients, measure p95 latency and throughput. +- Security tests: verify PQC handshake, replay protection, and attestation flows. + +10. MVP milestones and deliverables +- Week 0–1: architecture doc, slice format, and prototype plan. (this doc) +- Week 1–2: implement controller + worker minimal runtime and static partitioner that accepts a small transformer and emits slices. +- Week 2–3: add PQC handshake, encrypted slice transport, and pre-provisioning flow. +- Week 3–4: session manager with adaptive batching and basic QoS; run 1k simulated sessions. +- Week 4–5: integration tests, telemetry dashboard, readme hero docs, and release prep. + +11. Open questions +- Target PQC primitives (Kyber, CRYSTALS-Kyber; choose current NIST-recommended variant). Decide whether to include hybrid classical+PQC key exchange. +- Attestation strategy for diverse edge hardware — what minimal attestation APIs should we support for MVP? +- Benchmark targets: supply representative hardware profiles to set realistic throughput/latency goals. + +Appendix: quick dataflow +1. `StartSession` -> controller computes split plan -> preloads slices to assigned workers (encrypted transfer). +2. Client sends `Infer` -> session manager pipelines activations across workers over secure channels. +3. Workers return outputs and metrics -> session manager aggregates and returns response. + +Next steps: implement the static partitioner and minimal worker runtime (Week 1 task). \ No newline at end of file diff --git a/docs/PQC_INTEGRATION.md b/docs/PQC_INTEGRATION.md new file mode 100644 index 0000000..dc8c8eb --- /dev/null +++ b/docs/PQC_INTEGRATION.md @@ -0,0 +1,45 @@ +liboqs (pyOQS) integration notes + +Goal: Replace the placeholder X25519-only `PQCAdapter` with a hybrid KEM based on liboqs (e.g., Kyber) + X25519. + +High level steps: + +1. Install native liboqs and Python bindings (pyOQS). + - On Ubuntu (example): + ```bash + sudo apt-get update + sudo apt-get install -y build-essential cmake libssl-dev pkg-config + # Build and install liboqs from source (follow liboqs README) + git clone --branch main https://github.com/open-quantum-safe/liboqs.git + cd liboqs + mkdir build && cd build + cmake -DCMAKE_INSTALL_PREFIX=/usr/local .. + make -j$(nproc) + sudo make install + + # Install the Python bindings that import as `oqs` + pip install liboqs-python + ``` + - Alternatively use your distribution's packages or a prepared devcontainer that installs liboqs. + - Set `OQS_INSTALL_PATH=/usr/local` when using a local source install so the binding can find the shared library. + +2. Update `prototype/crypto.py` to perform a proper KEM exchange during handshake: + - Controller: send X25519 pub + OQS pub to worker. + - Worker: encapsulate to controller's OQS pub -> return encapsulation ciphertext + worker OQS pub. + - Controller: decapsulate ciphertext to obtain OQS shared secret. + - Final symmetric AEAD key = HKDF(X25519_shared || OQS_shared) + - The current binding in this workspace exposes `oqs.KeyEncapsulation`, `generate_keypair()`, `encap_secret()`, and `decap_secret()`. + +3. Tests & validation: + - Run `pytest -q prototype/test_oqs_hybrid.py prototype/test_secure_hybrid_integration.py prototype/test_concurrency_smoke.py`. + - Use `prototype/test_secure_run.py` as a quick smoke script when you want a single-session end-to-end check. + - Ensure the worker `/handshake` returns `worker_oqs_pub_b64` and `worker_pub_b64` when liboqs is available. + +Notes: +- The repository already contains scaffolding in `prototype/crypto.py` to detect pyOQS at runtime and expose `get_oqs_public()`; complete integration requires invoking `kem.encapsulate()` and `kem.decapsulate()` where appropriate. +- Building liboqs on CI requires adding native build steps in the pipeline; consider a GitHub Actions matrix job with a prebuilt liboqs artifact or using a self-hosted runner. +- The CI workflow includes a manual `workflow_dispatch` trigger that can build liboqs from source when `build_liboqs` is enabled. + +If you want, I can: +- Implement the full handshake KEM flow (controller encapsulate/decapsulate and worker encapsulate) once you confirm installing `pyOQS` in the devcontainer/CI is acceptable, or +- Prepare a PR that adds devcontainer Dockerfile steps to install liboqs so we can run the full integration here. diff --git a/docs/SCOPE.md b/docs/SCOPE.md new file mode 100644 index 0000000..cbfbfd0 --- /dev/null +++ b/docs/SCOPE.md @@ -0,0 +1,18 @@ +Scope & Success Criteria + +Target: platform and infrastructure engineers, MLOps teams, and edge fleet operators who need production-grade inference beyond single-node setups. + +MVP capabilities: +- Multi-device layer splitting: demonstrate partitioning a medium-sized transformer across GPU and CPU with deterministic correctness and end-to-end inference. +- Secure edge offload: implement PQC-based encryption and integrity checks for offloaded model slices and communications. +- High-concurrency session management: support 1k+ concurrent lightweight sessions with per-session QoS and adaptive batching. + +Success metrics: +- Correctness: identical outputs (within numerical tolerance) compared to single-node baseline for partitioned runs. +- Performance: 2× throughput improvement for target hardware when split across devices (measured on prototype hardware), and median p95 latency within target SLA for 95% of sessions. +- Security: PQC handshake and slice integrity checks complete within acceptable overhead (<20% added latency in offload path) and keys/telemetry never expose raw weights. + +Out of scope for MVP: +- Full production orchestration (K8s operators) and UI consoles — focus is on core engine, APIs, and integrations. + +Next: architecture spec covering layer-splitting algorithm, PQC keyflows, and session manager APIs. \ No newline at end of file diff --git a/prototype/README_PROTOTYPE.md b/prototype/README_PROTOTYPE.md new file mode 100644 index 0000000..fc17383 --- /dev/null +++ b/prototype/README_PROTOTYPE.md @@ -0,0 +1,31 @@ +Prototype demo + +This prototype demonstrates a minimal multi-device layer-splitting demo using a toy model. It simulates two workers (FastAPI) that accept slice preload and execution. + +Quickstart: + +1. Install dependencies: + +```bash +python -m pip install -r prototype/requirements.txt +``` + +2. Start two workers in separate terminals (secure worker available): + +```bash +# insecure worker (no encryption) +python prototype/worker.py --port 8001 +# secure worker (handshake + AEAD) listens on a separate port +python prototype/worker_secure.py --port 8003 +``` + +3. Run the demo: + +```bash +python prototype/run_demo.py +``` + +Notes: +- This is a functional prototype illustrating partitioning, preload, and remote execution. It uses pickle-serialized weights and inputs for simplicity. +- A secure path using X25519 + optional liboqs hybrid KEM is scaffolded in [prototype/crypto.py](prototype/crypto.py) and [prototype/worker_secure.py](prototype/worker_secure.py). To enable full hybrid PQC tests, install native liboqs plus the Python binding and set `OQS_INSTALL_PATH=/usr/local` (see [docs/PQC_INTEGRATION.md](docs/PQC_INTEGRATION.md)). +- The in-process integration tests can be run with `pytest -q prototype/test_secure_hybrid_integration.py prototype/test_concurrency_smoke.py` once the environment is prepared. diff --git a/prototype/__init__.py b/prototype/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/prototype/controller.py b/prototype/controller.py new file mode 100644 index 0000000..83de5fa --- /dev/null +++ b/prototype/controller.py @@ -0,0 +1,46 @@ +import requests +import base64 +import pickle +from prototype.model_tools import ToyModel + +class Controller: + def __init__(self, workers): + # workers: list of urls + self.workers = workers + + def partition_model(self, model: ToyModel, num_slices=2): + L = len(model.weights) + per = max(1, L // num_slices) + slices = [] + for i in range(0, L, per): + start = i + end = min(L, i+per) + sub = model.slice(start, end) + slices.append((start, end, sub)) + return slices + + def preload_slices(self, slices): + # round-robin assign to workers + assigned = [] + for i, (start,end,sub) in enumerate(slices): + w = self.workers[i % len(self.workers)] + blob = sub.serialize() + b64 = base64.b64encode(blob).decode('ascii') + manifest = {"start": start, "end": end} + payload = {"slice_id": f"slice_{start}_{end}", "manifest": manifest, "weights_b64": b64} + r = requests.post(f"{w}/preload", json=payload, timeout=10) + r.raise_for_status() + assigned.append((payload['slice_id'], w)) + return assigned + + def run_distributed(self, assigned, x_blob): + # assigned: list of (slice_id, worker_url) in order + current = x_blob + for slice_id, w in assigned: + b64 = base64.b64encode(current).decode('ascii') + payload = {"slice_id": slice_id, "input_b64": b64} + r = requests.post(f"{w}/execute", json=payload, timeout=30) + r.raise_for_status() + out_b64 = r.json()['output_b64'] + current = base64.b64decode(out_b64) + return current diff --git a/prototype/controller_secure.py b/prototype/controller_secure.py new file mode 100644 index 0000000..ae46681 --- /dev/null +++ b/prototype/controller_secure.py @@ -0,0 +1,159 @@ +import requests +import base64 +import pickle +from prototype.model_tools import ToyModel +from prototype.crypto import PQCAdapter, AEAD, b64, ub64 +import threading +import time +import random + +class SecureController: + def __init__(self, workers): + self.workers = workers + self.keys = {} # worker_url -> AEAD + self.kems = {} # worker_url -> PQCAdapter (ephemeral keypair reused per worker) + # initialize per-worker locks to avoid races during handshake + self.kem_locks = {w: threading.Lock() for w in workers} + # attempt initial handshake with all workers to establish AEAD keys + for w in workers: + try: + self.handshake_with_worker(w) + except Exception: + # don't fail construction; handshake will be attempted lazily + pass + + def partition_model(self, model: ToyModel, num_slices=2): + L = len(model.weights) + per = max(1, L // num_slices) + slices = [] + for i in range(0, L, per): + start = i + end = min(L, i+per) + sub = model.slice(start, end) + slices.append((start, end, sub)) + return slices + + def handshake_with_worker(self, worker_url): + # ensure only one handshake happens concurrently per worker + if worker_url not in self.kem_locks: + self.kem_locks[worker_url] = threading.Lock() + lock = self.kem_locks[worker_url] + with lock: + # reuse or create KEM per worker to keep a stable shared key + if worker_url in self.kems: + kem = self.kems[worker_url] + else: + kem = PQCAdapter() + self.kems[worker_url] = kem + client_pub = kem.public_bytes() + # include optional OQS public bytes if available (scaffolding) + payload = {"client_pub_b64": b64(client_pub), "client_id": "controller"} + try: + oqs_pub = kem.get_oqs_public() + if oqs_pub: + payload["oqs_pub_b64"] = b64(oqs_pub) + except Exception: + pass + r = requests.post(f"{worker_url}/handshake", json=payload, timeout=5) + r.raise_for_status() + j = r.json() + worker_pub_b64 = j['worker_pub_b64'] + # optional worker OQS pub and encapsulation ct for hybrid KEM + worker_oqs_b64 = j.get('worker_oqs_pub_b64') + worker_oqs_ct_b64 = j.get('worker_oqs_ct_b64') + worker_pub = ub64(worker_pub_b64) + x25519_shared = kem.derive_shared(worker_pub) + # if worker returned an OQS encapsulation, decapsulate and derive hybrid key + final_key = None + if worker_oqs_ct_b64 and kem.oqs_supported: + try: + ct = ub64(worker_oqs_ct_b64) + oqs_shared = kem.decap(ct) + from prototype.crypto import derive_hybrid_key + final_key = derive_hybrid_key(x25519_shared, oqs_shared) + except Exception: + final_key = x25519_shared + else: + final_key = x25519_shared + self.keys[worker_url] = AEAD(final_key) + return True + + def preload_slices(self, slices, encrypt=False): + assigned = [] + for i, (start,end,sub) in enumerate(slices): + w = self.workers[i % len(self.workers)] + blob = sub.serialize() + manifest = {"start": start, "end": end} + slice_id = f"slice_{start}_{end}" + if encrypt: + if w not in self.keys: + self.handshake_with_worker(w) + aead = self.keys[w] + nonce, ct = aead.encrypt(blob) + payload = {"slice_id": slice_id, "manifest": manifest, "encrypted": True, + "weights_b64": b64(ct), "nonce_b64": b64(nonce)} + else: + payload = {"slice_id": slice_id, "manifest": manifest, "weights_b64": b64(blob)} + # retry with exponential backoff for transient failures + max_attempts = 3 + backoff_base = 0.05 + for attempt in range(1, max_attempts+1): + try: + r = requests.post(f"{w}/preload", json=payload, timeout=10) + r.raise_for_status() + break + except Exception as e: + if attempt == max_attempts: + raise + sleep_t = backoff_base * (2 ** (attempt - 1)) + random.uniform(0, backoff_base) + time.sleep(sleep_t) + assigned.append((slice_id, w)) + return assigned + + def run_distributed(self, assigned, x_blob, encrypt=False): + current = x_blob + for slice_id, w in assigned: + if encrypt: + aead = self.keys[w] + nonce, ct = aead.encrypt(current) + payload = {"slice_id": slice_id, "encrypted": True, "input_b64": b64(ct), "nonce_b64": b64(nonce)} + # retry execute with backoff for transient errors + max_attempts = 3 + backoff_base = 0.05 + for attempt in range(1, max_attempts+1): + try: + r = requests.post(f"{w}/execute", json=payload, timeout=30) + r.raise_for_status() + break + except Exception: + if attempt == max_attempts: + raise + sleep_t = backoff_base * (2 ** (attempt - 1)) + time.sleep(sleep_t) + else: + payload = {"slice_id": slice_id, "input_b64": base64.b64encode(current).decode('ascii')} + # non-encrypted execute also gets retries + max_attempts = 3 + backoff_base = 0.05 + for attempt in range(1, max_attempts+1): + try: + r = requests.post(f"{w}/execute", json=payload, timeout=30) + r.raise_for_status() + break + except Exception: + if attempt == max_attempts: + raise + sleep_t = backoff_base * (2 ** (attempt - 1)) + time.sleep(sleep_t) + r.raise_for_status() + j = r.json() + if j.get('encrypted'): + aead = self.keys[w] + nonce = ub64(j['nonce_b64']) + ct = ub64(j['output_b64']) + out = aead.decrypt(nonce, ct) + current = out + else: + out_b64 = j['output_b64'] + current = base64.b64decode(out_b64) + return current diff --git a/prototype/crypto.py b/prototype/crypto.py new file mode 100644 index 0000000..cb7a0ae --- /dev/null +++ b/prototype/crypto.py @@ -0,0 +1,167 @@ +from cryptography.hazmat.primitives.asymmetric import x25519 +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +import os +import base64 +from cryptography.hazmat.primitives import serialization + +# Try to detect liboqs / pyOQS availability. If present, we'll expose +# scaffolding for a PQC KEM; if not present, we gracefully fall back +# to the existing X25519-only DH flow. +OQS_AVAILABLE = False +_oqs = None +try: + import oqs as _oqs # type: ignore + OQS_AVAILABLE = True +except Exception: + OQS_AVAILABLE = False + + +class PQCAdapter: + """Hybrid PQC adapter scaffold. + + Current behaviour: + - Always performs an X25519 DH exchange to produce a shared secret. + - If liboqs/pyOQS is present on both peers, this class exposes + additional public bytes fields so a real KEM exchange can be + implemented later without changing the outer handshake shape. + + Note: proper KEM encapsulate/decapsulate requires extra ciphertext + to be exchanged. This file adds scaffolding so that future work can + implement a full liboqs hybrid KEM without large protocol changes. + """ + + def __init__(self, oqs_alg: str = 'Kyber512'): + self._priv = x25519.X25519PrivateKey.generate() + self.pub = self._priv.public_key() + self.oqs_supported = False + self.oqs_alg = oqs_alg + self.oqs_public = b'' + if OQS_AVAILABLE: + try: + kem_cls = getattr(_oqs, 'KeyEncapsulation', None) + if kem_cls is None: + kem_cls = getattr(_oqs, 'KEM', None) + if kem_cls is None: + raise RuntimeError('No OQS KEM class available') + self.kem = kem_cls(self.oqs_alg) + pub = self.kem.generate_keypair() + if isinstance(pub, tuple): + pub = pub[0] + self.oqs_public = pub + self.oqs_supported = True + except Exception: + self.kem = None + self.oqs_public = b'' + self.oqs_supported = False + + def public_bytes(self) -> bytes: + """Return the X25519 public bytes. For forward-compatibility we + also expose an optional OQS public blob via `get_oqs_public()`. + The current handshake uses only the X25519 bytes for key + derivation; OQS support is scaffolding for later hybrid KEM + steps. + """ + return self.pub.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + + def get_oqs_public(self) -> bytes: + return self.oqs_public + + def derive_shared(self, peer_public_bytes: bytes) -> bytes: + """Derive a symmetric AEAD key. Currently this uses X25519 DH + only (keeps existing behaviour). When OQS hybrid KEM is fully + implemented, concat/PRF of both secrets should be used here. + """ + peer_pub = x25519.X25519PublicKey.from_public_bytes(peer_public_bytes) + shared = self._priv.exchange(peer_pub) + # derive AEAD key from shared secret + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=None, + info=b'mohawk-aead-key', + ) + key = hkdf.derive(shared) + return key + + # OQS helper wrappers: encapsulate/decapsulate when available + def encap(self, peer_oqs_pub: bytes): + """Encapsulate to `peer_oqs_pub` using the pyOQS KEM if available. + Returns (ct, shared) or raises RuntimeError if not supported. + """ + if not self.oqs_supported or not getattr(self, 'kem', None): + raise RuntimeError('OQS not available') + # Try common pyOQS method names defensively + try: + # pyOQS KeyEncapsulation API: kem.encap_secret(pub) or kem.encapsulate(pub) + if hasattr(self.kem, 'encap_secret'): + ct, ss = self.kem.encap_secret(peer_oqs_pub) + return ct, ss + if hasattr(self.kem, 'encapsulate'): + ct, ss = self.kem.encapsulate(peer_oqs_pub) + return ct, ss + if hasattr(self.kem, 'encap'): + ct, ss = self.kem.encap(peer_oqs_pub) + return ct, ss + except Exception as e: + raise RuntimeError('OQS encapsulation failed: %s' % e) + raise RuntimeError('OQS encapsulation not supported by this pyOQS build') + + def decap(self, ct: bytes): + """Decapsulate ciphertext `ct` using stored private key. Returns shared secret.""" + if not self.oqs_supported or not getattr(self, 'kem', None): + raise RuntimeError('OQS not available') + try: + if hasattr(self.kem, 'decap_secret'): + ss = self.kem.decap_secret(ct) + return ss + if hasattr(self.kem, 'decapsulate'): + ss = self.kem.decapsulate(ct) + return ss + if hasattr(self.kem, 'decap'): + ss = self.kem.decap(ct) + return ss + except Exception as e: + raise RuntimeError('OQS decapsulation failed: %s' % e) + raise RuntimeError('OQS decapsulation not supported by this pyOQS build') + + +def derive_hybrid_key(shared_x25519: bytes, shared_oqs: bytes) -> bytes: + """Derive a single AEAD key from two raw shared secrets (concatenate + and run HKDF). This produces a 32-byte AEAD key. + """ + combined = (shared_x25519 or b'') + (shared_oqs or b'') + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=None, + info=b'mohawk-hybrid-aead-key', + ) + return hkdf.derive(combined) + + +class AEAD: + def __init__(self, key: bytes): + self.key = key + self.aead = ChaCha20Poly1305(key) + + def encrypt(self, plaintext: bytes, aad: bytes = b''): + nonce = os.urandom(12) + ct = self.aead.encrypt(nonce, plaintext, aad) + return nonce, ct + + def decrypt(self, nonce: bytes, ciphertext: bytes, aad: bytes = b''): + return self.aead.decrypt(nonce, ciphertext, aad) + + +# helpers +def b64(x: bytes) -> str: + return base64.b64encode(x).decode('ascii') + + +def ub64(s: str) -> bytes: + return base64.b64decode(s) diff --git a/prototype/integration_helpers.py b/prototype/integration_helpers.py new file mode 100644 index 0000000..14a07b6 --- /dev/null +++ b/prototype/integration_helpers.py @@ -0,0 +1,53 @@ +from urllib.parse import urlparse +import asyncio + +import requests +import httpx2 + +from prototype import worker_secure + + +def reset_worker_state() -> None: + worker_secure.slices.clear() + worker_secure.keys.clear() + with worker_secure.metrics_lock: + for key in worker_secure.metrics: + worker_secure.metrics[key] = 0 + + +def make_worker_client() -> httpx2.Client: + reset_worker_state() + transport = httpx2.ASGITransport(app=worker_secure.app) + return httpx2.Client(transport=transport, base_url="http://worker-inproc") + + +class _InProcessResponse: + def __init__(self, response, url: str): + self._response = response + self.status_code = response.status_code + self.text = response.text + self.url = url + + def json(self): + return self._response.json() + + def raise_for_status(self): + if self.status_code >= 400: + raise requests.HTTPError( + f"{self.status_code} error for {self.url}", + response=self._response, + ) + + +class InProcessWorkerTransport: + def __init__(self, client: httpx2.Client): + self.client = client + + def post(self, url, json=None, timeout=None, **kwargs): + path = urlparse(url).path or "/" + async def _post(): + async with httpx2.AsyncClient(transport=self.client._transport, base_url=self.client.base_url) as async_client: + return await async_client.post(path, json=json) + + response = asyncio.run(_post()) + return _InProcessResponse(response, url) \ No newline at end of file diff --git a/prototype/load_harness.py b/prototype/load_harness.py new file mode 100644 index 0000000..90d1467 --- /dev/null +++ b/prototype/load_harness.py @@ -0,0 +1,64 @@ +import time +import numpy as np +from prototype.model_tools import ToyModel +from prototype.session_manager import SessionManager +from concurrent.futures import ThreadPoolExecutor, as_completed + + +def run_session_sync(sm: SessionManager, model, session_idx, encrypt=False): + sid = sm.start_session(model, num_slices=2, encrypt=encrypt) + x = np.random.default_rng(session_idx).standard_normal((8,1)).astype('float32') + out = sm.infer(sid, x) + sm.end_session(sid) + return out + + +def run_load(workers, concurrency=20, total=100, encrypt=False): + sm = SessionManager(workers) + model = ToyModel([8,16,16,8], seed=42) + results = [] + start = time.time() + with ThreadPoolExecutor(max_workers=concurrency) as ex: + futures = [ex.submit(run_session_sync, sm, model, i, encrypt) for i in range(total)] + for f in as_completed(futures): + results.append(f.result()) + end = time.time() + print(f"Completed {total} sessions in {end-start:.2f}s") + return results + + +if __name__ == '__main__': + workers = ["http://127.0.0.1:8003", "http://127.0.0.1:8003"] + import requests, json + runs = [ + {'concurrency': 50, 'total': 200}, + {'concurrency': 100, 'total': 500}, + {'concurrency': 200, 'total': 1000}, + ] + all_agg = {} + for rconf in runs: + c = rconf['concurrency'] + t = rconf['total'] + print(f"Starting run total={t} concurrency={c}") + run_load(workers, concurrency=c, total=t, encrypt=True) + # fetch metrics from workers + agg = {} + for w in set(workers): + try: + resp = requests.get(f"{w}/metrics", timeout=5) + resp.raise_for_status() + m = resp.json() + print(f"metrics from {w}: {m}") + for k, v in m.items(): + agg[k] = agg.get(k, 0) + v + except Exception as e: + print(f"failed to fetch metrics from {w}: {e}") + print(f"aggregated metrics for run {t}: {agg}") + all_agg[f"run_{t}"] = agg + # persist a copy + try: + with open(f"/tmp/metrics_run_{t}.json", 'w') as fh: + json.dump(agg, fh) + except Exception as e: + print(f"failed to write metrics file: {e}") + print(f"all runs aggregated: {all_agg}") diff --git a/prototype/model_tools.py b/prototype/model_tools.py new file mode 100644 index 0000000..c11ec87 --- /dev/null +++ b/prototype/model_tools.py @@ -0,0 +1,40 @@ +import numpy as np +import pickle + +class ToyModel: + def __init__(self, layer_sizes, seed=0): + rng = np.random.default_rng(seed) + self.weights = [] + for i in range(len(layer_sizes)-1): + w = rng.standard_normal((layer_sizes[i+1], layer_sizes[i])).astype(np.float32) + b = rng.standard_normal((layer_sizes[i+1],)).astype(np.float32) + self.weights.append((w, b)) + + def forward(self, x): + out = x + for (w,b) in self.weights: + out = w @ out + b[:, None] + out = np.tanh(out) + return out + + def slice(self, start_layer, end_layer): + # returns a new ToyModel with subset of layers + sub = ToyModel.__new__(ToyModel) + sub.weights = self.weights[start_layer:end_layer] + return sub + + def serialize(self): + return pickle.dumps(self.weights) + + @staticmethod + def deserialize(blob): + m = ToyModel.__new__(ToyModel) + m.weights = pickle.loads(blob) + return m + + def apply(self, x): + out = x + for (w,b) in self.weights: + out = w @ out + b[:, None] + out = np.tanh(out) + return out diff --git a/prototype/requirements.txt b/prototype/requirements.txt new file mode 100644 index 0000000..f107339 --- /dev/null +++ b/prototype/requirements.txt @@ -0,0 +1,7 @@ +fastapi +uvicorn +numpy +requests +cryptography +pytest +httpx2 diff --git a/prototype/run_demo.py b/prototype/run_demo.py new file mode 100644 index 0000000..309c93b --- /dev/null +++ b/prototype/run_demo.py @@ -0,0 +1,43 @@ +import numpy as np +import pickle +import base64 +import time +from prototype.model_tools import ToyModel +from prototype.controller import Controller + +# config +worker_urls = ["http://127.0.0.1:8001", "http://127.0.0.1:8002"] + +def single_node_run(model, x): + return model.forward(x) + +def distributed_run(model, x): + c = Controller(worker_urls) + slices = c.partition_model(model, num_slices=2) + assigned = c.preload_slices(slices) + x_blob = pickle.dumps(x) + out_blob = c.run_distributed(assigned, x_blob) + out = pickle.loads(out_blob) + return out + +if __name__ == '__main__': + # build model + model = ToyModel([8,16,16,8], seed=42) + x = np.random.default_rng(1).standard_normal((8,1)).astype('float32') + + print("Running single-node baseline...") + baseline = single_node_run(model, x) + + print("Running distributed demo (requires two workers at :8001 and :8002)...") + t0 = time.time() + out = distributed_run(model, x) + t1 = time.time() + print(f"Distributed run time: {t1-t0:.3f}s") + + # compare + diff = np.max(np.abs(baseline - out)) + print(f"Max abs diff vs baseline: {diff}") + if diff < 1e-5: + print("SUCCESS: outputs match within tolerance") + else: + print("WARNING: outputs differ — check serialization/ordering") diff --git a/prototype/session_manager.py b/prototype/session_manager.py new file mode 100644 index 0000000..ca93a6a --- /dev/null +++ b/prototype/session_manager.py @@ -0,0 +1,28 @@ +import uuid +import pickle +from prototype.controller_secure import SecureController + +class SessionManager: + def __init__(self, workers): + self.controller = SecureController(workers) + self.sessions = {} + + def start_session(self, model, num_slices=2, encrypt=False): + session_id = str(uuid.uuid4()) + slices = self.controller.partition_model(model, num_slices=num_slices) + assigned = self.controller.preload_slices(slices, encrypt=encrypt) + self.sessions[session_id] = {"assigned": assigned, "encrypt": encrypt} + return session_id + + def infer(self, session_id, x): + s = self.sessions[session_id] + x_blob = pickle.dumps(x) + out_blob = self.controller.run_distributed(s['assigned'], x_blob, encrypt=s['encrypt']) + out = pickle.loads(out_blob) + return out + + def end_session(self, session_id): + if session_id in self.sessions: + del self.sessions[session_id] + return True + return False diff --git a/prototype/telemetry.py b/prototype/telemetry.py new file mode 100644 index 0000000..11c10b3 --- /dev/null +++ b/prototype/telemetry.py @@ -0,0 +1,62 @@ +import time +import inspect +from functools import wraps + + +class Telemetry: + def __init__(self, metrics_dict, lock): + self.metrics = metrics_dict + self.lock = lock + # histogram bucket boundaries in seconds + self.buckets = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0] + + def record(self, name_sum, name_count, duration): + # record sum and count in metrics dict + with self.lock: + self.metrics[name_sum] = self.metrics.get(name_sum, 0.0) + duration + self.metrics[name_count] = self.metrics.get(name_count, 0) + 1 + # also update histogram buckets for this metric prefix + try: + base = name_sum + if base.endswith('_sum'): + base = base[:-4] + hist_prefix = f"{base}_hist" + # find the appropriate bucket + for b in self.buckets: + key = f"{hist_prefix}_{b}" + if duration <= b: + self.metrics[key] = self.metrics.get(key, 0) + 1 + break + else: + # overflow bucket + key = f"{hist_prefix}_+Inf" + self.metrics[key] = self.metrics.get(key, 0) + 1 + except Exception: + pass + + def timed(self, name_sum, name_count): + def decorator(func): + if inspect.iscoroutinefunction(func): + async def async_wrapper(*args, **kwargs): + t0 = time.time() + try: + return await func(*args, **kwargs) + finally: + dt = time.time() - t0 + self.record(name_sum, name_count, dt) + + wraps(func)(async_wrapper) + return async_wrapper + else: + def sync_wrapper(*args, **kwargs): + t0 = time.time() + try: + return func(*args, **kwargs) + finally: + dt = time.time() - t0 + self.record(name_sum, name_count, dt) + + wraps(func)(sync_wrapper) + return sync_wrapper + + return decorator diff --git a/prototype/test_concurrency_smoke.py b/prototype/test_concurrency_smoke.py new file mode 100644 index 0000000..48cdb19 --- /dev/null +++ b/prototype/test_concurrency_smoke.py @@ -0,0 +1,22 @@ +import pytest +import numpy as np + +import prototype.controller_secure as controller_secure +from prototype.integration_helpers import InProcessWorkerTransport, make_worker_client, reset_worker_state +from prototype.load_harness import run_load + + +@pytest.fixture() +def inprocess_worker(monkeypatch): + client = make_worker_client() + transport = InProcessWorkerTransport(client) + monkeypatch.setattr(controller_secure.requests, 'post', transport.post) + yield client + reset_worker_state() + + +def test_concurrency_smoke(inprocess_worker): + workers = ['http://worker-inproc'] + res = run_load(workers, concurrency=4, total=8, encrypt=True) + assert len(res) == 8 + assert all(isinstance(item, np.ndarray) for item in res) diff --git a/prototype/test_oqs_hybrid.py b/prototype/test_oqs_hybrid.py new file mode 100644 index 0000000..11564fe --- /dev/null +++ b/prototype/test_oqs_hybrid.py @@ -0,0 +1,35 @@ +import sys +import pytest +from prototype.crypto import PQCAdapter, derive_hybrid_key, AEAD, OQS_AVAILABLE + + +def test_oqs_hybrid_encap_decap(): + if not OQS_AVAILABLE: + pytest.skip('oqs module not available') + # create two adapters (controller and worker) + c = PQCAdapter() + w = PQCAdapter() + # ensure both report oqs_supported; skip if pyOQS API not present + if not getattr(c, 'oqs_supported', False) or not getattr(w, 'oqs_supported', False): + pytest.skip('pyOQS KEM API not available in this environment') + # exchange oqs public keys + c_pub = c.get_oqs_public() + w_pub = w.get_oqs_public() + # controller encapsulates to worker's pub + ct, ss_c = c.encap(w_pub) + # worker decapsulates + ss_w = w.decap(ct) + assert ss_c == ss_w + # also derive X25519 shared + x_c = c.derive_shared(w.public_bytes()) + x_w = w.derive_shared(c.public_bytes()) + assert x_c == x_w + # derive hybrid key and verify AEAD + hybrid_c = derive_hybrid_key(x_c, ss_c) + hybrid_w = derive_hybrid_key(x_w, ss_w) + assert hybrid_c == hybrid_w + aead = AEAD(hybrid_c) + plaintext = b'test message' + nonce, ct = aead.encrypt(plaintext) + out = aead.decrypt(nonce, ct) + assert out == plaintext diff --git a/prototype/test_secure_hybrid_integration.py b/prototype/test_secure_hybrid_integration.py new file mode 100644 index 0000000..7e8f0e5 --- /dev/null +++ b/prototype/test_secure_hybrid_integration.py @@ -0,0 +1,38 @@ +import numpy as np +import pytest + +from prototype.crypto import OQS_AVAILABLE, PQCAdapter +from prototype.integration_helpers import InProcessWorkerTransport, make_worker_client, reset_worker_state +from prototype.model_tools import ToyModel +from prototype.session_manager import SessionManager +import prototype.controller_secure as controller_secure + + +def _hybrid_supported() -> bool: + return OQS_AVAILABLE and PQCAdapter().oqs_supported + + +@pytest.fixture() +def inprocess_worker(monkeypatch): + client = make_worker_client() + transport = InProcessWorkerTransport(client) + monkeypatch.setattr(controller_secure.requests, 'post', transport.post) + yield client + reset_worker_state() + + +def test_secure_hybrid_roundtrip_inprocess(inprocess_worker): + if not _hybrid_supported(): + pytest.skip('pyOQS hybrid KEM not available in this environment') + + workers = ['http://worker-inproc'] + sm = SessionManager(workers) + model = ToyModel([8, 16, 16, 8], seed=42) + x = np.random.default_rng(7).standard_normal((8, 1)).astype('float32') + + sid = sm.start_session(model, num_slices=2, encrypt=True) + out = sm.infer(sid, x) + sm.end_session(sid) + + baseline = model.forward(x) + assert np.allclose(out, baseline) diff --git a/prototype/test_secure_run.py b/prototype/test_secure_run.py new file mode 100644 index 0000000..5d8b303 --- /dev/null +++ b/prototype/test_secure_run.py @@ -0,0 +1,31 @@ +import numpy as np +import pytest + +import prototype.controller_secure as controller_secure +from prototype.integration_helpers import InProcessWorkerTransport, make_worker_client, reset_worker_state +from prototype.model_tools import ToyModel +from prototype.session_manager import SessionManager + + +@pytest.fixture() +def inprocess_worker(monkeypatch): + client = make_worker_client() + transport = InProcessWorkerTransport(client) + monkeypatch.setattr(controller_secure.requests, 'post', transport.post) + yield client + reset_worker_state() + + +def test_secure_run_roundtrip_inprocess(inprocess_worker): + workers = ['http://worker-inproc'] + sm = SessionManager(workers) + model = ToyModel([8, 16, 16, 8], seed=42) + + x = np.random.default_rng(1).standard_normal((8, 1)).astype('float32') + baseline = model.forward(x) + + sid = sm.start_session(model, num_slices=2, encrypt=True) + out = sm.infer(sid, x) + sm.end_session(sid) + + assert np.allclose(out, baseline) diff --git a/prototype/worker.py b/prototype/worker.py new file mode 100644 index 0000000..970949f --- /dev/null +++ b/prototype/worker.py @@ -0,0 +1,49 @@ +from fastapi import FastAPI, UploadFile, File, HTTPException +from pydantic import BaseModel +import uvicorn +import asyncio +import pickle +from typing import Dict +from prototype.model_tools import ToyModel +import base64 + +app = FastAPI() + +slices: Dict[str, ToyModel] = {} + +class PreloadRequest(BaseModel): + slice_id: str + manifest: dict + weights_b64: str + +class ExecRequest(BaseModel): + slice_id: str + input_b64: str + +@app.post("/preload") +async def preload(req: PreloadRequest): + try: + blob = base64.b64decode(req.weights_b64) + m = ToyModel.deserialize(blob) + slices[req.slice_id] = m + return {"status": "ok", "slice_id": req.slice_id} + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + +@app.post("/execute") +async def execute(req: ExecRequest): + if req.slice_id not in slices: + raise HTTPException(status_code=404, detail="slice not found") + blob = base64.b64decode(req.input_b64) + x = pickle.loads(blob) + out = slices[req.slice_id].apply(x) + out_blob = pickle.dumps(out) + return {"output_b64": base64.b64encode(out_blob).decode('ascii')} + +if __name__ == '__main__': + import argparse + p = argparse.ArgumentParser() + p.add_argument('--host', default='127.0.0.1') + p.add_argument('--port', type=int, default=8000) + args = p.parse_args() + uvicorn.run(app, host=args.host, port=args.port) diff --git a/prototype/worker_secure.py b/prototype/worker_secure.py new file mode 100644 index 0000000..cc504c2 --- /dev/null +++ b/prototype/worker_secure.py @@ -0,0 +1,222 @@ +from fastapi import FastAPI, UploadFile, File, HTTPException +from pydantic import BaseModel +import uvicorn +import asyncio +import pickle +from typing import Dict +from prototype.model_tools import ToyModel +import base64 +from prototype.crypto import PQCAdapter, AEAD, b64, ub64 +from prototype.telemetry import Telemetry +import traceback +import threading +from fastapi.responses import JSONResponse + +app = FastAPI() + +slices: Dict[str, ToyModel] = {} +keys: Dict[str, AEAD] = {} # peer_pub_b64 -> AEAD + +# simple in-memory metrics +metrics = { + 'handshakes': 0, + 'preload_success': 0, + 'preload_fail': 0, + 'execute_success': 0, + 'execute_fail': 0, +} +metrics_lock = threading.Lock() +telemetry = Telemetry(metrics, metrics_lock) + +class HandshakeRequest(BaseModel): + client_pub_b64: str + client_id: str | None = None + oqs_pub_b64: str | None = None + +class PreloadRequest(BaseModel): + slice_id: str + manifest: dict + weights_b64: str + encrypted: bool = False + nonce_b64: str = None + +class ExecRequest(BaseModel): + slice_id: str + input_b64: str + encrypted: bool = False + nonce_b64: str = None + +@app.post("/handshake") +async def handshake(req: HandshakeRequest): + client_pub = ub64(req.client_pub_b64) + client_id = req.client_id or 'controller' + kem = PQCAdapter() + worker_pub = kem.public_bytes() + # if the controller provided an OQS pub, attempt hybrid KEM + controller_oqs_pub = None + shared_oqs = None + if getattr(req, 'oqs_pub_b64', None): + try: + controller_oqs_pub = ub64(req.oqs_pub_b64) + except Exception: + controller_oqs_pub = None + # always derive X25519 shared + x25519_key = kem.derive_shared(client_pub) + # if both sides support OQS, encapsulate to controller's OQS pub and + # derive a hybrid AEAD key + if controller_oqs_pub and kem.oqs_supported: + try: + ct, shared_oqs = kem.encap(controller_oqs_pub) + except Exception: + ct = None + shared_oqs = None + else: + ct = None + shared_oqs = None + # final AEAD key: hybrid if we have an OQS shared secret, else X25519-only + try: + from prototype.crypto import derive_hybrid_key + if shared_oqs: + final_key = derive_hybrid_key(x25519_key, shared_oqs) + else: + final_key = x25519_key + except Exception: + final_key = x25519_key + # store AEAD keyed by client id for stable lookup + keys[client_id] = AEAD(final_key) + with metrics_lock: + metrics['handshakes'] += 1 + # include worker-side OQS public bytes and encapsulation ct if available + resp = {"worker_pub_b64": b64(worker_pub)} + try: + oqs_pub = kem.get_oqs_public() + if oqs_pub: + resp['worker_oqs_pub_b64'] = b64(oqs_pub) + except Exception: + pass + if ct: + try: + resp['worker_oqs_ct_b64'] = b64(ct) + except Exception: + pass + return resp + +@app.post("/preload") +@telemetry.timed('preload_time_sum', 'preload_time_count') +async def preload(req: PreloadRequest): + try: + if req.encrypted: + # find AEAD by matching any key (simple demo: single client) + # use controller client id mapping + if 'controller' not in keys: + raise HTTPException(status_code=400, detail='no handshake for controller') + aead = keys['controller'] + nonce = ub64(req.nonce_b64) + ct = ub64(req.weights_b64) + blob = aead.decrypt(nonce, ct) + else: + blob = base64.b64decode(req.weights_b64) + m = ToyModel.deserialize(blob) + slices[req.slice_id] = m + with metrics_lock: + metrics['preload_success'] += 1 + return {"status": "ok", "slice_id": req.slice_id} + except Exception as e: + tb = traceback.format_exc() + print("preload error:\n", tb) + with metrics_lock: + metrics['preload_fail'] += 1 + raise HTTPException(status_code=400, detail=str(e)) + +@app.post("/execute") +@telemetry.timed('execute_time_sum', 'execute_time_count') +async def execute(req: ExecRequest): + if req.slice_id not in slices: + raise HTTPException(status_code=404, detail="slice not found") + try: + if req.encrypted: + if 'controller' not in keys: + raise HTTPException(status_code=400, detail='no handshake for controller') + aead = keys['controller'] + nonce = ub64(req.nonce_b64) + ct = ub64(req.input_b64) + blob = aead.decrypt(nonce, ct) + else: + blob = base64.b64decode(req.input_b64) + x = pickle.loads(blob) + out = slices[req.slice_id].apply(x) + out_blob = pickle.dumps(out) + # maybe encrypt response if request was encrypted + if req.encrypted: + nonce, ct = aead.encrypt(out_blob) + with metrics_lock: + metrics['execute_success'] += 1 + return {"encrypted": True, "nonce_b64": b64(nonce), "output_b64": b64(ct)} + else: + with metrics_lock: + metrics['execute_success'] += 1 + return {"output_b64": base64.b64encode(out_blob).decode('ascii')} + except Exception as e: + tb = traceback.format_exc() + print("execute error:\n", tb) + with metrics_lock: + metrics['execute_fail'] += 1 + raise HTTPException(status_code=400, detail=str(e)) + + +@app.get('/metrics') +async def get_metrics(): + # expose computed percentiles based on histogram buckets + with metrics_lock: + out = dict(metrics) + # compute percentiles if histogram buckets present + def compute_percentiles(prefix): + # build sorted buckets from metrics keys + hist_keys = [k for k in out.keys() if k.startswith(f"{prefix}_hist_")] + if not hist_keys: + return None + # extract bucket values and counts + buckets = [] + for k in hist_keys: + b = k.split('_')[-1] + cnt = out.get(k, 0) + try: + if b == '+Inf': + val = float('inf') + else: + val = float(b) + buckets.append((val, cnt)) + except Exception: + continue + buckets.sort(key=lambda x: x[0]) + total = sum(c for _, c in buckets) + if total == 0: + return None + # cumulative to find percentile + def percentile(p): + target = total * p + c = 0 + for val, cnt in buckets: + c += cnt + if c >= target: + return val + return buckets[-1][0] + + return {'p50': percentile(0.5), 'p95': percentile(0.95), 'p99': percentile(0.99)} + + # try common metric prefixes + for metric_prefix in ['preload_time', 'execute_time']: + ps = compute_percentiles(metric_prefix) + if ps: + out[f"{metric_prefix}_p50"] = ps['p50'] + out[f"{metric_prefix}_p95"] = ps['p95'] + out[f"{metric_prefix}_p99"] = ps['p99'] + return JSONResponse(content=out) + +if __name__ == '__main__': + import argparse + p = argparse.ArgumentParser() + p.add_argument('--host', default='127.0.0.1') + p.add_argument('--port', type=int, default=8000) + args = p.parse_args() + uvicorn.run(app, host=args.host, port=args.port)