diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8c810e7..752f935 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -26,6 +26,8 @@ jobs: steps: - name: 💾 Check out repository uses: actions/checkout@v6 + with: + fetch-depth: 0 - name: 🪪 Configure git identity for tests run: | diff --git a/copier_python/__main__.py b/copier_python/__main__.py new file mode 100644 index 0000000..d4cd231 --- /dev/null +++ b/copier_python/__main__.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import os +import shutil +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, Annotated + +from rich import print # noqa: A004 +from rich.console import Console +from rich.panel import Panel +from rich.table import Table +from rich.text import Text +from typer import Argument, Context, Exit, Option, Typer + +from .repo import RepoTarget +from .update import UpdateAction + +if TYPE_CHECKING: + from collections.abc import Sequence + + +console = Console() + + +class Args: + Repos = Annotated[ + list[str], + Argument( + help=( + "Repositories to update. Accepts gh:user/repo" + " or github.com/user/repo." + ) + ), + ] + DryRun = Annotated[ + bool, Option("--dry-run", "-n", help="Skip push and PR creation.") + ] + + +cli = Typer( + help="copier-python utilities", + add_completion=False, + no_args_is_help=True, + pretty_exceptions_enable=False, +) + + +@cli.callback() +def setup(ctx: Context) -> None: + pass + + +class UpdateStatus(Enum): + UPDATED = "bold green" + CURRENT = "bold blue" + FAILED = "bold red" + + @property + def formatted_name(self) -> Text: + return Text(f"{self.name:<7}", style=self.value) + + +@dataclass +class UpdateResult: + repo: RepoTarget + status: UpdateStatus + exception: Exception | None = None + pr_url: str | None = None + + +@cli.command() +def update( + repos: Args.Repos, + *, + dry_run: Args.DryRun = False, + branch: Annotated[str, Option(help="Branch name to create.")] = "updates", +) -> None: + """Apply copier-python template updates to one or more downstream repos.""" + results = [] + + repo_targets = {(target := RepoTarget(repo)).url: target for repo in repos} + for target in repo_targets.values(): + try: + pr_url = UpdateAction(target, branch=branch, dry_run=dry_run)() + if pr_url: + results.append( + UpdateResult(target, UpdateStatus.UPDATED, pr_url=pr_url) + ) + else: + results.append(UpdateResult(target, UpdateStatus.CURRENT)) + except Exception as exc: # noqa: BLE001, PERF203 + console.print_exception() + results.append( + UpdateResult(target, status=UpdateStatus.FAILED, exception=exc) + ) + + _print_summary(results, dry_run=dry_run) + if any(r.status == UpdateStatus.FAILED for r in results): + raise Exit(1) + + +def _print_summary(results: Sequence[UpdateResult], *, dry_run: bool) -> None: + if not results: + return + grid = Table.grid(padding=(0, 1), expand=True) + grid.add_column() + grid.add_column() + grid.add_column() + for result in sorted(results, key=lambda r: r.repo.github_repo): + grid.add_row( + result.status.formatted_name, + Text.from_markup( + f"[link={result.repo.url}]{result.repo.github_repo}[/link]", + style="bold", + ), + Text(str(result.pr_url or result.exception)), + ) + title = Text("Update Results", style="bold") + if dry_run: + title += Text(" (dry run)", style="green") + print( + Panel(grid, title=title, title_align="left", expand=False, padding=1) + ) + + +def setup_env() -> None: + os.environ.pop("VIRTUAL_ENV", default=None) + os.environ["TERMINAL_WIDTH"] = str( + min(shutil.get_terminal_size().columns, 100) + ) + + +def main() -> None: + setup_env() + cli() + + +if __name__ == "__main__": + main() diff --git a/copier_python/repo.py b/copier_python/repo.py new file mode 100644 index 0000000..e7a31d1 --- /dev/null +++ b/copier_python/repo.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import os +import re +import subprocess +import tempfile +from contextlib import contextmanager, suppress +from dataclasses import InitVar, dataclass, field +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar + +import yaml +from rich import print # noqa: A004 +from rich.panel import Panel +from rich.text import Text +from typing_extensions import Self + +if TYPE_CHECKING: + from collections.abc import Generator, Sequence + + +@dataclass(unsafe_hash=True) +class RepoTarget: + arg: InitVar[str | Path] + github_repo: str = field(init=False) + + _GITHUB_REGEXES: ClassVar[Sequence[str]] = [ + r"(\.git)?/?$", + r"^gh\:", + r"^(https?://)?github.com/", + r"^((git\+)?ssh://)?([a-zA-Z0-9]+)@github.com:", + ] + + def __post_init__(self, arg: str | Path) -> None: + arg = str(arg).strip() + if ((path := Path(arg)).is_dir()) and (root := self.repo_root(path)): + arg = root + for regex in self._GITHUB_REGEXES: + arg = re.sub(regex, "", arg) + if not re.fullmatch(r"[\w-]+\/[\w-]+", arg): + raise ValueError(arg) + self.github_repo = arg + + def repo_root(self, path: Path) -> str | None: + with suppress(subprocess.CalledProcessError): + return subprocess.check_output( + ["git", "config", "--local", "remote.origin.url"], + text=True, + cwd=path, + ).strip() + return None + + @cached_property + def name(self) -> str: + return self.github_repo.split("/")[-1] + + @property + def url(self) -> str: + return f"https://github.com/{self.github_repo}" + + @cached_property + def push_url(self) -> str: + return f"git@github.com:{self.github_repo}.git" + + +@dataclass +class RepoWorktree: + path: Path + repo: RepoTarget + branch: str + + @classmethod + @contextmanager + def clone( + cls, repo: RepoTarget, branch: str + ) -> Generator[Self, None, None]: + with tempfile.TemporaryDirectory() as td: + repo_dir = Path(td) / "worktree" + cls.run_in(["git", "clone", repo.url, str(repo_dir)], repo=repo) + for cmd in ( + [ + *["git", "remote", "set-url", "--push", "origin"], + repo.push_url, + ], + ["poe", "setup"], + ["git", "checkout", "-b", branch], + ): + cls.run_in(cmd, repo=repo, cwd=repo_dir) + yield cls(path=repo_dir, repo=repo, branch=branch) + + @classmethod + def run_in( + cls, cmd: list[str], *, repo: RepoTarget, **kwargs: Any + ) -> subprocess.CompletedProcess[str]: + kwargs.setdefault("check", True) + kwargs.setdefault("text", True) + cmd_text = Text(" ".join(cmd), style="color(153)") + txt = Text.assemble( + Text(f"{repo.github_repo} => ", style="bold"), + cmd_text, + ) + print(Panel(txt, expand=False, border_style="white dim")) + return subprocess.run(cmd, **kwargs) # noqa: S603 PLW1510 + + def run( + self, cmd: list[str], **kwargs: Any + ) -> subprocess.CompletedProcess[str]: + kwargs.setdefault("cwd", self.path) + return self.run_in(cmd, repo=self.repo, **kwargs) + + def git_status(self) -> list[str]: + return self.run( + ["git", "status", "--porcelain"], capture_output=True + ).stdout.splitlines() + + @staticmethod + def has_conflicts(status: list[str]) -> bool: + conflict_codes = {"UU", "AA", "DD", "AU", "UA", "DU", "UD"} + return any(line[:2] in conflict_codes for line in status) + + @property + def template_ref(self) -> str: + return yaml.safe_load( # type: ignore[no-any-return] + (self.path / ".copier-answers.yml").read_text() + ).get("_commit") + + def shell(self) -> None: + self.run([os.environ.get("SHELL", "/bin/bash")], check=False) + + def open_pr(self, title: str, body: str) -> str: + result = self.run( + [ + "gh", + "pr", + "create", + "--title", + title, + "--body", + body, + "--head", + self.branch, + ], + capture_output=True, + check=False, + ) + if result.returncode == 0: + return result.stdout.strip() + view = self.run( + ["gh", "pr", "view", "--json", "url", "--jq", ".url"], + capture_output=True, + check=False, + ) + if view.returncode == 0: + return view.stdout.strip() + raise RuntimeError(f"gh pr create failed: {result.stderr.strip()}") diff --git a/copier_python/update.py b/copier_python/update.py new file mode 100644 index 0000000..6717c52 --- /dev/null +++ b/copier_python/update.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import os +import subprocess +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from .repo import RepoWorktree + +if TYPE_CHECKING: + from .repo import RepoTarget + + +@dataclass +class UpdateAction: + repo: RepoTarget + branch: str + dry_run: bool = False + + def __call__(self) -> str | None: + with RepoWorktree.clone(self.repo, branch=self.branch) as worktree: + return self._run(worktree) + + def _run(self, repo: RepoWorktree) -> str | None: + """Run copier update in the repo worktree.""" + copier_status = json.loads( + repo.run( + ["copier", "check-update", "--output-format", "json"], + capture_output=True, + ).stdout.strip() + ) + start_ref = "v" + copier_status["current_version"] + end_ref = "v" + copier_status["latest_version"] + if not copier_status.get("update_available", False): + return None + repo.run(["copier", "update", "-l"]) + + status = repo.git_status() + if not status: + return None + + if repo.has_conflicts(status): + print( # noqa: T201 + "Conflicts detected." + " Resolve them and exit the shell to continue." + ) + repo.shell() + status = repo.git_status() + if repo.has_conflicts(status): + raise RuntimeError("Conflicts remain, aborting") + + try: + repo.run(["uv", "run", "poe", "lt"]) + except subprocess.CalledProcessError: + print( # noqa: T201 + "Lint/test failed. Fix errors and exit the shell to continue." + ) + repo.shell() + + title = "Apply template updates" + body = "" + if start_ref and end_ref and start_ref != end_ref: + ref_range = f"{start_ref}...{end_ref}" + body = os.linesep.join( + ( + f"Applied updates from template: {ref_range}", + f"{repo.repo.url}/compare/{ref_range}", + ) + ) + repo.run(["git", "add", "-A"]) + repo.run(["git", "commit", "-m", f"{title}\n\n{body}".strip()]) + + if self.dry_run: + return None + + repo.run( + ["git", "push", "-u", "origin", repo.branch, "--force-with-lease"] + ) + return repo.open_pr(title, body) diff --git a/pyproject.toml b/pyproject.toml index 55c683f..a644f54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,13 @@ classifiers = [ "Typing :: Typed", ] keywords = [] -dependencies = [] +dependencies = [ + "pyyaml>=6", + "typer>=0.9", +] + +[project.scripts] +copier-python = "copier_python.__main__:main" [project.urls] Homepage = "https://smkent.github.io/copier-python" @@ -50,6 +56,7 @@ dev = [ "ruff>=0.15", "syrupy>=5", "ty>=0.0.37", + "typing-extensions>=4", ] docs = [ "zensical>=0.0.42", diff --git a/tests/conftest.py b/tests/conftest.py index 84b358d..71b9a14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,20 @@ -"""Shared test fixtures.""" +from __future__ import annotations +import subprocess import warnings -from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import copier import pytest +from copier_python.__main__ import setup_env + +from .utils import DisallowCallable + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + TEMPLATE_ROOT = Path(__file__).parent.parent DEFAULT_DATA: dict[str, Any] = { @@ -26,23 +33,46 @@ } +@pytest.fixture(scope="session", autouse=True) +def ensure_env() -> None: + setup_env() + + +@pytest.fixture(autouse=True) +def disallow_subprocess( + request: pytest.FixtureRequest, +) -> Iterator[DisallowCallable]: + with DisallowCallable(request, subprocess.Popen, "__init__")() as mock: + yield mock + + @pytest.fixture -def render_template(tmp_path: Path) -> Callable[..., Path]: - def _render(**kwargs: Any) -> Path: +def allow_subprocess(disallow_subprocess: DisallowCallable) -> Iterator[None]: + with disallow_subprocess.pause(): + yield + + +@pytest.fixture +def render_template( + tmp_path: Path, disallow_subprocess: DisallowCallable +) -> Callable[..., Path]: + def _render(*, vcs_ref: str = "HEAD", **kwargs: Any) -> Path: + worktree = tmp_path / "project" with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=copier.errors.DirtyLocalWarning, ) - copier.run_copy( - src_path=str(TEMPLATE_ROOT), - dst_path=str(tmp_path), - data={**DEFAULT_DATA, **(kwargs or {})}, - vcs_ref="HEAD", - defaults=True, - overwrite=True, - unsafe=False, - ) - return tmp_path + with disallow_subprocess.pause(): + copier.run_copy( + src_path=str(TEMPLATE_ROOT), + dst_path=str(worktree), + data={**DEFAULT_DATA, **(kwargs or {})}, + vcs_ref=vcs_ref, + defaults=True, + overwrite=True, + unsafe=False, + ) + return worktree return _render diff --git a/tests/template/__init__.py b/tests/template/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__snapshots__/test_agents.ambr b/tests/template/__snapshots__/test_agents.ambr similarity index 100% rename from tests/__snapshots__/test_agents.ambr rename to tests/template/__snapshots__/test_agents.ambr diff --git a/tests/__snapshots__/test_compose.ambr b/tests/template/__snapshots__/test_compose.ambr similarity index 100% rename from tests/__snapshots__/test_compose.ambr rename to tests/template/__snapshots__/test_compose.ambr diff --git a/tests/__snapshots__/test_contributing.ambr b/tests/template/__snapshots__/test_contributing.ambr similarity index 100% rename from tests/__snapshots__/test_contributing.ambr rename to tests/template/__snapshots__/test_contributing.ambr diff --git a/tests/__snapshots__/test_dockerfile.ambr b/tests/template/__snapshots__/test_dockerfile.ambr similarity index 100% rename from tests/__snapshots__/test_dockerfile.ambr rename to tests/template/__snapshots__/test_dockerfile.ambr diff --git a/tests/__snapshots__/test_docs.ambr b/tests/template/__snapshots__/test_docs.ambr similarity index 100% rename from tests/__snapshots__/test_docs.ambr rename to tests/template/__snapshots__/test_docs.ambr diff --git a/tests/__snapshots__/test_license.ambr b/tests/template/__snapshots__/test_license.ambr similarity index 100% rename from tests/__snapshots__/test_license.ambr rename to tests/template/__snapshots__/test_license.ambr diff --git a/tests/__snapshots__/test_pyproject.ambr b/tests/template/__snapshots__/test_pyproject.ambr similarity index 100% rename from tests/__snapshots__/test_pyproject.ambr rename to tests/template/__snapshots__/test_pyproject.ambr diff --git a/tests/__snapshots__/test_readme.ambr b/tests/template/__snapshots__/test_readme.ambr similarity index 100% rename from tests/__snapshots__/test_readme.ambr rename to tests/template/__snapshots__/test_readme.ambr diff --git a/tests/__snapshots__/test_workflows.ambr b/tests/template/__snapshots__/test_workflows.ambr similarity index 100% rename from tests/__snapshots__/test_workflows.ambr rename to tests/template/__snapshots__/test_workflows.ambr diff --git a/tests/__snapshots__/test_zensical.ambr b/tests/template/__snapshots__/test_zensical.ambr similarity index 100% rename from tests/__snapshots__/test_zensical.ambr rename to tests/template/__snapshots__/test_zensical.ambr diff --git a/tests/test_agents.py b/tests/template/test_agents.py similarity index 100% rename from tests/test_agents.py rename to tests/template/test_agents.py diff --git a/tests/test_compose.py b/tests/template/test_compose.py similarity index 100% rename from tests/test_compose.py rename to tests/template/test_compose.py diff --git a/tests/test_contributing.py b/tests/template/test_contributing.py similarity index 100% rename from tests/test_contributing.py rename to tests/template/test_contributing.py diff --git a/tests/test_dockerfile.py b/tests/template/test_dockerfile.py similarity index 100% rename from tests/test_dockerfile.py rename to tests/template/test_dockerfile.py diff --git a/tests/test_docs.py b/tests/template/test_docs.py similarity index 100% rename from tests/test_docs.py rename to tests/template/test_docs.py diff --git a/tests/test_license.py b/tests/template/test_license.py similarity index 100% rename from tests/test_license.py rename to tests/template/test_license.py diff --git a/tests/test_pyproject.py b/tests/template/test_pyproject.py similarity index 100% rename from tests/test_pyproject.py rename to tests/template/test_pyproject.py diff --git a/tests/test_readme.py b/tests/template/test_readme.py similarity index 100% rename from tests/test_readme.py rename to tests/template/test_readme.py diff --git a/tests/test_structure.py b/tests/template/test_structure.py similarity index 100% rename from tests/test_structure.py rename to tests/template/test_structure.py diff --git a/tests/test_template.py b/tests/template/test_template.py similarity index 90% rename from tests/test_template.py rename to tests/template/test_template.py index ac0225d..07be68b 100644 --- a/tests/test_template.py +++ b/tests/template/test_template.py @@ -4,7 +4,10 @@ from collections.abc import Callable from pathlib import Path +import pytest + +@pytest.mark.usefixtures("allow_subprocess") def test_template_render_lint_test( render_template: Callable[..., Path], ) -> None: diff --git a/tests/test_version.py b/tests/template/test_version.py similarity index 100% rename from tests/test_version.py rename to tests/template/test_version.py diff --git a/tests/test_workflows.py b/tests/template/test_workflows.py similarity index 100% rename from tests/test_workflows.py rename to tests/template/test_workflows.py diff --git a/tests/test_zensical.py b/tests/template/test_zensical.py similarity index 100% rename from tests/test_zensical.py rename to tests/template/test_zensical.py diff --git a/tests/test_update.py b/tests/test_update.py new file mode 100644 index 0000000..cc66ed2 --- /dev/null +++ b/tests/test_update.py @@ -0,0 +1,429 @@ +"""Tests for the copier-python update management command.""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +import tempfile +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, _Call, call, patch + +import pytest +import yaml + +from copier_python.__main__ import main, update + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator, Sequence + + from .utils import DisallowCallable + + +DEFAULT_BRANCH_NAME = "fhqwhgads" + + +@pytest.fixture(params=["v0.6.0"]) +def start_ref(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture(params=["v0.8.0"]) +def end_ref(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture +def mock_shell() -> Path: + shell = "/bin/false" + os.environ["SHELL"] = shell + return Path(shell) + + +@pytest.fixture +def origin(tmp_path: Path, disallow_subprocess: DisallowCallable) -> Path: + origin_path = tmp_path / "origin.git" + with disallow_subprocess.pause(): + subprocess.run( # noqa: S603 + [ + "git", + "init", + "--bare", + str(origin_path), + "-b", + DEFAULT_BRANCH_NAME, + ], + check=True, + capture_output=True, + ) + return origin_path + + +@pytest.fixture +def create_project( + disallow_subprocess: DisallowCallable, + render_template: Callable[..., Path], + origin: Path, +) -> Callable[..., Path]: + def _create( + vcs_ref: str, + postcreate: Callable[[Path], None] | None = None, + **kwargs: Any, + ) -> Path: + project = render_template(vcs_ref=vcs_ref) + for cmd in ( + ["git", "init", "-b", DEFAULT_BRANCH_NAME], + ["git", "remote", "add", "origin", str(origin)], + ["git", "add", "."], + ["git", "commit", "-m", "init from template"], + ["git", "push", "-u", "origin", DEFAULT_BRANCH_NAME], + ): + with disallow_subprocess.pause(): + subprocess.run( # noqa: S603 + cmd, cwd=project, check=True, capture_output=True + ) + if postcreate: + postcreate(project) + return project + + return _create + + +@pytest.fixture +def mock_temp_dir(tmp_path: Path) -> Iterator[Path]: + with patch.object( + tempfile.TemporaryDirectory, "__enter__", return_value=tmp_path + ): + yield tmp_path + + +@pytest.fixture +def mock_worktree(mock_temp_dir: Path) -> Path: + wt = mock_temp_dir / "mock_worktree" + wt.mkdir(exist_ok=True) + + with Path.open(wt / ".copier-answers.yml", "w") as f: + yaml.dump({"_commit": "deadbeef"}, f) + return wt + + +@pytest.fixture( + params=( + pytest.param(True, id="dry_run"), + pytest.param(False, id="live"), + ) +) +def dry_run(request: pytest.FixtureRequest) -> bool: + return request.param + + +@dataclass +class ExpectRun: + disallow_subprocess: DisallowCallable + mock_shell: Path + origin: Path + expected_calls: list[_Call] = field(default_factory=list, init=False) + run_mocks: list[tuple[Callable[Sequence[str], bool], object]] = field( + default_factory=list + ) + + @contextmanager + def patch( + self, + *, + mock: bool = True, + has_any: bool = False, + shell_callback: Callable[[], None] | None = None, + ) -> Iterator[MagicMock]: + + subp_run = subprocess.run + + def _run(cmd: Sequence[str], *args: Any, **kwargs: Any) -> Any: + for match_mock, mock_value in self.run_mocks: + if match_mock(cmd): + return mock_value + if cmd[0] == str(self.mock_shell): + if shell_callback: + return shell_callback() + return MagicMock() + if mock: + if tuple(cmd[:2]) == ("copier", "check-update"): + return SimpleNamespace( + stdout=json.dumps( + { + "update_available": False, + "current_version": "11.38.0", + "latest_version": "11.38.0", + } + ) + ) + return MagicMock() + cmd = [ + ( + str(self.origin) + if arg.startswith( + ("https://github.com/", "git@github.com:") + ) + else arg + ) + for arg in cmd + ] + if cmd[0] == "gh": + cmd = ["echo", *cmd] + with self.disallow_subprocess.pause(): + return subp_run(cmd, *args, **kwargs) + + with patch.object(subprocess, "run", side_effect=_run) as mock_run: + yield mock_run + if has_any: + mock_run.assert_has_calls(self.expected_calls) + else: + assert mock_run.call_args_list == self.expected_calls + + def expect(self, cmd: Sequence[str], **kwargs: Any) -> None: + kwargs.setdefault("check", True) + kwargs.setdefault("text", True) + self.expected_calls.append(call(list(cmd), **kwargs)) + + +@pytest.fixture +def expect_run( + disallow_subprocess: DisallowCallable, mock_shell: Path, origin: Path +) -> ExpectRun: + return ExpectRun( + disallow_subprocess=disallow_subprocess, + mock_shell=mock_shell, + origin=origin, + ) + + +def test_main_help(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(sys, "argv", ["copier-python", "--help"]) + with pytest.raises(SystemExit, match=str(0)): + main() + + +def test_main_update_with_project_current( + mock_temp_dir: Path, + expect_run: ExpectRun, + create_project: Callable[..., Path], + end_ref: str, + *, + dry_run: bool, +) -> None: + create_project(vcs_ref=end_ref) + worktree = mock_temp_dir / "worktree" + expect_run.expect( + ["git", "clone", "https://github.com/ness/pkfire", str(worktree)] + ) + for cmd in [ + [ + *["git", "remote", "set-url", "--push", "origin"], + "git@github.com:ness/pkfire.git", + ], + [*["poe", "setup"]], + [*["git", "checkout", "-b", "updates"]], + ]: + expect_run.expect(cmd, cwd=worktree) + expect_run.expect( + ["copier", "check-update", "--output-format", "json"], + cwd=worktree, + capture_output=True, + ) + with expect_run.patch(mock=False): + update(["gh:ness/pkfire"], dry_run=dry_run) + + +def test_main_update_error( + disallow_subprocess: DisallowCallable, + mock_shell: Path, + mock_temp_dir: Path, + expect_run: ExpectRun, + create_project: Callable[..., Path], + start_ref: str, + end_ref: str, +) -> None: + def _postcreate(project: Path) -> None: + (project / "AGENTS.md").write_text("Overwritten file, no newline") + for cmd in ( + ["git", "add", "."], + ["git", "commit", "-m", "Add conflicting change"], + ["git", "push"], + ): + with disallow_subprocess.pause(): + subprocess.run( # noqa: S603 + cmd, cwd=project, check=True, capture_output=True + ) + + create_project(vcs_ref=start_ref, postcreate=_postcreate) + worktree = mock_temp_dir / "worktree" + commit_message = os.linesep.join( + ( + "Apply template updates", + "", + (f"Applied updates from template: {start_ref}...{end_ref}"), + ( + f"https://github.com/ness/pkfire/compare/" + f"{start_ref}...{end_ref}" + ), + ) + ) + + def _shell() -> None: + with disallow_subprocess.pause(): + for cmd in ( + ["git", "checkout", "--ours", "AGENTS.md"], + ["git", "reset", "AGENTS.md"], + ): + subprocess.check_call(cmd, cwd=worktree) # noqa: S603 + + expect_run.expect( + ["git", "clone", "https://github.com/ness/pkfire", str(worktree)] + ) + for cmd in [ + [ + *["git", "remote", "set-url", "--push", "origin"], + "git@github.com:ness/pkfire.git", + ], + [*["poe", "setup"]], + [*["git", "checkout", "-b", "updates"]], + ]: + expect_run.expect(cmd, cwd=worktree) + expect_run.expect( + ["copier", "check-update", "--output-format", "json"], + cwd=worktree, + capture_output=True, + ) + expect_run.expect(["copier", "update", "-l"], cwd=worktree) + expect_run.expect( + ["git", "status", "--porcelain"], cwd=worktree, capture_output=True + ) + expect_run.expect([str(mock_shell)], cwd=worktree, check=False) + expect_run.expect( + ["git", "status", "--porcelain"], cwd=worktree, capture_output=True + ) + expect_run.expect(["uv", "run", "poe", "lt"], cwd=worktree) + expect_run.expect([str(mock_shell)], cwd=worktree, check=False) + for cmd in [ + [*["git", "add", "-A"]], + [*["git", "commit", "-m", commit_message]], + ]: + expect_run.expect(cmd, cwd=worktree) + + with expect_run.patch(mock=False, shell_callback=_shell): + update(["gh:ness/pkfire"], dry_run=True) + + +@pytest.mark.parametrize( + ("arg", "repo"), + [ + pytest.param(arg, repo, id=arg) + for repo in ["ness/pkfire", "Ness/PK_Thunder-Alpha"] + for base in [ + *[repo, f"gh:{repo}", f"github.com/{repo}"], + *[ + f"{proto}{user}@github.com:{repo}" + for user in ("git", repo.split("/", 1)[0]) + for proto in ("", "ssh://", "git+ssh://") + ], + *[f"{proto}://github.com/{repo}" for proto in ("http", "https")], + ] + for base_dotgit in [base, f"{base}.git"] + for arg in [base_dotgit, f"{base_dotgit}/"] + ], +) +def test_main_update_repo_arguments( + mock_temp_dir: Path, expect_run: ExpectRun, arg: str, repo: str +) -> None: + expect_run.expect( + [ + "git", + "clone", + f"https://github.com/{repo}", + str(mock_temp_dir / "worktree"), + ] + ) + with expect_run.patch(mock=True, has_any=True): + update([arg], dry_run=True) + + +def test_main_update_with_project( + mock_temp_dir: Path, + expect_run: ExpectRun, + create_project: Callable[..., Path], + end_ref: str, + start_ref: str, + *, + dry_run: bool, +) -> None: + create_project(vcs_ref=start_ref) + worktree = mock_temp_dir / "worktree" + commit_message = os.linesep.join( + ( + "Apply template updates", + "", + (f"Applied updates from template: {start_ref}...{end_ref}"), + ( + f"https://github.com/ness/pkfire/compare/" + f"{start_ref}...{end_ref}" + ), + ) + ) + + expect_run.expect( + ["git", "clone", "https://github.com/ness/pkfire", str(worktree)] + ) + for cmd in [ + [ + *["git", "remote", "set-url", "--push", "origin"], + "git@github.com:ness/pkfire.git", + ], + [*["poe", "setup"]], + [*["git", "checkout", "-b", "updates"]], + ]: + expect_run.expect(cmd, cwd=worktree) + expect_run.expect( + ["copier", "check-update", "--output-format", "json"], + cwd=worktree, + capture_output=True, + ) + expect_run.expect(["copier", "update", "-l"], cwd=worktree) + expect_run.expect( + ["git", "status", "--porcelain"], cwd=worktree, capture_output=True + ) + for cmd in [ + [*["uv", "run", "poe", "lt"]], + [*["git", "add", "-A"]], + [*["git", "commit", "-m", commit_message]], + ]: + expect_run.expect(cmd, cwd=worktree) + if not dry_run: + for cmd in [ + ["git", "push", "-u", "origin", "updates", "--force-with-lease"], + ]: + expect_run.expect(cmd, cwd=worktree) + + title, _, *body = commit_message.splitlines() + expect_run.expect( + [ + "gh", + "pr", + "create", + "--title", + title, + "--body", + os.linesep.join(body), + "--head", + "updates", + ], + cwd=worktree, + check=False, + capture_output=True, + ) + with expect_run.patch(mock=False): + update(["gh:ness/pkfire"], dry_run=dry_run) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..d050e1c --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import sys +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any +from unittest.mock import MagicMock, patch + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + import pytest + + +@dataclass +class DisallowCallable: + request: pytest.FixtureRequest + obj: object + attr: str + original: Callable[..., Any] = field(init=False) + enabled: bool = field(default=False, init=False) + mock_attr: MagicMock | None = field(default=None, init=False) + + @dataclass + class DisallowedError(Exception): + mock: DisallowCallable + + def __str__(self) -> str: + fn = self.mock.original + name = f"{fn.__module__}." + ( + getattr(fn, "__qualname__", None) + or getattr(fn, "__name__", "(unknown)") + ) + return f"`{name}` disallowed via {self.mock.request.fixturename}" + + def __post_init__(self) -> None: + self.original = getattr(self.obj, self.attr) + + @contextmanager + def __call__(self) -> Iterator[Self]: + + def wrapper(*args: Any, **kwargs: Any) -> Any: + if self.enabled: + return self.original(*args, **kwargs) + raise self.DisallowedError(self) + + with patch.object( + self.obj, self.attr, autospec=True, side_effect=wrapper + ) as mock_attr: + self.mock_attr = mock_attr + yield self + + @contextmanager + def pause(self) -> Iterator[None]: + if self.mock_attr: + with patch.object(self.obj, self.attr, self.original): + yield + self.mock_attr.assert_not_called()