Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 9 additions & 46 deletions src/runpod_flash/cli/commands/login.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import asyncio
import datetime as dt
from typing import Optional

import typer
from rich.console import Console
Expand All @@ -15,20 +13,8 @@

console = Console()

POLL_INTERVAL_SECONDS = 2.0
DEFAULT_TIMEOUT_SECONDS = 600.0


def _parse_expires_at(value: Optional[str]) -> Optional[dt.datetime]:
if not value:
return None
try:
return dt.datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
return None


async def _login(open_browser: bool, timeout_seconds: float) -> None:
async def _login(open_browser: bool) -> None:
async with RunpodGraphQLClient(require_api_key=False) as client:
request = await client.create_flash_auth_request()
request_id = request.get("id")
Expand All @@ -45,46 +31,23 @@ async def _login(open_browser: bool, timeout_seconds: float) -> None:
if open_browser:
typer.launch(auth_url)

expires_at = _parse_expires_at(request.get("expiresAt"))
deadline = dt.datetime.now(dt.timezone.utc) + dt.timedelta(
seconds=timeout_seconds
)
if expires_at and expires_at < deadline:
deadline = expires_at

with console.status("[dim]Waiting for authorization...[/dim]"):
while True:
status_payload = await client.get_flash_auth_request_status(request_id)
status = status_payload.get("status")
api_key = status_payload.get("apiKey")

if api_key and status in {"APPROVED", "CONSUMED"}:
check_and_migrate_legacy_credentials()
path = save_api_key(api_key)
console.print(
f"[green]Logged in.[/green] Credentials saved to [dim]{path}[/dim]"
)
console.print()
return

if status in {"DENIED", "EXPIRED", "CONSUMED"}:
raise RuntimeError(f"login failed: {status.lower()}")
api_key = console.input("Paste the API key shown after authorization: ").strip()
Comment thread
KAJdev marked this conversation as resolved.
Comment thread
KAJdev marked this conversation as resolved.
Comment thread
KAJdev marked this conversation as resolved.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

console.input(...) echoes the pasted key to the terminal, so the credential lands in scrollback, screen recordings, and shoulder-surfing range. This is the exact exposure this PR set out to close, so it shouldn't ship echoed. Use hidden input:

api_key = console.input("Paste the API key shown after authorization: ", password=True).strip()

(or getpass.getpass). This was raised earlier in the review and the thread was resolved, but the code is unchanged at HEAD.


if dt.datetime.now(dt.timezone.utc) >= deadline:
raise RuntimeError("login timed out")
if not api_key:
raise RuntimeError("no api key provided")
Comment thread
KAJdev marked this conversation as resolved.

await asyncio.sleep(POLL_INTERVAL_SECONDS)
check_and_migrate_legacy_credentials()
path = save_api_key(api_key)
console.print(f"[green]Logged in.[/green] Credentials saved to [dim]{path}[/dim]")
Comment on lines +36 to +41
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any non-empty string is written to the credentials file and the user sees "Logged in." — a typo or stray paste becomes a deferred silent failure that only surfaces as a 401 on the next command, far from the cause. At minimum check the format (e.g. rpa_ prefix) before persisting; ideally fire one cheap authenticated call (e.g. myself/whoami) and only print "Logged in." once the key is confirmed usable. Fail loudly with an actionable message otherwise.

console.print()


def login_command(
no_open: bool = typer.Option(False, "--no-open", help="do not open the browser"),
timeout: float = typer.Option(
DEFAULT_TIMEOUT_SECONDS, "--timeout", help="max wait time in seconds"
),
):
"""Authenticate and save a Runpod API key for flash."""
try:
asyncio.run(_login(open_browser=not no_open, timeout_seconds=timeout))
asyncio.run(_login(open_browser=not no_open))
except RuntimeError as exc:
print_error(console, str(exc))
raise typer.Exit(code=1)
14 changes: 0 additions & 14 deletions src/runpod_flash/core/api/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,20 +860,6 @@ async def create_flash_auth_request(self) -> Dict[str, Any]:
result = await self._execute_graphql(mutation)
return result.get("createFlashAuthRequest", {})

async def get_flash_auth_request_status(self, request_id: str) -> Dict[str, Any]:
query = """
query flashAuthRequestStatus($flashAuthRequestId: String!) {
flashAuthRequestStatus(flashAuthRequestId: $flashAuthRequestId) {
id
status
expiresAt
apiKey
}
}
"""
result = await self._execute_graphql(query, {"flashAuthRequestId": request_id})
return result.get("flashAuthRequestStatus", {})

async def close(self):
"""Close the HTTP session."""
if self.session and not self.session.closed:
Expand Down
62 changes: 62 additions & 0 deletions src/runpod_flash/core/resources/request_logs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import re
from dataclasses import dataclass
Expand Down Expand Up @@ -34,6 +35,67 @@ class QBRequestLogBatch:
ready_worker_ids: List[str] = field(default_factory=list)


@dataclass
class SSEEvent:
id: str
data: dict[str, Any]


@dataclass
class LogEvent:
source: str
line: str
ts: str


def parse_sse_event(data: str) -> Optional[SSEEvent]:
"""
Parses an SSE line into a dictionary
"""
if not data:
return None

try:
event_id_line, data_line = filter(bool, data.split("\n"))
event_id = event_id_line.split(":", 1)[1].strip()
data_json = data_line.split(":", 1)[1].strip()
data = json.loads(data_json)
return SSEEvent(id=event_id, data=data)
except Exception as e:
log.error("Failed to parse SSE event: %s", e)
return None
Comment thread
KAJdev marked this conversation as resolved.
Comment on lines +51 to +66
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two problems in this function:

  1. Contract mismatch — drops 100% of events. event_id_line, data_line = filter(bool, data.split("\n")) assumes a multi-line id:\ndata: frame, but the only caller (stream_pod_logs) feeds it one line at a time via response.aiter_lines(). A single line can't unpack into two targets, so this raises ValueError on essentially every line; the broad except swallows it and returns None, so the stream yields nothing. SSE frames must be accumulated across the blank-line delimiter before parsing.

  2. Docstring is wrong. "Parses an SSE line into a dictionary" — it returns an SSEEvent dataclass (not a dict) from a multi-line block (not a line). Please correct it and document the None-on-failure behavior per the project docstring standard (args/returns/raises).

Comment on lines +64 to +66
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except Exception here (and in parse_log_event at line 75) swallows structurally different failures — the always-present ValueError from the unpack bug is indistinguishable in the logs from a genuine json.JSONDecodeError or a one-off malformed line. Per the project convention (no broad except, never swallow silently, errors must be actionable), narrow to the expected types and log the offending payload:

except (ValueError, IndexError, json.JSONDecodeError) as e:
    log.error("Failed to parse SSE event %r: %s", data, e)
    return None



def parse_log_event(data: dict[str, Any]) -> Optional[LogEvent]:
"""
Parses a log event from a dictionary
"""
try:
return LogEvent(source=data["source"], line=data["line"], ts=data["ts"])
except Exception as e:
log.error("Failed to parse log event: %s", e)
return None


async def stream_pod_logs(pod_id: str, tail: int = 0):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SSEEvent, LogEvent, parse_sse_event, parse_log_event, and stream_pod_logs are unrelated to the login security fix this PR describes. Verified across the whole branch: nothing outside this file references any of these symbols, and stream_pod_logs itself has no caller — so this is dead code, with no tests, and (per the other comments) non-functional as written. Recommend pulling the entire block out into its own PR behind tests and a real caller. Bundling it here also enlarges the security review surface for no benefit.

"""
Streams logs from pod using SSE
"""
Comment thread
KAJdev marked this conversation as resolved.
if tail < 0:
raise ValueError("tail must be greater than 0")
Comment thread
KAJdev marked this conversation as resolved.
Comment on lines +84 to +85
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guard rejects tail < 0 but the message says "tail must be greater than 0" — yet tail=0 is the default and is allowed. The message should read "tail must be greater than or equal to 0" (or "must not be negative") to match the predicate.


url = f"{RUNPOD_HAPI_URL}/v1/pod/{pod_id}/logs?stream=true&tail={tail}"

async with get_authenticated_httpx_client() as client:
async with client.get(url) as response:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_authenticated_httpx_client() returns an httpx.AsyncClient, whose .get(url) returns a coroutine resolving to a buffered Response — not an async context manager. async with client.get(url) as response: raises TypeError/AttributeError on first call, and aiter_lines() requires a stream-opened response anyway. Use the streaming API and surface HTTP errors:

async with client.stream("GET", url) as response:
    response.raise_for_status()
    async for line in response.aiter_lines():
        ...

Without raise_for_status(), a 401/404/500 (e.g. expired key — directly relevant to this PR) silently yields an empty stream. Every other call site in this file correctly uses await client.get(...).

async for line in response.aiter_lines():
event = parse_sse_event(line)
if event:
log_event = parse_log_event(event.data)
if log_event:
yield log_event
Comment thread
KAJdev marked this conversation as resolved.

Comment thread
KAJdev marked this conversation as resolved.

class QBRequestLogFetcher:
def __init__(
self,
Expand Down
74 changes: 25 additions & 49 deletions tests/unit/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,6 @@

import pytest

from runpod_flash.cli.commands.login import _parse_expires_at


class TestParseExpiresAt:
def test_iso_format(self):
result = _parse_expires_at("2026-03-01T12:00:00Z")
assert result is not None
assert result.year == 2026

def test_none_input(self):
assert _parse_expires_at(None) is None

def test_empty_string(self):
assert _parse_expires_at("") is None

def test_invalid_string(self):
assert _parse_expires_at("not-a-date") is None


class TestGraphQLClientNoKeyForLogin:
"""Login mutations must not send stored credentials."""
Expand Down Expand Up @@ -61,16 +43,13 @@ def test_require_api_key_true_loads_key(self):
assert client.api_key == "loaded-key"


def _make_mock_client(**status_return):
def _make_mock_client():
"""Build an AsyncMock that works as an async context manager."""
client = AsyncMock()
client.create_flash_auth_request.return_value = {
"id": "req-123",
"expiresAt": None,
}
client.get_flash_auth_request_status.return_value = status_return
# _login uses `async with RunpodGraphQLClient(...) as client:`,
# so __aenter__ must return the same mock instance
client.__aenter__.return_value = client
return client

Expand All @@ -85,42 +64,39 @@ def _get_login_fn():


class TestLoginFlow:
async def test_login_denied(self):
mock_client = _make_mock_client(status="DENIED", apiKey=None)
async def test_login_saves_pasted_key(self, isolate_credentials_file):
mock_client = _make_mock_client()
_login = _get_login_fn()

with patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
with (
patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
),
patch("runpod_flash.cli.commands.login.console") as mock_console,
):
with pytest.raises(RuntimeError, match="login failed: denied"):
await _login(open_browser=False, timeout_seconds=5)

async def test_login_approved_saves_key(self, isolate_credentials_file):
mock_client = _make_mock_client(status="APPROVED", apiKey="fresh-api-key")
_login = _get_login_fn()

with patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
):
await _login(open_browser=False, timeout_seconds=5)
mock_console.input.return_value = "pasted-api-key"
await _login(open_browser=False)
assert isolate_credentials_file.exists()
assert "fresh-api-key" in isolate_credentials_file.read_text()
assert "pasted-api-key" in isolate_credentials_file.read_text()

async def test_login_expired(self):
mock_client = _make_mock_client(status="EXPIRED", apiKey=None)
async def test_login_empty_key_raises(self):
mock_client = _make_mock_client()
_login = _get_login_fn()

with patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
with (
patch(
"runpod_flash.cli.commands.login.RunpodGraphQLClient",
return_value=mock_client,
),
patch("runpod_flash.cli.commands.login.console") as mock_console,
):
with pytest.raises(RuntimeError, match="login failed: expired"):
await _login(open_browser=False, timeout_seconds=5)
mock_console.input.return_value = " "
with pytest.raises(RuntimeError, match="no api key provided"):
await _login(open_browser=False)

async def test_no_request_id_raises(self):
mock_client = _make_mock_client(status="APPROVED", apiKey="key")
mock_client = _make_mock_client()
mock_client.create_flash_auth_request.return_value = {}
_login = _get_login_fn()

Expand All @@ -129,4 +105,4 @@ async def test_no_request_id_raises(self):
return_value=mock_client,
):
with pytest.raises(RuntimeError, match="auth request failed"):
await _login(open_browser=False, timeout_seconds=5)
await _login(open_browser=False)
Loading
Loading