Skip to content

Commit 83d8013

Browse files
committed
Address PR feedback and improve install flow
1 parent bae4448 commit 83d8013

3 files changed

Lines changed: 212 additions & 14 deletions

File tree

src/specify_cli/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,12 @@ def extension_add(
10861086
console.print(f" {REINSTALL_COMMAND}")
10871087
raise typer.Exit(1)
10881088

1089-
# Enforce install_allowed policy
1089+
# If a different approved source exists, use it instead of prompting.
1090+
installable_info = catalog.get_installable_extension_info(resolved_id)
1091+
if installable_info is not None:
1092+
ext_info = installable_info
1093+
1094+
# Enforce install_allowed policy only when no approved source exists.
10901095
if not ext_info.get("_install_allowed", True):
10911096
catalog_name = ext_info.get("_catalog_name", "community")
10921097
console.print()

src/specify_cli/extensions.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,11 +2112,19 @@ def _catalog_entry_to_dict(self, entry: CatalogEntry) -> Dict[str, Any]:
21122112

21132113
def approve_catalog_install(self, catalog_name: str) -> CatalogEntry:
21142114
"""Persist install permission for a catalog while preserving the stack."""
2115-
active_catalogs = self.get_active_catalogs()
2115+
config_path = self.project_root / ".specify" / self.CONFIG_FILENAME
2116+
2117+
# Base the update on the project-level config if it exists
2118+
if config_path.exists():
2119+
base_catalogs = self._load_catalog_config(config_path) or []
2120+
else:
2121+
# Otherwise, preserve the currently active stack so user-level catalogs remain available.
2122+
base_catalogs = self.get_active_catalogs()
2123+
21162124
updated_catalogs: List[Dict[str, Any]] = []
21172125
approved_entry: Optional[CatalogEntry] = None
21182126

2119-
for entry in active_catalogs:
2127+
for entry in base_catalogs:
21202128
if entry.name == catalog_name:
21212129
entry = self._entry(
21222130
url=entry.url,
@@ -2128,6 +2136,22 @@ def approve_catalog_install(self, catalog_name: str) -> CatalogEntry:
21282136
approved_entry = entry
21292137
updated_catalogs.append(self._catalog_entry_to_dict(entry))
21302138

2139+
# If the catalog wasn't found in the base (e.g., a custom user-level catalog),
2140+
# we pull it from the active catalogs and append it to the project stack.
2141+
if approved_entry is None:
2142+
for entry in self.get_active_catalogs():
2143+
if entry.name == catalog_name:
2144+
entry = self._entry(
2145+
url=entry.url,
2146+
name=entry.name,
2147+
priority=entry.priority,
2148+
install_allowed=True,
2149+
description=entry.description,
2150+
)
2151+
approved_entry = entry
2152+
updated_catalogs.append(self._catalog_entry_to_dict(entry))
2153+
break
2154+
21312155
if approved_entry is None:
21322156
raise ValidationError(
21332157
f"Catalog '{catalog_name}' is not active and cannot be approved"
@@ -2542,6 +2566,36 @@ def get_extension_info(self, extension_id: str) -> Optional[Dict[str, Any]]:
25422566
return ext_data
25432567
return None
25442568

2569+
def get_installable_extension_info(self, extension_id: str) -> Optional[Dict[str, Any]]:
2570+
"""Return the first installable source for an extension, if any.
2571+
2572+
This checks the active catalogs in priority order and returns the
2573+
highest-priority source that is actually allowed to install. It is
2574+
used by the add flow to avoid prompting for approval when a usable
2575+
approved source already exists.
2576+
"""
2577+
for catalog_entry in self.get_active_catalogs():
2578+
try:
2579+
catalog_data = self._fetch_single_catalog(catalog_entry, force_refresh=False)
2580+
except ExtensionError:
2581+
continue
2582+
2583+
ext_data = catalog_data.get("extensions", {}).get(extension_id)
2584+
if not isinstance(ext_data, dict):
2585+
continue
2586+
2587+
if not catalog_entry.install_allowed:
2588+
continue
2589+
2590+
return {
2591+
**ext_data,
2592+
"id": extension_id,
2593+
"_catalog_name": catalog_entry.name,
2594+
"_install_allowed": catalog_entry.install_allowed,
2595+
}
2596+
2597+
return None
2598+
25452599
def download_extension(
25462600
self, extension_id: str, target_dir: Optional[Path] = None
25472601
) -> Path:
@@ -2559,8 +2613,8 @@ def download_extension(
25592613
"""
25602614
import urllib.error
25612615

2562-
# Get extension info from catalog
2563-
ext_info = self.get_extension_info(extension_id)
2616+
# Get the best installable source first, then fall back to the merged view.
2617+
ext_info = self.get_installable_extension_info(extension_id) or self.get_extension_info(extension_id)
25642618
if not ext_info:
25652619
raise ExtensionError(f"Extension '{extension_id}' not found in catalog")
25662620

tests/test_extensions.py

Lines changed: 148 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,6 +3623,7 @@ def fake_open(req, timeout=None):
36233623
}
36243624

36253625
with patch.object(catalog, "get_extension_info", return_value=ext_info), \
3626+
patch.object(catalog, "get_installable_extension_info", return_value=ext_info), \
36263627
patch("specify_cli.authentication.http.urllib.request.build_opener", return_value=mock_opener):
36273628
catalog.download_extension("test-ext", target_dir=temp_dir)
36283629

@@ -3669,6 +3670,7 @@ def fake_open(req, timeout=None):
36693670
}
36703671

36713672
with patch.object(catalog, "get_extension_info", return_value=ext_info), \
3673+
patch.object(catalog, "get_installable_extension_info", return_value=ext_info), \
36723674
patch("specify_cli.authentication.http.urllib.request.build_opener", return_value=mock_opener):
36733675
catalog.download_extension("test-ext", target_dir=temp_dir)
36743676

@@ -4191,6 +4193,58 @@ def test_approve_catalog_install_preserves_active_stack(self, temp_dir):
41914193
assert parsed["catalogs"][1]["name"] == "community"
41924194
assert parsed["catalogs"][1]["install_allowed"] is True
41934195

4196+
def test_approve_catalog_install_preserves_user_level_active_catalogs(self, temp_dir):
4197+
"""Approving a catalog should preserve the full active stack when no project config exists."""
4198+
import yaml as yaml_module
4199+
from unittest.mock import patch
4200+
4201+
project_dir = self._make_project(temp_dir)
4202+
home_dir = temp_dir / "home"
4203+
specify_home = home_dir / ".specify"
4204+
specify_home.mkdir(parents=True)
4205+
with (specify_home / "extension-catalogs.yml").open("w", encoding="utf-8") as f:
4206+
yaml_module.safe_dump(
4207+
{
4208+
"catalogs": [
4209+
{
4210+
"name": "alpha",
4211+
"url": "https://alpha.example.com/catalog.json",
4212+
"priority": 1,
4213+
"install_allowed": False,
4214+
},
4215+
{
4216+
"name": "community",
4217+
"url": ExtensionCatalog.COMMUNITY_CATALOG_URL,
4218+
"priority": 2,
4219+
"install_allowed": False,
4220+
},
4221+
{
4222+
"name": "beta",
4223+
"url": "https://beta.example.com/catalog.json",
4224+
"priority": 3,
4225+
"install_allowed": True,
4226+
},
4227+
]
4228+
},
4229+
f,
4230+
sort_keys=False,
4231+
allow_unicode=True,
4232+
)
4233+
4234+
catalog = ExtensionCatalog(project_dir)
4235+
4236+
with patch("specify_cli.extensions.Path.home", return_value=home_dir):
4237+
approved = catalog.approve_catalog_install("community")
4238+
4239+
config_path = project_dir / ".specify" / "extension-catalogs.yml"
4240+
parsed = yaml_module.safe_load(config_path.read_text(encoding="utf-8"))
4241+
4242+
assert approved.name == "community"
4243+
assert [entry["name"] for entry in parsed["catalogs"]] == ["alpha", "community", "beta"]
4244+
assert parsed["catalogs"][0]["install_allowed"] is False
4245+
assert parsed["catalogs"][1]["install_allowed"] is True
4246+
assert parsed["catalogs"][2]["install_allowed"] is True
4247+
41944248
def test_approve_catalog_install_rejects_symlinked_specify_dir(self, temp_dir):
41954249
"""Approval writes fail closed when .specify resolves outside the project root."""
41964250
project_dir = self._make_project(temp_dir)
@@ -4720,6 +4774,7 @@ def test_add_by_display_name_uses_resolved_id_for_download(self, tmp_path):
47204774
# Mock catalog that returns extension by display name
47214775
mock_catalog = MagicMock()
47224776
mock_catalog.get_extension_info.return_value = None # ID lookup fails
4777+
mock_catalog.get_installable_extension_info.return_value = None # Installable lookup fails
47234778
mock_catalog.search.return_value = [
47244779
{
47254780
"id": "acme-jira-integration",
@@ -4822,14 +4877,20 @@ def test_add_blocked_extension_approval_updates_project_catalog_config(self, tmp
48224877
commands=[],
48234878
)
48244879

4825-
with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
4880+
with (
4881+
patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
48264882
"id": "security-review",
48274883
"name": "Security Review",
48284884
"version": "1.0.0",
48294885
"description": "Security review extension",
48304886
"_catalog_name": "community",
48314887
"_install_allowed": False,
4832-
}), patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), patch("typer.confirm", return_value=True), patch.object(Path, "cwd", return_value=project_dir):
4888+
}),
4889+
patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path),
4890+
patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest),
4891+
patch("typer.confirm", return_value=True),
4892+
patch.object(Path, "cwd", return_value=project_dir),
4893+
):
48334894
result = runner.invoke(
48344895
app,
48354896
["extension", "add", "security-review"],
@@ -4864,14 +4925,19 @@ def record_status(*args, **kwargs):
48644925
call_order.append("spinner")
48654926
return MagicMock()
48664927

4867-
with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
4928+
with (
4929+
patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
48684930
"id": "security-review",
48694931
"name": "Security Review",
48704932
"version": "1.0.0",
48714933
"description": "Security review extension",
48724934
"_catalog_name": "community",
48734935
"_install_allowed": False,
4874-
}), patch("typer.confirm", side_effect=lambda *a, **kw: (call_order.append("confirm"), False)[-1]), patch("specify_cli.console.status", side_effect=record_status), patch.object(Path, "cwd", return_value=project_dir):
4936+
}),
4937+
patch("typer.confirm", side_effect=lambda *a, **kw: (call_order.append("confirm"), False)[-1]),
4938+
patch("specify_cli.console.status", side_effect=record_status),
4939+
patch.object(Path, "cwd", return_value=project_dir),
4940+
):
48754941
result = runner.invoke(
48764942
app,
48774943
["extension", "add", "security-review"],
@@ -4894,14 +4960,18 @@ def test_add_blocked_extension_cancel_leaves_config_unchanged(self, tmp_path):
48944960
project_dir.mkdir()
48954961
(project_dir / ".specify").mkdir()
48964962

4897-
with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
4963+
with (
4964+
patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
48984965
"id": "security-review",
48994966
"name": "Security Review",
49004967
"version": "1.0.0",
49014968
"description": "Security review extension",
49024969
"_catalog_name": "community",
49034970
"_install_allowed": False,
4904-
}), patch("typer.confirm", return_value=False), patch.object(Path, "cwd", return_value=project_dir):
4971+
}),
4972+
patch("typer.confirm", return_value=False),
4973+
patch.object(Path, "cwd", return_value=project_dir),
4974+
):
49054975
result = runner.invoke(
49064976
app,
49074977
["extension", "add", "security-review"],
@@ -4939,14 +5009,80 @@ def test_add_approved_catalog_skips_approval_prompt(self, tmp_path):
49395009
def unexpected_confirm(*args, **kwargs):
49405010
raise AssertionError("Approval prompt should not run for approved catalogs")
49415011

4942-
with patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
5012+
with (
5013+
patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
49435014
"id": "security-review",
49445015
"name": "Security Review",
49455016
"version": "1.0.0",
49465017
"description": "Security review extension",
49475018
"_catalog_name": "default",
49485019
"_install_allowed": True,
4949-
}), patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path), patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest), patch("typer.confirm", side_effect=unexpected_confirm), patch("specify_cli.console.status", return_value=contextlib.nullcontext()), patch.object(Path, "cwd", return_value=project_dir):
5020+
}),
5021+
patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path),
5022+
patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest),
5023+
patch("typer.confirm", side_effect=unexpected_confirm),
5024+
patch("specify_cli.console.status", return_value=contextlib.nullcontext()),
5025+
patch.object(Path, "cwd", return_value=project_dir),
5026+
):
5027+
result = runner.invoke(
5028+
app,
5029+
["extension", "add", "security-review"],
5030+
catch_exceptions=True,
5031+
)
5032+
5033+
assert result.exit_code == 0, result.output
5034+
assert "Catalog Approval Required" not in result.output
5035+
5036+
def test_add_prefers_approved_source_over_blocked_duplicate(self, tmp_path):
5037+
"""If the same extension exists in approved and blocked catalogs, the add flow should skip approval."""
5038+
from typer.testing import CliRunner
5039+
from unittest.mock import patch
5040+
from types import SimpleNamespace
5041+
from specify_cli import app
5042+
import contextlib
5043+
5044+
runner = CliRunner()
5045+
project_dir = tmp_path / "test-project"
5046+
project_dir.mkdir()
5047+
(project_dir / ".specify").mkdir()
5048+
5049+
zip_path = tmp_path / "approved-duplicate.zip"
5050+
zip_path.write_bytes(b"fake-zip")
5051+
mock_manifest = SimpleNamespace(
5052+
id="security-review",
5053+
name="Security Review",
5054+
version="1.0.0",
5055+
description="Security review extension",
5056+
warnings=[],
5057+
commands=[],
5058+
)
5059+
5060+
def unexpected_confirm(*args, **kwargs):
5061+
raise AssertionError("Approval prompt should not run when an approved source exists")
5062+
5063+
with (
5064+
patch("specify_cli.extensions.ExtensionCatalog.get_extension_info", return_value={
5065+
"id": "security-review",
5066+
"name": "Security Review",
5067+
"version": "1.0.0",
5068+
"description": "Security review extension",
5069+
"_catalog_name": "community",
5070+
"_install_allowed": False,
5071+
}),
5072+
patch("specify_cli.extensions.ExtensionCatalog.get_installable_extension_info", return_value={
5073+
"id": "security-review",
5074+
"name": "Security Review",
5075+
"version": "1.0.0",
5076+
"description": "Security review extension",
5077+
"_catalog_name": "default",
5078+
"_install_allowed": True,
5079+
}),
5080+
patch("specify_cli.extensions.ExtensionCatalog.download_extension", return_value=zip_path),
5081+
patch("specify_cli.extensions.ExtensionManager.install_from_zip", return_value=mock_manifest),
5082+
patch("typer.confirm", side_effect=unexpected_confirm),
5083+
patch("specify_cli.console.status", return_value=contextlib.nullcontext()),
5084+
patch.object(Path, "cwd", return_value=project_dir),
5085+
):
49505086
result = runner.invoke(
49515087
app,
49525088
["extension", "add", "security-review"],
@@ -4971,7 +5107,10 @@ def test_add_not_found_still_reports_missing_extension(self, tmp_path):
49715107
mock_catalog.get_extension_info.return_value = None
49725108
mock_catalog.search.return_value = []
49735109

4974-
with patch("specify_cli.extensions.ExtensionCatalog", return_value=mock_catalog), patch.object(Path, "cwd", return_value=project_dir):
5110+
with (
5111+
patch("specify_cli.extensions.ExtensionCatalog", return_value=mock_catalog),
5112+
patch.object(Path, "cwd", return_value=project_dir),
5113+
):
49755114
result = runner.invoke(
49765115
app,
49775116
["extension", "add", "does-not-exist"],

0 commit comments

Comments
 (0)