Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion README.new.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,30 @@ pip install -e .

```python
from prove_shared import MongoDBHandler, AsyncAuth, Status
from prove_shared.mongo_handler import requestItemProcessing
from prove_shared.database.mongo import requestItemProcessing
```

### Package layout

```text
prove-shared/
pyproject.toml
config.yaml
src/
prove_shared/
__init__.py
auth.py
file_utils.py
logger.py
objects.py
queue_manager.py
wikidata_utils.py
database/
__init__.py
interface.py # DataStore (ABC) — the contract
mongo.py # MongoDBHandler implementation
postgres.py # PostgreSQLHandler stub
orchestrator.py # get_database() + DatabaseOrchestrator
```

## Setup Instructions
Expand Down
27 changes: 27 additions & 0 deletions prove-api/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,30 @@ evidence_selection:
n_top_sentences: 5
score_threshold: 0
token_size: 512

# ----------------------------------------------------------------------------
# Database backend selection (read by prove_shared.database.get_database)
# ----------------------------------------------------------------------------
# primary: "mongo" | "postgres" — the DB that owns the data
# fallback: "none" | "mongo" | "postgres" — secondary backend (optional)
# mode: "single" | "dual-write" — "dual-write" mirrors every write
# to the fallback during migration
# auto_fallback_on_read: if true, failed reads on primary
# retry against the fallback (off
# by default — silent fallbacks
# hide real outages)
# ----------------------------------------------------------------------------
database:
primary: mongo
fallback: none
mode: single
auto_fallback_on_read: false

mongo:
connection_string: "mongodb://localhost:27017/"
max_retries: 3

# TODO: populate when the Postgres migration begins.
postgres:
dsn: "postgresql://localhost/prove"
max_retries: 3
124 changes: 54 additions & 70 deletions prove-api/custom_decorators.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,44 @@
# @repo: api
# @description: Flask decorators for request logging (@log_request) and API key authentication (@api_required); includes StatsDBHandler for usage tracking
# @description: Flask decorators for request logging (@log_request) and API key authentication (@api_required). Usage-logging now routes through MongoDBHandler.log_usage() — no more ad-hoc StatsDBHandler subclass.
from datetime import datetime, timezone
from base64 import b64encode, b64decode
from base64 import b64decode
from functools import wraps
from flask import request
import threading
import time
from typing import Any, Union

from pymongo import MongoClient

try:
from utils_api import get_ip_location, logger
from local_secrets import SOURCE, API_KEY, PRIVATE_KEY
except ImportError:
from api.utils_api import get_ip_location, logger
from api.local_secrets import SOURCE, API_KEY, PRIVATE_KEY

from prove_shared.mongo_handler import MongoDBHandler
from prove_shared.database import get_database
from prove_shared.auth import AsyncAuth


class StatsDBHandler(MongoDBHandler):
Comment thread
thedeepaksengar marked this conversation as resolved.
def __init__(self, connection_string="mongodb://localhost:27017/", max_retries=3):
super().__init__(connection_string, max_retries)

def connect(self, max_retries: int, connection_string: str):
for attempt in range(self.max_retries):
try:
self.client = MongoClient(self.connection_string)
self.client.server_info()
self.db = self.client['service_usage']
self.usage_collection = self.db['usage']
print("Successfully connected to StatsDB")
return True
except Exception as e:
print(f"StatsDB connection attempt {attempt + 1} failed: {e}")
if attempt == self.max_retries - 1:
raise ConnectionError("Failed to connect to StatsDB") from e
time.sleep(5) # Wait before retry

def close(self):
"""Closes the MongoDB connection."""
if self.client:
self.client.close()
print("MongoDB connection closed")

def __enter__(self):
"""Enables use with 'with' statement."""
self.connect()
return self # Allows access to the instance in 'with' block

def __exit__(self, exc_type, exc_value, traceback):
"""Ensures the connection is closed when exiting 'with' block."""
self.close()
# ---------------------------------------------------------------------------
# Shared database handle
# ---------------------------------------------------------------------------
# One backend instance per process is enough. `get_database()` reads the
# app's config.yaml and returns whichever implementation is configured
# (Mongo today, Postgres later, or an orchestrator that writes to both
# during migration). The per-request `with StatsDBHandler()` pattern used
# previously paid a connect cost on every HTTP hit — this avoids that.
_db = get_database()


def log_request(func):
"""
Fire-and-forget usage logger for any API route.

Writes the request metadata to the production usage DB on a background
thread so the actual response latency is unaffected. Logging failures are
swallowed inside the handler — a usage-log hiccup must never surface to
the caller as a 500.
"""
@wraps(func)
def wrapper(*args, **kwargs):
method = request.method
Expand All @@ -77,11 +59,12 @@ def wrapper(*args, **kwargs):
end_time = time.monotonic()
elapsed_time = end_time - start_time

# Only log in the production environment (SOURCE is set per-deploy).
if SOURCE != 'server':
return response

threading.Thread(
target=log_usage_information,
target=_log_usage_information,
args=(timestamp, method, url, headers, body, elapsed_time),
daemon=True
).start()
Expand All @@ -90,43 +73,46 @@ def wrapper(*args, **kwargs):
return wrapper


def log_usage_information(
def _log_usage_information(
timestamp: str,
method: str,
url: str,
headers: dict[str, Any],
body: dict[str, Any],
elapsed_time: float
elapsed_time: float,
) -> None:
try:
with StatsDBHandler() as db:
ip = headers.pop("X-Real-Ip", None)
headers.pop("X-Forwarded-For", None)

if ip:
try:
headers["location"] = get_ip_location(ip)
except KeyError:
headers["X-Real-Ip"] = ip
logger.error(f"when retrieving location for {ip}")
except ConnectionError:
headers["X-Real-Ip"] = ip
logger.error("failed to retrieve location, check API")

db.usage_collection.insert_one({
"method": method,
"url": url,
"headers": headers,
"body": body,
"timestamp": timestamp,
"execution_time": elapsed_time
})
except ConnectionError as e:
print(f"Failed to log usage information from StatsDB: {e}")
return
"""
Build a usage record and hand it to the database handler.

The IP-geolocation enrichment can raise (KeyError on unknown IPs,
ConnectionError if the geo API is down). We handle those here so the
record is still written without location data rather than being dropped.
"""
ip = headers.pop("X-Real-Ip", None)
headers.pop("X-Forwarded-For", None)

if ip:
try:
headers["location"] = get_ip_location(ip)
except KeyError:
headers["X-Real-Ip"] = ip
logger.error(f"when retrieving location for {ip}")
except ConnectionError:
headers["X-Real-Ip"] = ip
logger.error("failed to retrieve location, check API")

_db.log_usage({
"method": method,
"url": url,
"headers": headers,
"body": body,
"timestamp": timestamp,
"execution_time": elapsed_time,
})


def api_required(func):
"""Reject requests that don't carry a valid AsyncAuth-signed API key."""
@wraps(func)
def decorator(*args, **kwargs):
if not request.json:
Expand All @@ -136,7 +122,5 @@ def decorator(*args, **kwargs):
api_key = b64decode(api_key)
if api_key is None or not AsyncAuth.is_valid(api_key):
return {"message": "Please provide a valid API key."}, 403
else:
return func(*args, **kwargs)
return func(*args, **kwargs)
return decorator

Loading