Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"typer>=0.12.0",
"questionary>=2.0.0",
"pathspec>=0.11.0",
"tomlkit>=0.13.0",
"tomli>=2.0.0; python_version < '3.11'",
]

Expand Down
67 changes: 64 additions & 3 deletions src/runpod_flash/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,23 @@
from typing import Optional

import runpod.cli.groups.config.functions as _runpod_config
import tomlkit

from runpod.cli.groups.config.functions import (
get_credentials,
set_credentials,
)

log = logging.getLogger(__name__)

# runpodctl writes top-level `apikey`/`apiurl` keys into the same config.toml
# that runpod-python uses for its `[default]` profile. We must preserve those
# (and any other unrelated content) when updating flash's api_key, so flash
# login does not clobber runpodctl's credentials. tomlkit round-trips comments,
# foreign keys, sibling profiles, and line endings, so we only mutate the one
# value we own.
_DEFAULT_SECTION = "default"
_API_KEY_FIELD = "api_key"

_OLD_XDG_PATH = Path.home() / ".config" / "runpod" / "credentials.toml"


Expand Down Expand Up @@ -50,7 +59,11 @@ def get_api_key() -> Optional[str]:


def save_api_key(api_key: str) -> Path:
"""Save API key to ~/.runpod/config.toml via runpod-python.
"""Save API key into the [default] section of ~/.runpod/config.toml.

Updates only flash's `[default].api_key` value, preserving any other
content in the file (notably runpodctl's top-level `apikey`/`apiurl`
keys and other profile sections).

Args:
api_key: The API key to save.
Expand All @@ -59,14 +72,62 @@ def save_api_key(api_key: str) -> Path:
Path to the credentials file.
"""
path = get_credentials_path()
set_credentials(api_key, overwrite=True)
path.parent.mkdir(parents=True, exist_ok=True)

# newline="" disables universal-newline translation so tomlkit sees (and
# preserves) the file's original line endings rather than collapsing CRLF.
existing = ""
if path.exists():
with path.open("r", encoding="utf-8", newline="") as f:
existing = f.read()
new_content = _upsert_default_api_key(existing, api_key)
with path.open("w", encoding="utf-8", newline="") as f:
f.write(new_content)

try:
os.chmod(path, 0o600)
except OSError:
pass
return path


def _upsert_default_api_key(content: str, api_key: str) -> str:
"""Update `[default].api_key` in TOML text, leaving the rest intact.

Parses with tomlkit so foreign top-level keys (runpodctl's
`apikey`/`apiurl`), sibling profile sections, comments, and the file's
original line endings are preserved. Only `[default].api_key` is mutated;
a missing `[default]` table is created.

If the existing file is malformed TOML it is already unloadable (runpodctl
cannot read it either), so we cannot preserve it. Rather than block login,
fall back to a fresh document and warn -- the discarded content held no
recoverable credentials.
"""
if content:
try:
doc = tomlkit.parse(content)
except tomlkit.exceptions.TOMLKitError:
log.warning(
"Existing credentials file is not valid TOML; replacing it "
"with a fresh [default] section. Unrelated content was lost.",
exc_info=True,
)
doc = tomlkit.document()
else:
doc = tomlkit.document()

section = doc.get(_DEFAULT_SECTION)
if isinstance(section, (tomlkit.items.Table, tomlkit.items.InlineTable)):
section[_API_KEY_FIELD] = api_key
else:
table = tomlkit.table()
table[_API_KEY_FIELD] = api_key
doc[_DEFAULT_SECTION] = table

return tomlkit.dumps(doc)


def check_and_migrate_legacy_credentials() -> None:
"""Check for credentials at old XDG path and migrate if needed.

Expand Down
182 changes: 182 additions & 0 deletions tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Unit tests for credential storage and retrieval."""

import logging
import os
import sys
from pathlib import Path
from unittest.mock import patch

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib

from runpod_flash.core.credentials import (
get_api_key,
get_credentials_path,
Expand Down Expand Up @@ -69,3 +76,178 @@ def test_sets_restrictive_permissions(self, isolate_credentials_file):
save_api_key("secret")
mode = oct(isolate_credentials_file.stat().st_mode & 0o777)
assert mode == "0o600"

def test_preserves_runpodctl_top_level_keys(self, isolate_credentials_file):
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text(
"apikey = 'rpa_runpodctl_key'\n"
"apiurl = 'https://api.runpod.io/graphql'\n"
"\n"
"[default]\n"
'api_key = "old-flash-key"\n'
)
save_api_key("new-flash-key")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["apikey"] == "rpa_runpodctl_key"
assert parsed["apiurl"] == "https://api.runpod.io/graphql"
assert parsed["default"]["api_key"] == "new-flash-key"

def test_adds_default_section_when_missing(self, isolate_credentials_file):
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text(
"apikey = 'rpa_runpodctl_key'\napiurl = 'https://api.runpod.io/graphql'\n"
)
save_api_key("flash-key")
text = isolate_credentials_file.read_text()
parsed = tomllib.loads(text)
assert parsed["apikey"] == "rpa_runpodctl_key"
assert parsed["apiurl"] == "https://api.runpod.io/graphql"
assert parsed["default"]["api_key"] == "flash-key"
Comment thread
deanq marked this conversation as resolved.

def test_preserves_other_profile_sections(self, isolate_credentials_file):
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text(
"[default]\n"
'api_key = "old"\n'
"\n"
"[staging]\n"
'api_key = "staging-key"\n'
'extra = "preserved"\n'
)
save_api_key("new-default")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["default"]["api_key"] == "new-default"
assert parsed["staging"]["api_key"] == "staging-key"
assert parsed["staging"]["extra"] == "preserved"

def test_creates_file_with_only_default_when_missing(
self, isolate_credentials_file
):
save_api_key("first-key")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed == {"default": {"api_key": "first-key"}}
Comment thread
deanq marked this conversation as resolved.

def test_preserves_inline_comment_on_other_section_header(
self, isolate_credentials_file
):
"""An inline comment on a later section header must not redirect the
update onto that section. Regression for the regex section-boundary bug.
"""
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text(
"[default]\n"
'some_other_field = "x"\n'
"\n"
"[staging] # production environment\n"
'api_key = "staging-key"\n'
)
save_api_key("new-flash-key")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["default"]["api_key"] == "new-flash-key"
assert parsed["staging"]["api_key"] == "staging-key"

def test_updates_default_with_inline_comment_header(self, isolate_credentials_file):
"""`[default] # comment` must update in place, not append a duplicate
`[default]` table that tomllib would refuse to load.
"""
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text(
'[default] # flash profile\napi_key = "old"\n'
)
save_api_key("new-key")
text = isolate_credentials_file.read_text()
parsed = tomllib.loads(text)
assert parsed["default"]["api_key"] == "new-key"
assert text.count("[default]") == 1

def test_handles_default_header_without_trailing_newline(
self, isolate_credentials_file
):
"""A `[default]` header with no trailing newline must not concatenate
the new key onto the header line.
"""
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text("[default]")
save_api_key("new-key")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["default"]["api_key"] == "new-key"

def test_roundtrips_api_key_with_backslash_and_quote(
self, isolate_credentials_file
):
"""Keys containing backslash and double-quote must survive a write/read
round-trip through a real TOML parser.
"""
weird_key = r'a\b"c'
save_api_key(weird_key)
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["default"]["api_key"] == weird_key

def test_roundtrips_api_key_with_control_char(self, isolate_credentials_file):
"""Keys containing control characters (e.g. a tab) must produce valid
TOML rather than a file the next load rejects.
"""
weird_key = "tok\ten\nwith-controls"
save_api_key(weird_key)
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["default"]["api_key"] == weird_key

def test_preserves_crlf_line_endings(self, isolate_credentials_file):
"""A CRLF-edited config must not gain a stray LF-only line."""
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_bytes(
b"apikey = 'rpa_runpodctl_key'\r\n\r\n[default]\r\napi_key = \"old\"\r\n"
)
save_api_key("new-key")
raw = isolate_credentials_file.read_bytes()
assert b"\r\n" in raw
assert b"\n" not in raw.replace(b"\r\n", b"")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["default"]["api_key"] == "new-key"
assert parsed["apikey"] == "rpa_runpodctl_key"

def test_inserts_api_key_into_default_without_existing_key(
self, isolate_credentials_file
):
"""`[default]` with no `api_key` field gains one without disturbing
sibling fields or other sections.
"""
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text(
'[default]\nfoo = "bar"\n\n[other]\nx = 1\n'
)
save_api_key("new-key")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed["default"]["api_key"] == "new-key"
assert parsed["default"]["foo"] == "bar"
assert parsed["other"]["x"] == 1

def test_preserves_comments_and_formatting(self, isolate_credentials_file):
"""Comments and unrelated content survive the round-trip."""
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text(
"# managed by runpodctl\n"
"apikey = 'rpa_key'\n"
"\n"
"[default]\n"
"# flash credentials\n"
'api_key = "old"\n'
)
save_api_key("new-key")
text = isolate_credentials_file.read_text()
assert "# managed by runpodctl" in text
assert "# flash credentials" in text

def test_recovers_from_corrupt_existing_file(
self, isolate_credentials_file, caplog
):
"""A malformed config (already unloadable) must not block login: fall
back to a fresh minimal document and warn rather than raise.
"""
isolate_credentials_file.parent.mkdir(parents=True, exist_ok=True)
isolate_credentials_file.write_text("not valid toml {{{{\n")
with caplog.at_level(logging.WARNING):
save_api_key("new-key")
parsed = tomllib.loads(isolate_credentials_file.read_text())
assert parsed == {"default": {"api_key": "new-key"}}
assert any(record.levelno == logging.WARNING for record in caplog.records)
Loading
Loading