diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..d3c3021 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 5 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..76db5b7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,159 @@ +name: CI + +on: + pull_request: + branches: + - main + - "codex/**" + push: + branches: + - main + - "codex/**" + +permissions: + contents: read + +env: + PIP_NO_CACHE_DIR: "1" + +jobs: + test: + name: pytest (${{ matrix.os }}, py${{ matrix.python-version }}) + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-2022 + python-version: + - "3.10" + - "3.11" + - "3.12" + - "3.13" + - "3.14" + + steps: + - name: Check out repository + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Install package and test dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e ".[dev]" + + - name: Run tests + run: python -m pytest -q + + lint: + name: ruff + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Check out repository + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.12" + + - name: Install lint dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -e ".[dev]" + + - name: Run ruff + run: ruff check . + + supply-chain: + name: supply chain + runs-on: ubuntu-latest + timeout-minutes: 20 + + steps: + - name: Check out repository + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.12" + + - name: Install package and supply-chain tools + run: | + python -m pip install --upgrade pip + python -m pip install ".[dev]" + python -m pip install cyclonedx-bom pip-audit + + - name: Check installed dependency consistency + run: python -m pip check + + - name: Run vulnerability advisory scan + run: python -m pip_audit --strict . --progress-spinner off --cache-dir .pip-audit-cache + + - name: Generate CycloneDX SBOM + run: | + cyclonedx-py environment "$(which python)" \ + --pyproject pyproject.toml \ + --mc-type library \ + --output-reproducible \ + --of JSON \ + -o codebase-graph-sbom.cdx.json + + package: + name: package + runs-on: ubuntu-latest + timeout-minutes: 20 + + steps: + - name: Check out repository + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.12" + + - name: Build distributions + run: | + python -m pip install --upgrade pip + python -m pip install build twine + python -m build + python -m twine check dist/* + + - name: Smoke-test built wheel + shell: bash + run: | + python -m venv /tmp/codebase-graph-wheel + /tmp/codebase-graph-wheel/bin/python -m pip install --upgrade pip + /tmp/codebase-graph-wheel/bin/python -m pip install dist/*.whl + /tmp/codebase-graph-wheel/bin/codebase-graph --help + /tmp/codebase-graph-wheel/bin/codebase-graph-mcp --help + /tmp/codebase-graph-wheel/bin/python scripts/smoke_built_wheel.py /tmp/codebase-graph-wheel/bin/codebase-graph + + - name: Smoke-test source distribution + shell: bash + run: | + python -m venv /tmp/codebase-graph-sdist + /tmp/codebase-graph-sdist/bin/python -m pip install --upgrade pip + /tmp/codebase-graph-sdist/bin/python -m pip install dist/*.tar.gz + /tmp/codebase-graph-sdist/bin/codebase-graph --help + /tmp/codebase-graph-sdist/bin/codebase-graph-mcp --help + /tmp/codebase-graph-sdist/bin/python scripts/smoke_built_wheel.py /tmp/codebase-graph-sdist/bin/codebase-graph diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..a293ceb --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,291 @@ +name: Release + +on: + push: + branches: + - main + workflow_dispatch: + inputs: + pypi-environment-smoke: + description: Verify the pypi environment and OIDC claims without publishing. + required: false + type: boolean + default: false + +permissions: + contents: read + +env: + PIP_NO_CACHE_DIR: "1" + +jobs: + pypi-environment-smoke: + name: pypi environment smoke + if: ${{ github.event_name == 'workflow_dispatch' && inputs.pypi-environment-smoke }} + runs-on: ubuntu-latest + timeout-minutes: 10 + environment: + name: pypi + permissions: + contents: read + id-token: write + + steps: + - name: Verify PyPI environment OIDC claims + shell: bash + run: | + python - <<'PY' + import base64 + import json + import os + import urllib.request + + request_url = os.environ["ACTIONS_ID_TOKEN_REQUEST_URL"] + separator = "&" if "?" in request_url else "?" + request = urllib.request.Request( + f"{request_url}{separator}audience=pypi", + headers={"Authorization": f"bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"}, + ) + with urllib.request.urlopen(request, timeout=30) as response: + token = json.load(response)["value"] + + payload = token.split(".")[1] + payload += "=" * (-len(payload) % 4) + claims = json.loads(base64.urlsafe_b64decode(payload)) + + expected_claims = { + "aud": "pypi", + "environment": "pypi", + "repository": os.environ["GITHUB_REPOSITORY"], + "repository_owner": os.environ["GITHUB_REPOSITORY_OWNER"], + "workflow": "Release", + } + for claim, expected in expected_claims.items(): + actual = claims.get(claim) + if actual != expected: + raise SystemExit(f"OIDC claim {claim!r} expected {expected!r}, got {actual!r}") + + expected_workflow_ref_prefix = f"{os.environ['GITHUB_REPOSITORY']}/.github/workflows/release.yml@" + workflow_ref = claims.get("workflow_ref") + if not isinstance(workflow_ref, str) or not workflow_ref.startswith(expected_workflow_ref_prefix): + raise SystemExit( + f"OIDC claim 'workflow_ref' expected to start with {expected_workflow_ref_prefix!r}, " + f"got {workflow_ref!r}" + ) + + print("pypi environment OIDC claims verified") + PY + + release-please: + name: release please + if: ${{ !inputs.pypi-environment-smoke }} + runs-on: ubuntu-latest + timeout-minutes: 10 + permissions: + contents: write + pull-requests: write + outputs: + release-created: ${{ steps.release.outputs.release_created }} + tag-name: ${{ steps.release.outputs.tag_name }} + version: ${{ steps.release.outputs.version }} + + steps: + - name: Create release pull request or GitHub release + id: release + uses: googleapis/release-please-action@45996ed1f6d02564a971a2fa1b5860e934307cf7 # v5.0.0 + with: + token: ${{ secrets.RELEASE_PLEASE_TOKEN || github.token }} + config-file: release-please-config.json + manifest-file: .release-please-manifest.json + + production-gate: + name: production release gate + needs: release-please + if: needs.release-please.outputs.release-created == 'true' + runs-on: ubuntu-latest + timeout-minutes: 10 + environment: + name: pypi + permissions: + contents: read + + env: + RELEASE_TAG: ${{ needs.release-please.outputs.tag-name }} + CODEBASE_GRAPH_CONFIRM_TRUSTED_PUBLISHER: ${{ vars.CODEBASE_GRAPH_CONFIRM_TRUSTED_PUBLISHER }} + CODEBASE_GRAPH_CONFIRM_PYPI_ENVIRONMENT: ${{ vars.CODEBASE_GRAPH_CONFIRM_PYPI_ENVIRONMENT }} + CODEBASE_GRAPH_CONFIRM_HOSTED_CI_GREEN: ${{ vars.CODEBASE_GRAPH_CONFIRM_HOSTED_CI_GREEN }} + CODEBASE_GRAPH_CONFIRM_PRIVATE_VULNERABILITY_REPORTING: ${{ vars.CODEBASE_GRAPH_CONFIRM_PRIVATE_VULNERABILITY_REPORTING }} + CODEBASE_GRAPH_REQUIRE_CONDA: ${{ vars.CODEBASE_GRAPH_REQUIRE_CONDA }} + + steps: + - name: Check out release tag + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + ref: ${{ env.RELEASE_TAG }} + fetch-depth: 0 + + - name: Run production release gate + shell: bash + run: | + args=(--production) + if [[ "${CODEBASE_GRAPH_CONFIRM_TRUSTED_PUBLISHER}" == "true" ]]; then + args+=(--confirm trusted-publisher) + fi + if [[ "${CODEBASE_GRAPH_CONFIRM_PYPI_ENVIRONMENT}" == "true" ]]; then + args+=(--confirm pypi-environment) + fi + if [[ "${CODEBASE_GRAPH_CONFIRM_HOSTED_CI_GREEN}" == "true" ]]; then + args+=(--confirm hosted-ci-green) + fi + if [[ "${CODEBASE_GRAPH_CONFIRM_PRIVATE_VULNERABILITY_REPORTING}" == "true" ]]; then + args+=(--confirm private-vulnerability-reporting) + fi + if [[ "${CODEBASE_GRAPH_REQUIRE_CONDA}" == "true" ]]; then + args+=(--require-conda) + fi + python scripts/check_release_gate.py "${args[@]}" + + build: + name: build release distributions + needs: + - release-please + - production-gate + if: needs.release-please.outputs.release-created == 'true' + runs-on: ubuntu-latest + timeout-minutes: 30 + permissions: + contents: write + outputs: + package-version: ${{ steps.verify-version.outputs.package-version }} + + env: + RELEASE_TAG: ${{ needs.release-please.outputs.tag-name }} + + steps: + - name: Check out release tag + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + ref: ${{ env.RELEASE_TAG }} + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.12" + + - name: Build and validate distributions + run: | + python -m pip install --upgrade pip + python -m pip install build twine + python -m build + python -m twine check dist/* + + - name: Verify package version matches tag + id: verify-version + shell: bash + run: | + python - <<'PY' + import email + import glob + import os + import re + import tarfile + import zipfile + + tag = os.environ["RELEASE_TAG"] + match = re.fullmatch(r"v(?P\d+\.\d+\.\d+)", tag) + if match is None: + raise SystemExit(f"release tag must match vX.Y.Z, got {tag!r}") + expected = match.group("version") + + versions = set() + for artifact in glob.glob("dist/*"): + if artifact.endswith(".whl"): + with zipfile.ZipFile(artifact) as wheel: + metadata_name = next(name for name in wheel.namelist() if name.endswith(".dist-info/METADATA")) + metadata = email.message_from_bytes(wheel.read(metadata_name)) + elif artifact.endswith(".tar.gz"): + with tarfile.open(artifact, "r:gz") as sdist: + metadata_member = next(member for member in sdist.getmembers() if member.name.endswith("/PKG-INFO")) + metadata = email.message_from_binary_file(sdist.extractfile(metadata_member)) + else: + raise SystemExit(f"unexpected distribution artifact: {artifact}") + versions.add(metadata["Version"]) + + if versions != {expected}: + raise SystemExit(f"built package versions {sorted(versions)} do not match release tag {tag}") + + with open(os.environ["GITHUB_OUTPUT"], "a", encoding="utf-8") as output: + output.write(f"package-version={expected}\n") + PY + + - name: Smoke-test built wheel + shell: bash + run: | + python -m venv /tmp/codebase-graph-wheel + /tmp/codebase-graph-wheel/bin/python -m pip install --upgrade pip + /tmp/codebase-graph-wheel/bin/python -m pip install dist/*.whl + /tmp/codebase-graph-wheel/bin/python -m pip install cyclonedx-bom pip-audit + /tmp/codebase-graph-wheel/bin/codebase-graph --help + /tmp/codebase-graph-wheel/bin/codebase-graph-mcp --help + /tmp/codebase-graph-wheel/bin/python scripts/smoke_built_wheel.py /tmp/codebase-graph-wheel/bin/codebase-graph + /tmp/codebase-graph-wheel/bin/python -m pip check + /tmp/codebase-graph-wheel/bin/python -m pip_audit --strict . --progress-spinner off --cache-dir /tmp/pip-audit-cache + /tmp/codebase-graph-wheel/bin/cyclonedx-py environment /tmp/codebase-graph-wheel/bin/python \ + --pyproject pyproject.toml \ + --mc-type library \ + --output-reproducible \ + --of JSON \ + -o codebase-graph-${{ steps.verify-version.outputs.package-version }}-sbom.cdx.json + + - name: Smoke-test source distribution + shell: bash + run: | + python -m venv /tmp/codebase-graph-sdist + /tmp/codebase-graph-sdist/bin/python -m pip install --upgrade pip + /tmp/codebase-graph-sdist/bin/python -m pip install dist/*.tar.gz + /tmp/codebase-graph-sdist/bin/codebase-graph --help + /tmp/codebase-graph-sdist/bin/codebase-graph-mcp --help + /tmp/codebase-graph-sdist/bin/python scripts/smoke_built_wheel.py /tmp/codebase-graph-sdist/bin/codebase-graph + + - name: Upload distributions to GitHub release + env: + GH_TOKEN: ${{ github.token }} + run: gh release upload "$RELEASE_TAG" dist/* codebase-graph-${{ steps.verify-version.outputs.package-version }}-sbom.cdx.json --clobber + + publish-pypi: + name: publish to PyPI + needs: + - release-please + - build + if: needs.release-please.outputs.release-created == 'true' + runs-on: ubuntu-latest + timeout-minutes: 10 + environment: + name: pypi + url: https://pypi.org/p/cbasegraph + permissions: + contents: read + id-token: write + env: + GH_TOKEN: ${{ github.token }} + RELEASE_TAG: ${{ needs.release-please.outputs.tag-name }} + + steps: + - name: Download distributions from GitHub release + shell: bash + run: | + mkdir -p dist + gh release download "$RELEASE_TAG" --dir dist --pattern "*.whl" --pattern "*.tar.gz" + python - <<'PY' + from pathlib import Path + + artifacts = sorted(path.name for path in Path("dist").iterdir()) + if not any(name.endswith(".whl") for name in artifacts): + raise SystemExit(f"release {artifacts=} does not include a wheel") + if not any(name.endswith(".tar.gz") for name in artifacts): + raise SystemExit(f"release {artifacts=} does not include a source distribution") + PY + + - name: Publish distributions to PyPI + uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # release/v1 diff --git a/.gitignore b/.gitignore index 83972fa..ea0776e 100644 --- a/.gitignore +++ b/.gitignore @@ -206,6 +206,10 @@ tempCodeRunnerFile.py # Ruff stuff: .ruff_cache/ +# codebaseGraph local graph state +.codebaseGraph/ +.codebase_graph/ + # PyPI configuration file .pypirc diff --git a/.release-please-manifest.json b/.release-please-manifest.json new file mode 100644 index 0000000..e18ee07 --- /dev/null +++ b/.release-please-manifest.json @@ -0,0 +1,3 @@ +{ + ".": "0.0.0" +} diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..fc9f9b8 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,13 @@ + +## codebaseGraph workflow +- Treat the repo-local `.codebaseGraph` graph as the project operating source of truth. It is prohibited to read the code source before you find the target files using the graph. +- AI agents must use block format for `graph-search` and `graph-context`; reserve `--json` for tests, APIs, or explicit structured-payload debugging. +- Use `/Users/rabii/Projects/Repositories/codebaseGraph/.venv/bin/codebase-graph graph-search --repo-root . --no-refresh --detail slim --context-limit 1 --format block` before answering repo-structure questions or performing coding tasks. +- Use `/Users/rabii/Projects/Repositories/codebaseGraph/.venv/bin/codebase-graph graph-context --repo-root . --profile --no-refresh --detail slim --context-limit 2 --format block` when relationships or nearby evidence matter; useful profiles include `definitions`, `dependencies`, `callgraph`, `docs`, `runtime`, and `change_impact`. +- For architecture orientation, run `/Users/rabii/Projects/Repositories/codebaseGraph/.venv/bin/codebase-graph graph-architecture-queries`, then execute selected read-only statements with `/Users/rabii/Projects/Repositories/codebaseGraph/.venv/bin/codebase-graph graph-query "" --repo-root .`. +- Use `/Users/rabii/Projects/Repositories/codebaseGraph/.venv/bin/codebase-graph graph-schema` or `/Users/rabii/Projects/Repositories/codebaseGraph/.venv/bin/codebase-graph graph-query-helpers` before writing raw graph queries, add `--pretty` for indented JSON when humans need to inspect output, and keep `graph-query` read-only. +- Refresh the graph with `/Users/rabii/Projects/Repositories/codebaseGraph/.venv/bin/codebase-graph setup --repo-root . --mcp-client none` when files change materially. Setup config: `/Users/rabii/Projects/Repositories/codebaseGraph/.codebaseGraph/config.json`. + + +## Git Commit Convention +- When you finish your coding task, strictly use Conventional Commits 1.0.0 for commit message and commit your changes. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..0017fc7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,3 @@ +# Changelog + +Release notes are managed by release-please. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..82f2e19 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Rabii Chaarani + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..10b457c --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include SECURITY.md +include scripts/check_release_gate.py diff --git a/README.md b/README.md index a39482f..c8c79a5 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,150 @@ # codebaseGraph -`codebase_graph` is a generic project/code knowledge graph engine for Python repositories. It scans a source root, builds a typed graph of files, modules, symbols, imports, calls, dependencies, entry points, and documentation sources, and exposes search, compact context, schema, and read-only query helpers. +`codebaseGraph` builds a repo-local knowledge graph for coding agents. It materializes Python source, `AGENTS.md`, +`CLAUDE.md`, Markdown, and MDX files into a LadyBugDB-backed graph, then exposes search, compact context, schema, query +helpers, and read-only MCP tools. -## Install for local development +Using `codebaseGraph` helps agents orient and reason faster, reduce guesswork, keep prompts focused, and make changes with better +impact awareness. Because the graph stores local source, documentation, spans, and relationships together, it gives +AI agents a compact evidence layer for safer edits, architecture review, dependency tracing, and onboarding while reducing token consumption and tool calling. + +Requires Python 3.10+ + +## Quick start ```bash -python -m pip install -e .[dev] +python -m pip install cbasegraph +codebase-graph setup --repo-root . +codebase-graph graph-search SampleService --repo-root . --no-refresh --format block +``` + +Setup creates: + +```text +.codebaseGraph/ + config.json + manifest.json + _graph.ldb +``` + +For a repository named `my-service`, the database path is `.codebaseGraph/my-service_graph.ldb`. + +The setup command materializes the graph, writes or updates one marked codebaseGraph block in `AGENTS.md` or +`CLAUDE.md`, and installs a Codex MCP client entry unless skipped. + +Useful setup options: + +```bash +codebase-graph setup --repo-root /path/to/repo +codebase-graph setup --mcp-client claude +codebase-graph setup --mcp-client lmstudio +codebase-graph setup --skip-mcp-config +codebase-graph setup --instructions-target claude +codebase-graph setup --dry-run --pretty +``` + +## MCP install + +```bash +codebase-graph mcp install --client codex +``` + +Supported clients are `codex`, `claude`, `claude-project`, `lmstudio`, `hermes`, `openclaw`, and `generic`. + +Server naming: + +- `codebase-graph setup` installs the default MCP server as `codebase_graph`. +- Standalone `codebase-graph mcp install` defaults to `codebase_graph_`. +- Use `--name codebase_graph` to override the standalone installer name. + +The installer builds the server descriptor from `.codebaseGraph/config.json`, uses a supported native client CLI when +available, and falls back to writing the client config file directly. Use `--dry-run --json` to inspect the emitted +command or config patch before writing, and `--verify` to run a stdio smoke test after installation. + +```bash +codebase-graph mcp install --client claude --scope user +codebase-graph mcp install --client claude-project +codebase-graph mcp install --client all --dry-run --json +codebase-graph mcp install --config-path /path/to/.codebaseGraph/config.json +codebase-graph mcp install --verify +``` + +## MCP usage + +Stdio is the default transport for local MCP clients: + +```bash +codebase-graph mcp serve --config .codebaseGraph/config.json +codebase-graph-mcp --config .codebaseGraph/config.json ``` -## Basic usage +HTTP is available for local endpoint clients: + +```bash +codebase-graph mcp http --config .codebaseGraph/config.json --host 127.0.0.1 --port 8765 +``` + +Keep HTTP bound to `127.0.0.1` for normal use. Remote binding requires `--allow-remote` and a bearer token, but does not +provide TLS, rate limiting, authorization scopes, or a multi-user security model. HTTP clients must initialize first and +send the returned `Mcp-Session-Id` header on later requests. + +Available MCP tools: -```python -from codebase_graph import CodebaseGraph +- `graph_health` +- `graph_search` +- `graph_context` +- `graph_schema` +- `graph_query_helpers` +- `graph_architecture_queries` +- `graph_query` with write-like statements blocked -graph = CodebaseGraph(source_root=".", state_dir=".codebase_graph/graph") -graph.materialize() -graph.search("FastAPI routes") -graph.context("SomeClass") -graph.cypher("MATCH (n:PythonClass) RETURN n.label LIMIT 5") +`graph_query` returns at most 1,000 rows per call. Add a narrower `MATCH` pattern or a query-side `LIMIT` for broader +graph exploration. + +## CLI workflow + +The CLI mirrors the MCP tools for clients that do not surface MCP directly: + +```bash +codebase-graph graph-health --repo-root . +codebase-graph graph-context SampleService --repo-root . --profile definitions --format block +codebase-graph graph-query "MATCH (n) RETURN count(n) AS total_nodes LIMIT 1" --repo-root . ``` -## CLI +Use `--format block` for agent-facing output and `--json --pretty` for structured inspection. Retrieval commands also +support `--detail standard|slim`; `slim` drops score diagnostics and duplicate or empty summary fields. + +For coding-task architecture orientation, call `graph_architecture_queries` first, then run selected statements with +`graph_query`. + +## Development ```bash -codebase-graph status --source-root . -codebase-graph materialize --source-root . -codebase-graph schema -codebase-graph search "query" -codebase-graph context "SomeClass" -codebase-graph cypher "MATCH (n:PythonClass) RETURN n.label LIMIT 5" +python -m pip install -e .[dev] +python -m pytest +ruff check . ``` -The base package is intentionally small and importable without optional graph database or parquet bindings. Optional storage backends can be installed through extras as they mature. +## Release and security + +CI runs pytest across Linux, macOS, and Windows for Python 3.10 through 3.14, plus ruff, package-build checks, +supply-chain validation, and smoke tests. See [docs/release.md](docs/release.md) for the full release process and +conda-forge checklist. + +Report suspected vulnerabilities privately. See [SECURITY.md](SECURITY.md) for supported versions, reporting +expectations, and the local-first MCP security boundary. + +## Troubleshooting + +- Missing LadyBugDB: install a package build that includes `real_ladybug`; setup fails before creating `.codebaseGraph` + if the runtime cannot open a graph database. +- Stale graph: rerun `codebase-graph setup --repo-root .` after material source or documentation changes. +- Broken client config: rerun `codebase-graph mcp install --client --verify`. +- PATH or executable issues: run setup from the virtual environment that contains `codebase-graph`; the descriptor + prefers that absolute executable path. +- Unsupported files: binary, vendor, cache, virtualenv, build, dist, `.codebase_graph`, and `.codebaseGraph` paths are + skipped. +- Lock errors: stop other graph materialization or setup processes using the same + `.codebaseGraph/_graph.ldb`. Stale locks with dead writer PIDs are removed automatically; if the error + remains, inspect the `.ldb.lock` file before removing it manually. +- Diagnostics: set `CODEBASE_GRAPH_LOG_LEVEL=INFO` to include setup start/completion events on stderr. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..45168f1 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,31 @@ +# Security Policy + +## Supported Versions + +Security fixes are prepared against the current `main` branch and included in the next package release. + +## Reporting a Vulnerability + +Report suspected vulnerabilities privately through GitHub security advisories or private vulnerability reporting for this repository. Do not open a public issue for exploitable behavior, dependency vulnerabilities with an available proof of concept, credential exposure, or a bypass of the read-only MCP/query contract. + +Include: + +- affected version or commit +- reproduction steps +- expected impact +- relevant logs or proof of concept +- whether the report can be disclosed publicly after a fix is available + +Maintainers should acknowledge reports within 7 days, triage severity, and coordinate disclosure timing with the reporter. + +## Security Scope + +The production security boundary is local-first: + +- The stdio MCP transport is intended for local MCP clients. +- The HTTP MCP transport binds to localhost by default. +- `--allow-remote` requires a bearer token. It does not add TLS, rate limiting, authorization scopes, or a multi-user session model. +- HTTP tool calls require an initialized `Mcp-Session-Id`; one client's initialize request must not unlock tools for another client. +- `graph_query` is intended to remain read-only. Do not relax query restrictions without a parser-level read-only proof or an explicit safe-procedure allowlist. + +Dependency vulnerability scanning runs in hosted CI and release workflows. Local setup commands must not call external advisory services implicitly. diff --git a/conda-forge/recipe/meta.yaml b/conda-forge/recipe/meta.yaml new file mode 100644 index 0000000..157e193 --- /dev/null +++ b/conda-forge/recipe/meta.yaml @@ -0,0 +1,52 @@ +{% set name = "codebase-graph" %} +{% set pypi_name = "cbasegraph" %} +{% set version = "PUT_RELEASE_VERSION_HERE" %} +{% set python_min = "3.10" %} + +package: + name: {{ name|lower }} + version: {{ version }} + +source: + url: https://pypi.org/packages/source/{{ pypi_name[0] }}/{{ pypi_name }}/{{ pypi_name }}-{{ version }}.tar.gz + sha256: PUT_RELEASE_SDIST_SHA256_HERE + +build: + noarch: python + number: 0 + script: {{ PYTHON }} -m pip install . -vv --no-deps --no-build-isolation + entry_points: + - codebase-graph = codebase_graph.cli:main + - codebase-graph-mcp = codebase_graph.mcp.server:main + +requirements: + host: + - python {{ python_min }} + - pip + - setuptools >=77 + - setuptools-scm >=8 + - wheel + run: + - python >={{ python_min }} + - real-ladybug >=0.15.3,<0.16 + - tomli >=2.0.1 # [py<311] + - tree-sitter >=0.25.2,<0.26 + - tree-sitter-python >=0.25.0,<0.26 + +test: + imports: + - codebase_graph + commands: + - codebase-graph --help + - codebase-graph-mcp --help + requires: + - python {{ python_min }} + +about: + home: https://github.com/rabii-chaarani/codebaseGraph + summary: Generic codebase knowledge graph engine for Python projects. + license: MIT + +extra: + recipe-maintainers: + - rabii-chaarani diff --git a/docs/graph_output_token_comparison.md b/docs/graph_output_token_comparison.md new file mode 100644 index 0000000..8f66332 --- /dev/null +++ b/docs/graph_output_token_comparison.md @@ -0,0 +1,37 @@ +# Graph Search Output Token Comparison + +## Method +- Raw format: compact JSON emitted by the current graph-search payload serializer, counted from the exact serialized JSON text with sorted keys and compact separators. +- Ontology-preserving block format: grouped `file path` blocks with readable `Class`, `Method`, `Scope`, relation, `label`, `span`, `id`, and `rank_score` terms left literal. +- Tokenizer/model: encoding `o200k_base`. +- Count method: payload-only tokens using `len(encoding.encode(text))`; chat-message wrapper tokens were not included. + +## Results +| Query | Results | Context edges | Raw tokens | Block tokens | Saved tokens | Reduction % | +|---|---:|---:|---:|---:|---:|---:| +| SearchService | 3 | 6 | 502 | 234 | 268 | 53.4% | + +## Aggregate Summary +- Samples: 1 +- Total raw tokens: 502 +- Total block tokens: 234 +- Total saved tokens: 268 +- Overall reduction: 53.4% +- Mean reduction: 53.4% +- Median reduction: 53.4% +- Min reduction: SearchService (53.4%) +- Max reduction: SearchService (53.4%) +- p90 raw/block tokens: not reported because fewer than 10 samples were compared + +## Ontology Preservation +The validator normalizes raw JSON and block output into canonical result records preserving `type`, `label`, `path`, `span`, `id`, `rank_score`, and ordered context records with `direction`, `relation`, `type`, `label`, `path`, `span`, and non-boilerplate `summary`. + +Intentional omissions: +- `results[0].context[0].summary` +- `results[1].context[0].summary` +- `results[2].context[0].summary` + +Known limitations: the block parser validates the supported graph-search fixture shape and live graph-search output shape; it is not a general-purpose parser for hand-written variants. + +## Recommendation +Use the ontology-preserving block format by default for agent-facing graph-search output when consumers need readable context. Keep JSON available for machine APIs and tests that require strict structured payloads. diff --git a/docs/release.md b/docs/release.md new file mode 100644 index 0000000..86c8d00 --- /dev/null +++ b/docs/release.md @@ -0,0 +1,90 @@ +# Release Process + +`codebaseGraph` releases are managed by release-please. The release workflow opens and maintains a release pull request from Conventional Commit history. When that release pull request is merged, release-please creates the `vX.Y.Z` tag and GitHub Release, then the same workflow builds the source distribution and wheel from that tag, verifies that the package metadata version matches the tag, attaches the distributions to the GitHub Release, and publishes to PyPI with Trusted Publishing. + +## One-time PyPI setup + +Configure a PyPI Trusted Publisher for: + +- PyPI project: `cbasegraph` +- Owner/repository: `rabii-chaarani/codebaseGraph` +- Workflow: `release.yml` +- Environment: `pypi` + +Create the `pypi` GitHub environment before the first release. Use required reviewers on that environment when release approval should be manual. + +Set these `pypi` environment variables to `true` only after the corresponding owner-controlled gate is verified: + +- `CODEBASE_GRAPH_CONFIRM_TRUSTED_PUBLISHER` +- `CODEBASE_GRAPH_CONFIRM_PYPI_ENVIRONMENT` +- `CODEBASE_GRAPH_CONFIRM_HOSTED_CI_GREEN` +- `CODEBASE_GRAPH_CONFIRM_PRIVATE_VULNERABILITY_REPORTING` +- `CODEBASE_GRAPH_REQUIRE_CONDA`, only when conda-forge publication is part of the release + +The release workflow runs `scripts/check_release_gate.py --production` in the protected `pypi` environment before building +or publishing release distributions. If one of these variables is missing or the repository-local gates fail, the release +stops before any package is uploaded. + +## CI + +Pull requests and pushes to `main` or `codex/**` run: + +- `pytest` on Linux, macOS, and Windows for Python 3.10 through 3.14. +- `ruff check .` on Linux. +- Supply-chain checks on Linux with `pip check`, `pip-audit --strict` vulnerability advisory scanning, immutable + GitHub Action pins, and CycloneDX SBOM generation. +- A package build on Linux with `python -m build`, `twine check`, console-script smoke tests from the built wheel and + source distribution, packaged runtime smoke that runs `setup`, `graph-health`, `graph-search`, and stdio MCP handshake + checks, and release SBOM generation. + +## Release flow + +1. Merge normal pull requests into `main` with Conventional Commit-style titles or squash commit messages such as `feat: add graph query helpers` or `fix: preserve MCP config`. +2. The `Release` workflow opens or updates a release pull request that updates `CHANGELOG.md` and `.release-please-manifest.json`. +3. Review and merge the release pull request when ready to publish. +4. The `Release` workflow creates the `vX.Y.Z` tag and GitHub Release, builds the distributions from that tag, verifies `Version: X.Y.Z`, uploads the distributions and SBOM to the GitHub Release, and publishes to PyPI from the protected `pypi` environment. + +## Release gate + +Before publishing a production release, confirm: + +- Hosted CI is green for tests, ruff, package build, supply-chain, wheel smoke, and source-distribution smoke. +- `SECURITY.md` is present and vulnerability reporting expectations are current. +- The project owner has selected an SPDX license, added package license metadata, and included the corresponding license file. +- The PyPI Trusted Publisher, `pypi` GitHub environment, and release-please token posture have been verified in GitHub/PyPI settings. +- Conda-forge submission is either explicitly out of scope for the release or the recipe placeholders have been replaced with the release version, source-distribution SHA256, and chosen SPDX license. + +Run the local release-gate checker before publishing: + +```bash +python scripts/check_release_gate.py +python scripts/check_release_gate.py --production \ + --confirm trusted-publisher \ + --confirm pypi-environment \ + --confirm hosted-ci-green \ + --confirm private-vulnerability-reporting +``` + +Add `--require-conda` when conda-forge submission is in scope for the release. + +Vulnerability advisory scans require an external advisory service. Hosted CI and release workflows run those scans and +fail on known vulnerable dependencies. Local setup stays offline-safe and must not call external advisory APIs +implicitly; run local advisory scans explicitly when that disclosure is acceptable. + +The package version remains tag-derived through `setuptools_scm`; do not add a static `project.version` field to `pyproject.toml` just for release-please. + +To force a specific next version, merge a commit whose body contains a `Release-As: X.Y.Z` trailer. + +For manual maintenance, rerun or dispatch the `Release` workflow. If CI checks must run on release-please pull requests, configure a `RELEASE_PLEASE_TOKEN` secret backed by a personal access token or GitHub App token; the default `GITHUB_TOKEN` can create the pull request but does not trigger follow-up workflows from its own events. + +## Conda-forge release path + +This repository intentionally does not upload directly to Anaconda.org. Conda distribution should go through conda-forge: + +1. Ensure the PyPI release has completed and download the source distribution SHA256. +2. Before submitting `codebase-graph`, verify all runtime dependencies exist on conda-forge. If `real-ladybug` is not available, package that dependency first. +3. Copy `conda-forge/recipe/meta.yaml` into a new `recipes/codebase-graph/` directory in a fork of `conda-forge/staged-recipes`. +4. Replace `version`, `sha256`, and `license` placeholders with release-specific values. +5. Open the staged-recipes pull request and let conda-forge CI validate Linux, macOS, and Windows builds. + +After staged-recipes is merged, future conda releases are handled in the generated `codebase-graph-feedstock`. diff --git a/pyproject.toml b/pyproject.toml index fad2c19..eae5dcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,39 +1,60 @@ [build-system] -requires = ["setuptools>=68", "wheel"] +requires = ["setuptools>=77", "setuptools-scm>=8", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "codebase-graph" -version = "0.1.0" +name = "cbasegraph" +dynamic = ["version"] description = "Generic codebase knowledge graph engine for Python projects." readme = "README.md" requires-python = ">=3.10" authors = [{ name = "Rabii Chaarani" }] +license = "MIT" +license-files = ["LICENSE"] dependencies = [ - "tomli; python_version < '3.11'", + "real_ladybug>=0.15.3,<0.16", + "tomli>=2.0.1; python_version < '3.11'", + "tree-sitter>=0.25.2,<0.26", + "tree-sitter-python>=0.25.0,<0.26", ] classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Software Development :: Libraries :: Python Modules", ] [project.optional-dependencies] -ladybug = ["real_ladybug"] +ladybug = [] parquet = ["pyarrow"] -dev = ["pytest", "ruff"] +dev = ["pytest", "ruff", "tiktoken"] [project.scripts] codebase-graph = "codebase_graph.cli:main" +codebase-graph-mcp = "codebase_graph.mcp.server:main" + +[project.urls] +Homepage = "https://github.com/rabii-chaarani/codebaseGraph" +Repository = "https://github.com/rabii-chaarani/codebaseGraph" +Issues = "https://github.com/rabii-chaarani/codebaseGraph/issues" +Changelog = "https://github.com/rabii-chaarani/codebaseGraph/blob/main/CHANGELOG.md" [tool.setuptools.packages.find] where = ["src"] +include = ["codebase_graph*"] +exclude = ["tests*", "codebase_graph.egg-info*"] + +[tool.setuptools_scm] +version_scheme = "guess-next-dev" +local_scheme = "no-local-version" [tool.ruff] line-length = 120 target-version = "py310" [tool.pytest.ini_options] +pythonpath = ["src", "."] testpaths = ["tests"] diff --git a/release-please-config.json b/release-please-config.json new file mode 100644 index 0000000..5495177 --- /dev/null +++ b/release-please-config.json @@ -0,0 +1,10 @@ +{ + "packages": { + ".": { + "release-type": "python", + "package-name": "codebase-graph", + "include-v-in-tag": true, + "changelog-path": "CHANGELOG.md" + } + } +} diff --git a/scripts/check_release_gate.py b/scripts/check_release_gate.py new file mode 100644 index 0000000..3bbb4e0 --- /dev/null +++ b/scripts/check_release_gate.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import argparse +import re +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +try: + import tomllib +except ImportError: # pragma: no cover - Python 3.10 compatibility + import tomli as tomllib + + +REPO_ROOT = Path(__file__).resolve().parents[1] +WORKFLOWS = ( + Path(".github/workflows/ci.yml"), + Path(".github/workflows/release.yml"), +) +PYPI_CONFIRMATION_FLAGS = ( + "trusted-publisher", + "pypi-environment", + "hosted-ci-green", + "private-vulnerability-reporting", +) + + +@dataclass(frozen=True, slots=True) +class GateIssue: + severity: str + code: str + message: str + + def line(self) -> str: + return f"{self.severity}: {self.code}: {self.message}" + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser(description="Check local and production release readiness gates.") + parser.add_argument("--production", action="store_true", help="Require owner-controlled production release gates.") + parser.add_argument("--require-conda", action="store_true", help="Require the conda-forge recipe to be finalized.") + parser.add_argument( + "--confirm", + action="append", + default=[], + choices=PYPI_CONFIRMATION_FLAGS, + help="Confirm a manually verified external production gate.", + ) + args = parser.parse_args(argv) + + issues = run_checks( + production=args.production, + require_conda=args.require_conda, + confirmations=set(args.confirm), + ) + if not issues: + print("release gate passed") + return 0 + for issue in issues: + print(issue.line(), file=sys.stderr) + return 1 if any(issue.severity == "FAIL" for issue in issues) else 0 + + +def run_checks(*, production: bool, require_conda: bool, confirmations: set[str]) -> list[GateIssue]: + issues: list[GateIssue] = [] + issues.extend(_check_security_policy()) + issues.extend(_check_workflows()) + issues.extend(_check_release_workflow_permissions()) + if production: + issues.extend(_check_license_metadata()) + issues.extend(_check_external_confirmations(confirmations)) + if require_conda: + issues.extend(_check_conda_recipe()) + return issues + + +def _check_security_policy() -> list[GateIssue]: + issues: list[GateIssue] = [] + security = REPO_ROOT / "SECURITY.md" + manifest = REPO_ROOT / "MANIFEST.in" + if not security.exists(): + issues.append(GateIssue("FAIL", "security-policy-missing", "SECURITY.md is required.")) + if security.exists(): + text = security.read_text(encoding="utf-8") + for required in ("Reporting a Vulnerability", "graph_query", "--allow-remote"): + if required not in text: + issues.append(GateIssue("FAIL", "security-policy-incomplete", f"SECURITY.md must mention {required!r}.")) + if not manifest.exists() or "include SECURITY.md" not in manifest.read_text(encoding="utf-8"): + issues.append(GateIssue("FAIL", "security-policy-not-packaged", "MANIFEST.in must include SECURITY.md.")) + return issues + + +def _check_workflows() -> list[GateIssue]: + issues: list[GateIssue] = [] + for relative_path in WORKFLOWS: + path = REPO_ROOT / relative_path + if not path.exists(): + issues.append(GateIssue("FAIL", "workflow-missing", f"{relative_path} is required.")) + continue + text = path.read_text(encoding="utf-8") + issues.extend(_workflow_action_pin_issues(relative_path, text)) + for job in _jobs_missing_timeout(text): + issues.append(GateIssue("FAIL", "workflow-timeout-missing", f"{relative_path}:{job} has no timeout.")) + if re.search(r"pip_audit\b[^\n]*--dry-run", text): + issues.append(GateIssue("FAIL", "workflow-audit-dry-run", f"{relative_path} uses pip-audit --dry-run.")) + if "pip install dist/*.whl" not in text: + issues.append(GateIssue("FAIL", "workflow-wheel-smoke-missing", f"{relative_path} must smoke-test wheels.")) + if "pip install dist/*.tar.gz" not in text: + issues.append(GateIssue("FAIL", "workflow-sdist-smoke-missing", f"{relative_path} must smoke-test sdists.")) + return issues + + +def _workflow_action_pin_issues(relative_path: Path, text: str) -> list[GateIssue]: + issues: list[GateIssue] = [] + uses_pattern = re.compile(r"^\s*(?:-\s*)?uses:\s*(?P[^\s#]+)") + for line_number, line in enumerate(text.splitlines(), start=1): + match = uses_pattern.match(line) + if match is None: + continue + target = match.group("target").strip("'\"") + if target.startswith(("./", "../")): + continue + if "@" not in target: + issues.append(GateIssue("FAIL", "workflow-action-not-pinned", f"{relative_path}:{line_number}: {target}")) + continue + action, ref = target.rsplit("@", 1) + if not re.fullmatch(r"[0-9a-fA-F]{40}", ref): + issues.append( + GateIssue("FAIL", "workflow-action-not-pinned", f"{relative_path}:{line_number}: {action}@{ref}") + ) + return issues + + +def _jobs_missing_timeout(text: str) -> list[str]: + missing: list[str] = [] + in_jobs = False + current_job: str | None = None + current_has_timeout = False + + for line in text.splitlines(): + if line == "jobs:": + in_jobs = True + continue + if not in_jobs: + continue + if line and not line.startswith(" "): + break + job_match = re.match(r"^ ([A-Za-z0-9_-]+):\s*$", line) + if job_match is not None: + if current_job is not None and not current_has_timeout: + missing.append(current_job) + current_job = job_match.group(1) + current_has_timeout = False + continue + if current_job is not None and re.match(r"^ timeout-minutes:\s*\d+\s*$", line): + current_has_timeout = True + + if current_job is not None and not current_has_timeout: + missing.append(current_job) + return missing + + +def _check_release_workflow_permissions() -> list[GateIssue]: + workflow = REPO_ROOT / ".github/workflows/release.yml" + issues: list[GateIssue] = [] + if not workflow.exists(): + return [GateIssue("FAIL", "workflow-missing", ".github/workflows/release.yml is required.")] + + text = workflow.read_text(encoding="utf-8") + if ( + "production-gate:" not in text + or "python scripts/check_release_gate.py" not in text + or "--production" not in text + or "- production-gate" not in text + ): + issues.append( + GateIssue( + "FAIL", + "release-production-gate-missing", + "release workflow must run the production release gate before build/publish.", + ) + ) + if "environment:" not in text or "name: pypi" not in text: + issues.append(GateIssue("FAIL", "pypi-environment-missing", "release workflow must publish through pypi environment.")) + if "id-token: write" not in text: + issues.append(GateIssue("FAIL", "pypi-oidc-missing", "release workflow must grant id-token: write.")) + return issues + + +def _check_license_metadata() -> list[GateIssue]: + pyproject = _load_toml(REPO_ROOT / "pyproject.toml") + project = pyproject.get("project", {}) + license_value = project.get("license") + license_files = project.get("license-files") or pyproject.get("tool", {}).get("setuptools", {}).get("license-files") + license_paths = [path for path in REPO_ROOT.iterdir() if path.name.upper().startswith("LICENSE")] + issues: list[GateIssue] = [] + if not license_value and not license_files: + issues.append(GateIssue("FAIL", "license-metadata-missing", "pyproject.toml must declare package license metadata.")) + if not license_paths: + issues.append(GateIssue("FAIL", "license-file-missing", "repository must include the selected license file.")) + return issues + + +def _check_external_confirmations(confirmations: set[str]) -> list[GateIssue]: + return [ + GateIssue("FAIL", "external-confirmation-missing", f"production release requires --confirm {flag}.") + for flag in PYPI_CONFIRMATION_FLAGS + if flag not in confirmations + ] + + +def _check_conda_recipe() -> list[GateIssue]: + recipe_path = REPO_ROOT / "conda-forge/recipe/meta.yaml" + issues: list[GateIssue] = [] + if not recipe_path.exists(): + return [GateIssue("FAIL", "conda-recipe-missing", "conda-forge/recipe/meta.yaml is required.")] + + recipe = recipe_path.read_text(encoding="utf-8") + for placeholder in ("PUT_RELEASE_VERSION_HERE", "PUT_RELEASE_SDIST_SHA256_HERE", "PUT_SPDX_LICENSE_ID_HERE"): + if placeholder in recipe: + issues.append(GateIssue("FAIL", "conda-placeholder", f"conda recipe still contains {placeholder}.")) + return issues + + +def _load_toml(path: Path) -> dict[str, Any]: + with path.open("rb") as handle: + return tomllib.load(handle) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/compare_graph_output_tokens.py b/scripts/compare_graph_output_tokens.py new file mode 100644 index 0000000..24c5098 --- /dev/null +++ b/scripts/compare_graph_output_tokens.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +import statistics +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +REPO_ROOT = Path(__file__).resolve().parents[1] +SRC_ROOT = REPO_ROOT / "src" +if SRC_ROOT.as_posix() not in sys.path: + sys.path.insert(0, SRC_ROOT.as_posix()) + +from codebase_graph.mcp.runtime import runtime_config # noqa: E402 +from codebase_graph.mcp.tools import handle_tool_call # noqa: E402 +from codebase_graph.retrieval.block_format import ( # noqa: E402 + canonicalize_search_payload, + intentional_summary_omissions, + parse_search_block, + serialize_agent_search_block, + serialize_search_block, +) + + +DEFAULT_FIXTURE = REPO_ROOT / "tests" / "fixtures" / "search_service_graph_search.json" +DEFAULT_OUTPUT = REPO_ROOT / "docs" / "graph_output_token_comparison.md" + + +@dataclass(frozen=True, slots=True) +class Tokenizer: + encoding: Any + encoding_name: str + model_name: str | None + fallback_note: str = "" + + +def main(argv: list[str] | None = None) -> int: + args = _parser().parse_args(argv) + tokenizer = resolve_tokenizer(model=args.model, encoding_name=args.encoding) + samples = _load_samples(args) + rows = [_compare_sample(sample, tokenizer, block_format=args.block_format) for sample in samples] + aggregate = _aggregate(rows) + _write_report(args.output, rows, aggregate, tokenizer) + _print_summary(rows, aggregate, tokenizer, args.output) + return 0 + + +def resolve_tokenizer(*, model: str | None = None, encoding_name: str | None = None) -> Tokenizer: + try: + import tiktoken + except ImportError as exc: + raise RuntimeError( + "tiktoken is required for graph output token benchmarking. Install it in the active environment " + "or run this script where tiktoken is available." + ) from exc + + if encoding_name: + encoding = tiktoken.get_encoding(encoding_name) + return Tokenizer(encoding=encoding, encoding_name=encoding.name, model_name=model) + if model: + try: + encoding = tiktoken.encoding_for_model(model) + return Tokenizer(encoding=encoding, encoding_name=encoding.name, model_name=model) + except KeyError: + encoding = tiktoken.get_encoding("o200k_base") + return Tokenizer( + encoding=encoding, + encoding_name=encoding.name, + model_name=model, + fallback_note=f"model-specific encoding unavailable for {model}; defaulted to o200k_base", + ) + encoding = tiktoken.get_encoding("o200k_base") + return Tokenizer(encoding=encoding, encoding_name=encoding.name, model_name=None) + + +def count_tokens(text: str, encoding: Any) -> int: + return len(encoding.encode(text)) + + +def _parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compare graph-search JSON output with readable block output.") + parser.add_argument("--queries", action="append", default=[], help="Graph-search query to run; repeat as needed") + parser.add_argument("--fixture", action="append", type=Path, default=[], help="Path to a graph-search JSON fixture") + parser.add_argument("--model", default=None, help="Model name used to resolve a tiktoken encoding") + parser.add_argument("--encoding", default=None, help="Explicit tiktoken encoding name") + parser.add_argument("--limit", type=int, default=3, help="Graph-search result limit for live queries") + parser.add_argument("--profile", default="brief", help="Graph-search context profile for live queries") + parser.add_argument("--budget", type=int, default=600, help="Graph-search context budget for live queries") + parser.add_argument("--output", type=Path, default=DEFAULT_OUTPUT, help="Markdown report path") + parser.add_argument("--repo-root", type=Path, default=REPO_ROOT, help="Repository root for live graph-search queries") + parser.add_argument("--config", type=Path, default=None, help="Optional codebaseGraph setup config path") + parser.add_argument("--db", type=Path, default=None, help="Optional codebaseGraph database path") + parser.add_argument("--manifest", type=Path, default=None, help="Optional codebaseGraph manifest path") + parser.add_argument("--context-limit", type=int, default=2, help="Context items per result for live queries") + parser.add_argument("--detail", choices=("standard", "slim"), default="slim", help="Raw graph-search detail level") + parser.add_argument( + "--block-format", + choices=("ontology", "agent"), + default="ontology", + help="Block serializer to compare against raw JSON", + ) + return parser + + +def _load_samples(args: argparse.Namespace) -> list[dict[str, Any]]: + samples: list[dict[str, Any]] = [] + fixture_paths = args.fixture or ([] if args.queries else [DEFAULT_FIXTURE]) + for fixture_path in fixture_paths: + payload = json.loads(fixture_path.read_text(encoding="utf-8")) + if isinstance(payload, list): + samples.extend(_fixture_sample(item, fixture_path) for item in payload) + else: + samples.append(_fixture_sample(payload, fixture_path)) + if args.queries: + runtime = runtime_config( + repo_root=args.repo_root, + config_path=args.config, + db_path=args.db, + manifest_path=args.manifest, + ) + for query in args.queries: + payload = handle_tool_call( + "graph_search", + { + "query": query, + "limit": args.limit, + "profile": args.profile, + "budget": args.budget, + "context_limit": args.context_limit, + "detail": args.detail, + }, + runtime=runtime, + ) + samples.append({"name": query, "payload": payload, "source": "live graph-search"}) + if not samples: + raise ValueError("No samples found. Provide --queries or --fixture.") + return samples + + +def _fixture_sample(payload: dict[str, Any], fixture_path: Path) -> dict[str, Any]: + if "payload" in payload and isinstance(payload["payload"], dict): + name = str(payload.get("name") or payload["payload"].get("query") or fixture_path.stem) + return {"name": name, "payload": payload["payload"], "source": fixture_path.as_posix()} + return {"name": str(payload.get("query") or fixture_path.stem), "payload": payload, "source": fixture_path.as_posix()} + + +def _compare_sample(sample: dict[str, Any], tokenizer: Tokenizer, *, block_format: str) -> dict[str, Any]: + payload = sample["payload"] + raw_text = _raw_json(payload) + if block_format == "agent": + block_text = serialize_agent_search_block(payload) + else: + block_text = serialize_search_block(payload) + raw_canonical = canonicalize_search_payload(payload) + block_canonical = parse_search_block(block_text) + if raw_canonical != block_canonical: + raise AssertionError( + f"Block output is not semantically equivalent for {sample['name']}:\n" + f"raw={json.dumps(raw_canonical, sort_keys=True)}\n" + f"block={json.dumps(block_canonical, sort_keys=True)}" + ) + raw_tokens = count_tokens(raw_text, tokenizer.encoding) + block_tokens = count_tokens(block_text, tokenizer.encoding) + raw_chars = len(raw_text) + block_chars = len(block_text) + token_delta = raw_tokens - block_tokens + char_delta = raw_chars - block_chars + result_count = len(payload.get("results", [])) + context_edges = sum(len(result.get("context", [])) for result in payload.get("results", [])) + return { + "query": sample["name"], + "source": sample["source"], + "raw_chars": raw_chars, + "block_chars": block_chars, + "raw_tokens": raw_tokens, + "block_tokens": block_tokens, + "token_delta": token_delta, + "token_reduction_pct": _pct(token_delta, raw_tokens), + "char_reduction_pct": _pct(char_delta, raw_chars), + "results": result_count, + "context_edges": context_edges, + "tokenizer": tokenizer.encoding_name, + "model": tokenizer.model_name or "", + "block_format": block_format, + "intentional_omissions": intentional_summary_omissions(payload), + } + + +def _aggregate(rows: list[dict[str, Any]]) -> dict[str, Any]: + raw_tokens = [row["raw_tokens"] for row in rows] + block_tokens = [row["block_tokens"] for row in rows] + reductions = [row["token_reduction_pct"] for row in rows] + total_raw = sum(raw_tokens) + total_block = sum(block_tokens) + sorted_by_reduction = sorted(rows, key=lambda row: row["token_reduction_pct"]) + aggregate = { + "sample_count": len(rows), + "total_raw_tokens": total_raw, + "total_block_tokens": total_block, + "total_token_delta": total_raw - total_block, + "overall_token_reduction_pct": _pct(total_raw - total_block, total_raw), + "mean_token_reduction_pct": statistics.fmean(reductions) if reductions else 0.0, + "median_token_reduction_pct": statistics.median(reductions) if reductions else 0.0, + "min_reduction_case": sorted_by_reduction[0]["query"] if sorted_by_reduction else "", + "min_reduction_pct": sorted_by_reduction[0]["token_reduction_pct"] if sorted_by_reduction else 0.0, + "max_reduction_case": sorted_by_reduction[-1]["query"] if sorted_by_reduction else "", + "max_reduction_pct": sorted_by_reduction[-1]["token_reduction_pct"] if sorted_by_reduction else 0.0, + "p90_raw_tokens": None, + "p90_block_tokens": None, + } + if len(rows) >= 10: + aggregate["p90_raw_tokens"] = _p90(raw_tokens) + aggregate["p90_block_tokens"] = _p90(block_tokens) + return aggregate + + +def _write_report(path: Path, rows: list[dict[str, Any]], aggregate: dict[str, Any], tokenizer: Tokenizer) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + block_format = rows[0].get("block_format", "ontology") if rows else "ontology" + table_rows = "\n".join( + "| {query} | {results} | {context_edges} | {raw_tokens:,} | {block_tokens:,} | {token_delta:,} | " + "{token_reduction_pct:.1f}% |".format(**row) + for row in rows + ) + omission_lines = sorted({omission for row in rows for omission in row["intentional_omissions"]}) + omissions = "\n".join(f"- `{item}`" for item in omission_lines) or "- None" + fallback = f"\n- {tokenizer.fallback_note}" if tokenizer.fallback_note else "" + p90 = ( + f"- p90 raw/block tokens: {aggregate['p90_raw_tokens']:,} / {aggregate['p90_block_tokens']:,}\n" + if aggregate["p90_raw_tokens"] is not None + else "- p90 raw/block tokens: not reported because fewer than 10 samples were compared\n" + ) + if block_format == "agent": + format_description = "- Reduced agent block format: grouped `file path` blocks with stable result IDs retained and query settings, boilerplate same-span scope context, duplicate child-result edges, low-value type annotations, and excess score precision removed." + preservation_text = ( + "This reduced display mode intentionally does not preserve every raw JSON field. It keeps the fields most useful " + "for code navigation and follow-up inspection: file path, stable result `id`, ordered result type/label/span, rounded `rank_score`, " + "and non-boilerplate context with summaries. Use `--block-format ontology` when exact canonical equivalence is required." + ) + recommendation = ( + "Use the reduced agent block when the consumer is an interactive coding agent optimizing for quick navigation. " + "Use the ontology-preserving block or raw JSON when stable IDs, exact scores, or complete context records are required." + ) + else: + format_description = "- Ontology-preserving block format: grouped `file path` blocks with readable `Class`, `Method`, `Scope`, relation, `label`, `span`, `id`, and `rank_score` terms left literal." + preservation_text = ( + "The validator normalizes raw JSON and block output into canonical result records preserving `type`, `label`, " + "`path`, `span`, `id`, `rank_score`, and ordered context records with `direction`, `relation`, `type`, " + "`label`, `path`, `span`, and non-boilerplate `summary`." + ) + recommendation = ( + "Use the ontology-preserving block format by default for agent-facing graph-search output when consumers need " + "readable context. Keep JSON available for machine APIs and tests that require strict structured payloads." + ) + path.write_text( + "\n".join( + [ + "# Graph Search Output Token Comparison", + "", + "## Method", + "- Raw format: compact JSON emitted by the current graph-search payload serializer, counted from the exact serialized JSON text with sorted keys and compact separators.", + format_description, + f"- Tokenizer/model: encoding `{tokenizer.encoding_name}`" + + (f" resolved from model `{tokenizer.model_name}`." if tokenizer.model_name else ".") + + fallback, + "- Count method: payload-only tokens using `len(encoding.encode(text))`; chat-message wrapper tokens were not included.", + "", + "## Results", + "| Query | Results | Context edges | Raw tokens | Block tokens | Saved tokens | Reduction % |", + "|---|---:|---:|---:|---:|---:|---:|", + table_rows, + "", + "## Aggregate Summary", + f"- Samples: {aggregate['sample_count']}", + f"- Total raw tokens: {aggregate['total_raw_tokens']:,}", + f"- Total block tokens: {aggregate['total_block_tokens']:,}", + f"- Total saved tokens: {aggregate['total_token_delta']:,}", + f"- Overall reduction: {aggregate['overall_token_reduction_pct']:.1f}%", + f"- Mean reduction: {aggregate['mean_token_reduction_pct']:.1f}%", + f"- Median reduction: {aggregate['median_token_reduction_pct']:.1f}%", + f"- Min reduction: {aggregate['min_reduction_case']} ({aggregate['min_reduction_pct']:.1f}%)", + f"- Max reduction: {aggregate['max_reduction_case']} ({aggregate['max_reduction_pct']:.1f}%)", + p90.rstrip(), + "", + "## Ontology Preservation", + preservation_text, + "", + "Intentional omissions:", + omissions, + "", + "Known limitations: the block parser validates the supported graph-search fixture shape and live graph-search output shape; it is not a general-purpose parser for hand-written variants.", + "", + "## Recommendation", + recommendation, + "", + ] + ), + encoding="utf-8", + ) + + +def _print_summary(rows: list[dict[str, Any]], aggregate: dict[str, Any], tokenizer: Tokenizer, output_path: Path) -> None: + print(f"Compared {len(rows)} graph-search outputs using {tokenizer.encoding_name}.") + print(f"Raw: {aggregate['total_raw_tokens']:,} tokens") + print(f"Block: {aggregate['total_block_tokens']:,} tokens") + print(f"Saved: {aggregate['total_token_delta']:,} tokens") + print(f"Reduction: {aggregate['overall_token_reduction_pct']:.1f}%") + print(f"Report written to {output_path.as_posix()}") + + +def _raw_json(payload: dict[str, Any]) -> str: + return json.dumps(payload, separators=(",", ":"), sort_keys=True) + + +def _pct(delta: int | float, original: int | float) -> float: + return (float(delta) / float(original) * 100.0) if original else 0.0 + + +def _p90(values: list[int]) -> int: + ordered = sorted(values) + index = int(0.9 * (len(ordered) - 1)) + return ordered[index] + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/smoke_built_wheel.py b/scripts/smoke_built_wheel.py new file mode 100644 index 0000000..9825ca2 --- /dev/null +++ b/scripts/smoke_built_wheel.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import json +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, BinaryIO + + +def main(argv: list[str]) -> int: + if len(argv) != 2: + raise SystemExit("usage: smoke_built_wheel.py /path/to/codebase-graph") + executable = Path(argv[1]) + with tempfile.TemporaryDirectory(prefix="codebase-graph-wheel-smoke-") as tmp_dir: + repo_root = _sample_repo(Path(tmp_dir) / "sample_repo") + setup = _run( + [ + executable.as_posix(), + "setup", + "--repo-root", + repo_root.as_posix(), + "--mcp-client", + "none", + "--instructions-target", + "skip", + ] + ) + setup_payload = json.loads(setup.stdout) + config_path = Path(setup_payload["config_path"]) + + health = json.loads(_run([executable.as_posix(), "graph-health", "--repo-root", repo_root.as_posix()]).stdout) + if not health.get("ok") or not health.get("graph_readable"): + raise AssertionError(f"graph-health failed readiness smoke: {health}") + + search = json.loads( + _run( + [ + executable.as_posix(), + "graph-search", + "SampleService", + "--repo-root", + repo_root.as_posix(), + "--no-refresh", + "--detail", + "slim", + "--json", + ] + ).stdout + ) + if not search.get("results"): + raise AssertionError(f"graph-search returned no results: {search}") + + _install_verify_smoke(executable, config_path, Path(tmp_dir) / "mcp.json") + _mcp_smoke([executable.as_posix(), "mcp", "serve", "--config", config_path.as_posix()]) + return 0 + + +def _run(command: list[str]) -> subprocess.CompletedProcess[str]: + return subprocess.run(command, capture_output=True, text=True, check=True) + + +def _install_verify_smoke(executable: Path, config_path: Path, client_config_path: Path) -> None: + verify = json.loads( + _run( + [ + executable.as_posix(), + "mcp", + "install", + "--client", + "generic", + "--config-path", + config_path.as_posix(), + "--client-config-path", + client_config_path.as_posix(), + "--verify", + "--json", + ] + ).stdout + ) + verification = verify.get("verification") or {} + stdio = verification.get("stdio") or {} + checks = stdio.get("checks") or {} + required_checks = ("initialize", "tools_list", "graph_health", "graph_search", "tool_error_result") + if verification.get("ok") is not True or not all(checks.get(check) is True for check in required_checks): + raise AssertionError(f"mcp install --verify failed readiness smoke: {verify}") + + +def _sample_repo(repo_root: Path) -> Path: + package = repo_root / "sample_project" + package.mkdir(parents=True) + (package / "__init__.py").write_text("", encoding="utf-8") + (package / "service.py").write_text( + "class SampleService:\n" + " def run(self) -> str:\n" + " return helper()\n\n" + "def helper() -> str:\n" + " return 'ok'\n", + encoding="utf-8", + ) + (repo_root / "README.md").write_text("# Sample Repo\n\nSampleService smoke fixture.\n", encoding="utf-8") + return repo_root + + +def _mcp_smoke(command: list[str]) -> None: + proc = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + assert proc.stdin is not None + assert proc.stdout is not None + try: + initialized = _rpc(proc.stdin, proc.stdout, "initialize", {"protocolVersion": "2025-11-25"}) + listed = _rpc(proc.stdin, proc.stdout, "tools/list", {}) + health = _rpc(proc.stdin, proc.stdout, "tools/call", {"name": "graph_health", "arguments": {}}) + finally: + proc.stdin.close() + proc.wait(timeout=10) + assert proc.stderr is not None + stderr = proc.stderr.read() + if proc.returncode != 0: + raise AssertionError(stderr.decode("utf-8", errors="replace")) + if initialized["result"]["protocolVersion"] != "2025-11-25": + raise AssertionError(initialized) + tool_names = {tool["name"] for tool in listed["result"]["tools"]} + if not {"graph_health", "graph_search", "graph_query"}.issubset(tool_names): + raise AssertionError(listed) + if health["result"]["structuredContent"].get("ok") is not True: + raise AssertionError(health) + + +def _rpc(stdin: BinaryIO, stdout: BinaryIO, method: str, params: dict[str, Any]) -> dict[str, Any]: + request_id = _rpc.counter + _rpc.counter += 1 + body = json.dumps({"jsonrpc": "2.0", "id": request_id, "method": method, "params": params}).encode("utf-8") + stdin.write(f"Content-Length: {len(body)}\r\n\r\n".encode("ascii") + body) + stdin.flush() + return _read_response(stdout) + + +_rpc.counter = 1 # type: ignore[attr-defined] + + +def _read_response(stdout: BinaryIO) -> dict[str, Any]: + header = stdout.readline() + if not header.lower().startswith(b"content-length:"): + raise AssertionError(f"unexpected MCP header: {header!r}") + length = int(header.split(b":", 1)[1].strip()) + separator = stdout.readline() + if separator not in {b"\r\n", b"\n"}: + raise AssertionError(f"unexpected MCP header separator: {separator!r}") + return json.loads(stdout.read(length).decode("utf-8")) + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv)) diff --git a/src/codebase_graph/__init__.py b/src/codebase_graph/__init__.py index 280256d..0edc212 100644 --- a/src/codebase_graph/__init__.py +++ b/src/codebase_graph/__init__.py @@ -1,22 +1,3 @@ -from .graph_core import CodebaseGraph, GraphCoreStatus -from .ladybug import ( - DEFAULT_EMBEDDING_DIMENSIONS, - HashingEmbeddingProvider, - LadybugGraphExport, - LadybugGraphExporter, - LadybugGraphStore, - LadybugUnavailableError, -) -from .ontology import ONTOLOGY_NAME +"""Production namespace for codebase-graph.""" -__all__ = [ - "CodebaseGraph", - "DEFAULT_EMBEDDING_DIMENSIONS", - "GraphCoreStatus", - "HashingEmbeddingProvider", - "LadybugGraphExport", - "LadybugGraphExporter", - "LadybugGraphStore", - "LadybugUnavailableError", - "ONTOLOGY_NAME", -] +__all__ = [] diff --git a/src/codebase_graph/__main__.py b/src/codebase_graph/__main__.py deleted file mode 100644 index bfdcd0c..0000000 --- a/src/codebase_graph/__main__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .cli import main - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/codebase_graph/cli.py b/src/codebase_graph/cli.py deleted file mode 100644 index ce80e54..0000000 --- a/src/codebase_graph/cli.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .graph_core import main - -__all__ = ["main"] diff --git a/src/codebase_graph/cli/__init__.py b/src/codebase_graph/cli/__init__.py new file mode 100644 index 0000000..b86f24b --- /dev/null +++ b/src/codebase_graph/cli/__init__.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import argparse +import json +import os +from collections.abc import Sequence +from pathlib import Path + +from codebase_graph.db import create_ladybug_database +from codebase_graph.ingest import GraphMaterializer +from codebase_graph.mcp.graph_commands import ( + add_compact_context_arguments, + add_json_output_arguments, + graph_command_names, + graph_command_spec, + graph_command_specs, +) +from codebase_graph.mcp.runtime import runtime_config +from codebase_graph.mcp.tools import handle_tool_call +from codebase_graph.retrieval import SearchRequest, SearchService, serialize_graph_block +from codebase_graph.setup import SetupError, SetupOptions, run_setup +from codebase_graph.setup.clients import supported_client_ids +from codebase_graph.setup.installer import McpInstallOptions, install_mcp_clients, supported_install_client_ids + + +def main(argv: Sequence[str] | None = None) -> int: + parser = _build_parser() + args = parser.parse_args(argv) + if args.command == "materialize": + materializer = GraphMaterializer( + Path(args.source_root), + db_path=args.db, + manifest_path=args.manifest, + include_fts=not args.no_fts, + ) + try: + result = materializer.materialize(mode=args.mode) + finally: + materializer.close() + _print_json(result.as_dict(), args) + return 0 + if args.command in {"search", "context"}: + return _run_legacy_search_command(parser, args) + if args.command in graph_command_names(): + return _run_graph_command(parser, args) + if args.command == "setup": + try: + result = run_setup( + SetupOptions( + repo_root=args.repo_root, + mcp_client=args.mcp_client, + mcp_config_path=args.mcp_config_path, + skip_mcp_config=args.skip_mcp_config, + dry_run=args.dry_run, + instructions_target=args.instructions_target, + mode=args.mode, + ) + ) + except SetupError as exc: + parser.error(str(exc)) + _print_json(result.as_dict(), args) + return 0 + if args.command == "mcp" and args.mcp_command == "install": + setup_config_path = ( + Path(args.config_path).expanduser().resolve() + if args.config_path is not None + else Path(args.repo_root).expanduser().resolve() / ".codebaseGraph" / "config.json" + ) + try: + results = install_mcp_clients( + McpInstallOptions( + client=args.client, + scope=args.scope, + setup_config_path=setup_config_path, + server_name=args.name, + client_config_path=args.client_config_path, + dry_run=args.dry_run, + verify=args.verify, + ) + ) + except (OSError, ValueError) as exc: + parser.error(str(exc)) + payload: dict[str, object] + if args.client == "all": + payload = {"results": [result.as_dict() for result in results]} + else: + payload = results[0].as_dict() + if args.json: + _print_json(payload, args) + else: + _print_mcp_install_results(results) + return 1 if any(result.action == "failed" for result in results) else 0 + if args.command == "mcp" and args.mcp_command == "serve": + from codebase_graph.mcp.server import serve_stdio + + serve_stdio(repo_root=args.repo_root, config_path=args.config, db_path=args.db, manifest_path=args.manifest) + return 0 + if args.command == "mcp" and args.mcp_command == "http": + from codebase_graph.mcp.server import serve_http + + auth_token = _http_auth_token(args, parser) + serve_http( + repo_root=args.repo_root, + config_path=args.config, + db_path=args.db, + manifest_path=args.manifest, + host=args.host, + port=args.port, + endpoint_path=args.path, + allow_remote=args.allow_remote, + auth_token=auth_token, + ) + return 0 + parser.error(f"Unknown command: {args.command}") + return 2 + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(prog="codebase-graph") + subparsers = parser.add_subparsers(dest="command", required=True) + + materialize_parser = subparsers.add_parser("materialize", help="Materialize the code graph") + materialize_parser.add_argument("--source-root", default=".", help="Repository or source root to scan") + materialize_parser.add_argument("--db", default=None, help="LadybugDB path; defaults under .codebaseGraph") + materialize_parser.add_argument("--manifest", default=None, help="Manifest path; defaults under .codebaseGraph") + materialize_parser.add_argument("--mode", choices=("full", "changed"), default="changed") + materialize_parser.add_argument("--no-fts", action="store_true", help="Skip FTS index creation") + _add_json_output_arguments(materialize_parser) + + search_parser = subparsers.add_parser("search", help="Search the code graph with compact context") + _add_search_arguments(search_parser) + + context_parser = subparsers.add_parser("context", help="Return compact context for a search query") + _add_search_arguments(context_parser) + + for spec in graph_command_specs(): + graph_parser = subparsers.add_parser(spec.command_name, help=spec.help) + spec.add_arguments(graph_parser) + + setup_parser = subparsers.add_parser("setup", help="Bootstrap codebaseGraph state for a repository") + setup_parser.add_argument("--repo-root", default=".", help="Repository root to configure") + setup_parser.add_argument("--mcp-client", choices=supported_client_ids(), default="codex") + setup_parser.add_argument("--mcp-config-path", default=None, help="Override MCP client config path") + setup_parser.add_argument("--skip-mcp-config", action="store_true", help="Do not write MCP client config") + setup_parser.add_argument("--dry-run", action="store_true", help="Return the MCP config patch without writing it") + setup_parser.add_argument( + "--instructions-target", + choices=("auto", "agents", "claude", "skip"), + default="auto", + help="Instruction file to update", + ) + setup_parser.add_argument("--mode", choices=("full", "changed"), default="changed", help="Materialization mode") + setup_parser.add_argument("--json", action="store_true", help="Emit JSON output") + _add_json_output_arguments(setup_parser) + + mcp_parser = subparsers.add_parser("mcp", help="Run or inspect the MCP server") + mcp_subparsers = mcp_parser.add_subparsers(dest="mcp_command", required=True) + install_parser = mcp_subparsers.add_parser("install", help="Install the MCP server in a supported client") + install_parser.add_argument("--client", choices=supported_install_client_ids(include_all=True), default="codex") + install_parser.add_argument("--scope", choices=("local", "user", "project"), default="local") + install_parser.add_argument("--name", default=None, help="MCP server name; defaults to codebase_graph-") + install_parser.add_argument("--config-path", default=None, help="Path to .codebaseGraph/config.json") + install_parser.add_argument("--client-config-path", default=None, help="Override the target MCP client config path") + install_parser.add_argument("--repo-root", default=".", help="Repository root used to find .codebaseGraph/config.json") + install_parser.add_argument("--dry-run", action="store_true", help="Show the install action without writing or invoking CLIs") + install_parser.add_argument("--verify", action="store_true", help="Run direct MCP smoke checks after installation") + install_parser.add_argument("--json", action="store_true", help="Emit JSON output") + _add_json_output_arguments(install_parser) + + serve_parser = mcp_subparsers.add_parser("serve", help="Serve graph tools over MCP stdio") + serve_parser.add_argument("--repo-root", default=".", help="Repository root containing .codebaseGraph/config.json") + serve_parser.add_argument("--config", default=None, help="Path to .codebaseGraph/config.json") + serve_parser.add_argument("--db", default=None, help="Override LadyBugDB path") + serve_parser.add_argument("--manifest", default=None, help="Override manifest path") + http_parser = mcp_subparsers.add_parser("http", help="Serve graph tools over Streamable HTTP") + http_parser.add_argument("--repo-root", default=".", help="Repository root containing .codebaseGraph/config.json") + http_parser.add_argument("--config", default=None, help="Path to .codebaseGraph/config.json") + http_parser.add_argument("--db", default=None, help="Override LadyBugDB path") + http_parser.add_argument("--manifest", default=None, help="Override manifest path") + http_parser.add_argument("--host", default="127.0.0.1", help="HTTP bind host; default keeps the server local") + http_parser.add_argument("--port", type=int, default=8765, help="HTTP bind port") + http_parser.add_argument("--path", default="/mcp", help="MCP HTTP endpoint path") + http_parser.add_argument( + "--allow-remote", + action="store_true", + help="Allow binding MCP HTTP to a non-local host; requires an auth token", + ) + http_parser.add_argument( + "--auth-token", + default=None, + help="Bearer token required for HTTP requests; prefer --auth-token-env to avoid shell history exposure", + ) + http_parser.add_argument("--auth-token-env", default=None, help="Environment variable containing the HTTP bearer token") + return parser + + +def _add_search_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("query", help="Search query") + parser.add_argument("--source-root", default=".", help="Repository or source root to search") + parser.add_argument("--db", default=None, help="LadybugDB path; defaults under .codebaseGraph") + parser.add_argument("--manifest", default=None, help="Manifest path; defaults under .codebaseGraph") + add_compact_context_arguments(parser) + parser.add_argument("--no-refresh", action="store_true", help="Query the existing graph without changed materialization") + parser.add_argument("--json", action="store_true", help="Emit compact JSON output") + + +def _runtime(args: argparse.Namespace) -> object: + return runtime_config( + repo_root=args.repo_root, + config_path=args.config, + db_path=args.db, + manifest_path=args.manifest, + ) + + +def _run_legacy_search_command(parser: argparse.ArgumentParser, args: argparse.Namespace) -> int: + try: + request = _search_request_from_args(args) + except ValueError as exc: + parser.error(str(exc)) + materializer = GraphMaterializer( + Path(args.source_root), + db_path=args.db, + manifest_path=args.manifest, + include_fts=True, + ) + if args.no_refresh: + with create_ladybug_database(materializer.db_path, include_fts=True, read_only=True) as store: + payload = SearchService(store).search(request) + else: + try: + materializer.materialize(mode="changed") + payload = SearchService(materializer.store).search(request) + finally: + materializer.close() + _print_payload(payload.as_dict(detail=args.detail), args) + return 0 + + +def _search_request_from_args(args: argparse.Namespace) -> SearchRequest: + request = SearchRequest( + query=args.query, + limit=args.limit, + profile=args.profile, + budget=args.budget, + max_depth=args.max_depth, + context_limit=args.context_limit, + detail=args.detail, + ) + request.validate() + return request + + +def _run_graph_command(parser: argparse.ArgumentParser, args: argparse.Namespace) -> int: + spec = graph_command_spec(args.command) + try: + arguments = spec.payload_from_args(args) + runtime = _runtime(args) if spec.requires_runtime else None + payload = handle_tool_call(spec.tool_name, arguments, runtime=runtime) + except (OSError, ValueError) as exc: + parser.error(str(exc)) + _print_payload(payload, args) + return 0 + + +def _add_json_output_arguments(parser: argparse.ArgumentParser) -> None: + add_json_output_arguments(parser) + + +def _print_json(payload: object, args: argparse.Namespace) -> None: + print(_json_dumps(payload, pretty=getattr(args, "pretty", False))) + + +def _print_payload(payload: dict[str, object], args: argparse.Namespace) -> None: + if getattr(args, "json", False): + _print_json(payload, args) + return + if getattr(args, "format", "json") == "block": + print(serialize_graph_block(payload), end="") + return + _print_json(payload, args) + + +def _json_dumps(payload: object, *, pretty: bool) -> str: + if pretty: + return json.dumps(payload, indent=2, sort_keys=True) + return json.dumps(payload, separators=(",", ":"), sort_keys=True) + + +def _print_mcp_install_results(results: Sequence[object]) -> None: + for result in results: + action = getattr(result, "action") + client = getattr(result, "client") + method = getattr(result, "method") or "none" + server_name = getattr(result, "server_name") + target = getattr(result, "path") or " ".join(getattr(result, "command") or []) + suffix = f" -> {target}" if target else "" + print(f"{client}: {action} {server_name} via {method}{suffix}") + + +def _http_auth_token(args: argparse.Namespace, parser: argparse.ArgumentParser) -> str | None: + if args.auth_token and args.auth_token_env: + parser.error("mcp http accepts either --auth-token or --auth-token-env, not both") + if args.auth_token_env: + value = os.environ.get(args.auth_token_env) + if not value: + parser.error(f"Environment variable {args.auth_token_env!r} must contain the HTTP bearer token") + return value + return args.auth_token + + +__all__ = ["main"] diff --git a/src/codebase_graph/code_map.py b/src/codebase_graph/code_map.py deleted file mode 100644 index f98d3b4..0000000 --- a/src/codebase_graph/code_map.py +++ /dev/null @@ -1,224 +0,0 @@ -from __future__ import annotations - -import ast -import hashlib -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -CODE_EXTENSIONS = {".py"} -MAX_INDEXED_FILE_BYTES = 1_000_000 -EXCLUDED_FILENAMES = {".DS_Store"} -EXCLUDED_PARTS = { - ".git", - ".hg", - ".mypy_cache", - ".pytest_cache", - ".ruff_cache", - ".tox", - ".venv", - ".codebase_graph", - "__pycache__", - "build", - "dist", - "htmlcov", - "node_modules", - "site-packages", -} - -@dataclass(slots=True) -class CodeSymbol: - id: str - label: str - kind: str - path: str - module_name: str - qualified_name: str - line_start: int | None = None - line_end: int | None = None - decorators: list[str] = field(default_factory=list) - bases: list[str] = field(default_factory=list) - summary: str = "" - -@dataclass(slots=True) -class CodeFile: - id: str - path: str - module_name: str - language: str - line_count: int - summary: str = "" - imports: list[str] = field(default_factory=list) - calls: list[str] = field(default_factory=list) - symbols: list[CodeSymbol] = field(default_factory=list) - -@dataclass(slots=True) -class CodebaseMap: - files: list[CodeFile] - - def as_dict(self) -> dict[str, Any]: - return {"files": [_file_as_dict(file) for file in self.files]} - -class CodebaseGraphBuilder: - def __init__(self, root: str | Path) -> None: - self.root = Path(root) - - def build(self) -> CodebaseMap: - files = [_parse_python_file(path, self.root) for path in _iter_python_files(self.root)] - return CodebaseMap(files=files) - -def is_excluded_codebase_path_parts(parts: tuple[str, ...]) -> bool: - return any(part in EXCLUDED_PARTS for part in parts) - -def _iter_indexable_files(root: Path, suffixes: set[str], *, case_insensitive_suffixes: bool = False) -> list[Path]: - if not root.exists(): - return [] - paths: list[Path] = [] - for path in root.rglob("*"): - if _is_indexable_file(path, root, suffixes, case_insensitive_suffixes=case_insensitive_suffixes): - paths.append(path) - return sorted(paths) - -def _is_indexable_file( - path: Path, - root: Path, - suffixes: set[str], - *, - case_insensitive_suffixes: bool = False, -) -> bool: - if not path.is_file() or path.name in EXCLUDED_FILENAMES: - return False - suffix = path.suffix.lower() if case_insensitive_suffixes else path.suffix - if suffix not in suffixes: - return False - try: - rel_parts = path.relative_to(root).parts - except ValueError: - return False - if is_excluded_codebase_path_parts(rel_parts): - return False - try: - return path.stat().st_size <= MAX_INDEXED_FILE_BYTES - except OSError: - return False - -def _iter_python_files(root: Path) -> list[Path]: - return _iter_indexable_files(root, CODE_EXTENSIONS) - -def _parse_python_file(path: Path, root: Path) -> CodeFile: - rel_path = path.relative_to(root).as_posix() - text = path.read_text(encoding="utf-8", errors="replace") - module_name = _module_name(rel_path) - try: - tree = ast.parse(text) - except SyntaxError: - return CodeFile(_id("file", rel_path), rel_path, module_name, "python", len(text.splitlines()), "Syntax error") - - imports: list[str] = [] - calls: list[str] = [] - symbols: list[CodeSymbol] = [] - for node in ast.walk(tree): - if isinstance(node, ast.Import): - imports.extend(alias.name for alias in node.names) - elif isinstance(node, ast.ImportFrom): - module = "." * node.level + (node.module or "") - imports.append(module) - elif isinstance(node, ast.Call): - calls.append(_call_name(node.func)) - - for node in tree.body: - if isinstance(node, ast.ClassDef): - class_qn = f"{module_name}.{node.name}" if module_name else node.name - symbols.append( - CodeSymbol( - id=_id("symbol", class_qn), - label=node.name, - kind="python_class", - path=rel_path, - module_name=module_name, - qualified_name=class_qn, - line_start=getattr(node, "lineno", None), - line_end=getattr(node, "end_lineno", None), - decorators=[_call_name(item) for item in node.decorator_list], - bases=[_call_name(item) for item in node.bases], - summary=ast.get_docstring(node) or "", - ) - ) - for child in node.body: - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)): - method_qn = f"{class_qn}.{child.name}" - symbols.append( - CodeSymbol( - id=_id("symbol", method_qn), - label=child.name, - kind="python_method", - path=rel_path, - module_name=module_name, - qualified_name=method_qn, - line_start=getattr(child, "lineno", None), - line_end=getattr(child, "end_lineno", None), - decorators=[_call_name(item) for item in child.decorator_list], - summary=ast.get_docstring(child) or "", - ) - ) - elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - function_qn = f"{module_name}.{node.name}" if module_name else node.name - symbols.append( - CodeSymbol( - id=_id("symbol", function_qn), - label=node.name, - kind="python_function", - path=rel_path, - module_name=module_name, - qualified_name=function_qn, - line_start=getattr(node, "lineno", None), - line_end=getattr(node, "end_lineno", None), - decorators=[_call_name(item) for item in node.decorator_list], - summary=ast.get_docstring(node) or "", - ) - ) - - return CodeFile( - id=_id("file", rel_path), - path=rel_path, - module_name=module_name, - language="python", - line_count=len(text.splitlines()), - summary=ast.get_docstring(tree) or "", - imports=sorted(set(imports)), - calls=sorted({call for call in calls if call}), - symbols=symbols, - ) - -def _module_name(rel_path: str) -> str: - without_suffix = rel_path[:-3] if rel_path.endswith(".py") else rel_path - parts = without_suffix.split("/") - if parts[-1] == "__init__": - parts = parts[:-1] - return ".".join(part for part in parts if part) - -def _call_name(node: ast.AST) -> str: - if isinstance(node, ast.Name): - return node.id - if isinstance(node, ast.Attribute): - base = _call_name(node.value) - return f"{base}.{node.attr}" if base else node.attr - if isinstance(node, ast.Constant): - return repr(node.value) - return "" - -def _id(prefix: str, value: str) -> str: - return f"{prefix}:{hashlib.sha1(value.encode('utf-8')).hexdigest()[:20]}" - -def _file_as_dict(file: CodeFile) -> dict[str, Any]: - return { - "id": file.id, - "path": file.path, - "module_name": file.module_name, - "language": file.language, - "line_count": file.line_count, - "summary": file.summary, - "imports": file.imports, - "calls": file.calls, - "symbols": [symbol.__dict__ for symbol in file.symbols], - } diff --git a/src/codebase_graph/context_builder.py b/src/codebase_graph/context_builder.py deleted file mode 100644 index 7a98dc4..0000000 --- a/src/codebase_graph/context_builder.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from .graph_context import build_compact_graph_context - -def assemble_context(query: str, graph: dict[str, Any], *, budget: int = 1200) -> dict[str, Any]: - return build_compact_graph_context(graph, query, budget=budget, limit=5, include_raw=False) diff --git a/src/codebase_graph/core/__init__.py b/src/codebase_graph/core/__init__.py new file mode 100644 index 0000000..1d480ea --- /dev/null +++ b/src/codebase_graph/core/__init__.py @@ -0,0 +1,5 @@ +"""Public API, models, and protocols for the code memory graph.""" + +from .graph import CodeGraph, GraphEdge, GraphNode + +__all__ = ["CodeGraph", "GraphEdge", "GraphNode"] diff --git a/src/codebase_graph/core/graph.py b/src/codebase_graph/core/graph.py new file mode 100644 index 0000000..c469bae --- /dev/null +++ b/src/codebase_graph/core/graph.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from codebase_graph.ontology import ONTOLOGY_NAME, get_relation_type + + +@dataclass(slots=True) +class GraphNode: + id: str + table: str + label: str + kind: str = "" + language: str = "" + path: str = "" + qualified_name: str = "" + scope_id: str = "" + line_start: int | None = None + line_end: int | None = None + byte_start: int | None = None + byte_end: int | None = None + tree_sitter_node_type: str = "" + capture_name: str = "" + summary: str = "" + metadata: dict[str, Any] = field(default_factory=dict) + + def as_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "table": self.table, + "label": self.label, + "kind": self.kind, + "language": self.language, + "path": self.path, + "qualified_name": self.qualified_name, + "scope_id": self.scope_id, + "line_start": self.line_start, + "line_end": self.line_end, + "byte_start": self.byte_start, + "byte_end": self.byte_end, + "tree_sitter_node_type": self.tree_sitter_node_type, + "capture_name": self.capture_name, + "summary": self.summary, + "metadata": self.metadata, + } + + +@dataclass(slots=True) +class GraphEdge: + id: str + type: str + source_id: str + target_id: str + kind: str = "" + confidence: float = 1.0 + line_start: int | None = None + line_end: int | None = None + byte_start: int | None = None + byte_end: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def as_dict(self) -> dict[str, Any]: + return { + "id": self.id, + "type": self.type, + "source_id": self.source_id, + "target_id": self.target_id, + "kind": self.kind, + "confidence": self.confidence, + "line_start": self.line_start, + "line_end": self.line_end, + "byte_start": self.byte_start, + "byte_end": self.byte_end, + "metadata": self.metadata, + } + + +@dataclass(slots=True) +class CodeGraph: + nodes: dict[str, GraphNode] = field(default_factory=dict) + edges: dict[str, GraphEdge] = field(default_factory=dict) + ontology: str = ONTOLOGY_NAME + metadata: dict[str, Any] = field(default_factory=dict) + + def add_node(self, node: GraphNode) -> GraphNode: + existing = self.nodes.get(node.id) + if existing is None: + self.nodes[node.id] = node + return node + _merge_node(existing, node) + return existing + + def add_edge(self, edge: GraphEdge) -> GraphEdge: + self.edges.setdefault(edge.id, edge) + return self.edges[edge.id] + + def nodes_by_type(self, table: str) -> list[GraphNode]: + return [node for node in self.nodes.values() if node.table == table] + + def edges_by_type(self, edge_type: str) -> list[GraphEdge]: + return [edge for edge in self.edges.values() if edge.type == edge_type] + + def as_dict(self) -> dict[str, Any]: + return { + "ontology": self.ontology, + "metadata": self.metadata, + "nodes": [ + node.as_dict() + for node in sorted(self.nodes.values(), key=lambda item: (item.table, item.id)) + ], + "edges": [ + edge.as_dict() + for edge in sorted(self.edges.values(), key=lambda item: (item.type, item.id)) + ], + } + + def summary(self) -> dict[str, Any]: + node_counts: dict[str, int] = {} + edge_counts: dict[str, int] = {} + for node in self.nodes.values(): + node_counts[node.table] = node_counts.get(node.table, 0) + 1 + for edge in self.edges.values(): + edge_counts[edge.type] = edge_counts.get(edge.type, 0) + 1 + return { + "ontology": self.ontology, + "node_count": len(self.nodes), + "edge_count": len(self.edges), + "node_counts": node_counts, + "edge_counts": edge_counts, + } + + def validate_schema(self) -> None: + node_tables = {node.id: node.table for node in self.nodes.values()} + for edge in self.edges.values(): + if edge.source_id not in node_tables: + raise ValueError(f"Relation {edge.id} source is missing: {edge.source_id}") + if edge.target_id not in node_tables: + raise ValueError(f"Relation {edge.id} target is missing: {edge.target_id}") + spec = get_relation_type(edge.type) + source_table = node_tables[edge.source_id] + target_table = node_tables[edge.target_id] + if source_table not in spec.source_types: + raise ValueError(f"{edge.type} cannot start from {source_table}") + if target_table not in spec.target_types: + raise ValueError(f"{edge.type} cannot target {target_table}") + + +def _merge_node(existing: GraphNode, incoming: GraphNode) -> None: + for field_name in ( + "label", + "kind", + "language", + "path", + "qualified_name", + "scope_id", + "tree_sitter_node_type", + "capture_name", + "summary", + ): + if not getattr(existing, field_name) and getattr(incoming, field_name): + setattr(existing, field_name, getattr(incoming, field_name)) + for field_name in ("line_start", "line_end", "byte_start", "byte_end"): + if getattr(existing, field_name) is None and getattr(incoming, field_name) is not None: + setattr(existing, field_name, getattr(incoming, field_name)) + existing.metadata.update(incoming.metadata) diff --git a/src/codebase_graph/db/__init__.py b/src/codebase_graph/db/__init__.py new file mode 100644 index 0000000..c3beabe --- /dev/null +++ b/src/codebase_graph/db/__init__.py @@ -0,0 +1,20 @@ +"""Storage adapters, schema management, and migrations.""" + +from .query import GraphNeighbor, GraphQueryAdapter, LadybugGraphQueryAdapter, SearchIndexRow, graph_query_adapter +from .schema import build_ladybug_schema, build_ladybug_schema_statements, ladybug_type, quote_identifier +from .store import LadybugCodeGraphStore, LadybugUnavailableError, create_ladybug_database + +__all__ = [ + "GraphNeighbor", + "GraphQueryAdapter", + "LadybugCodeGraphStore", + "LadybugGraphQueryAdapter", + "LadybugUnavailableError", + "SearchIndexRow", + "build_ladybug_schema", + "build_ladybug_schema_statements", + "create_ladybug_database", + "graph_query_adapter", + "ladybug_type", + "quote_identifier", +] diff --git a/src/codebase_graph/db/query.py b/src/codebase_graph/db/query.py new file mode 100644 index 0000000..02a1d98 --- /dev/null +++ b/src/codebase_graph/db/query.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol + +from codebase_graph.ontology import get_relation_type + +from .schema import quote_identifier + + +@dataclass(frozen=True, slots=True) +class GraphNeighbor: + node_id: str + node_type: str + label: str + qualified_name: str = "" + path: str = "" + line_start: int | None = None + line_end: int | None = None + summary: str = "" + + +@dataclass(frozen=True, slots=True) +class SearchIndexRow: + id: str + node_type: str + label: str + qualified_name: str = "" + path: str = "" + line_start: int | None = None + line_end: int | None = None + summary: str = "" + score: float = 0.0 + metadata: dict[str, Any] = field(default_factory=dict) + + +class GraphQueryAdapter(Protocol): + def search_index(self, *, node_type: str, index_name: str, query: str, limit: int) -> list[SearchIndexRow]: + ... + + def neighbors( + self, + *, + node_id: str, + node_type: str, + relation: str, + direction: str, + limit: int, + ) -> list[GraphNeighbor]: + ... + + +class LadybugGraphQueryAdapter: + def __init__(self, store: Any) -> None: + self.store = store + + def search_index(self, *, node_type: str, index_name: str, query: str, limit: int) -> list[SearchIndexRow]: + rows = self.store.execute( + _fts_query_statement(node_type=node_type, index_name=index_name), + {"query": query, "top": limit}, + ).get_all() + return [ + SearchIndexRow( + id=_text(_value(row, 0)), + node_type=node_type, + label=_text(_value(row, 1)), + qualified_name=_text(_value(row, 2)), + path=_text(_value(row, 3)), + line_start=_optional_int(_value(row, 4)), + line_end=_optional_int(_value(row, 5)), + summary=_text(_value(row, 6)), + score=float(_value(row, 7) or 0.0), + ) + for row in rows + ] + + def neighbors( + self, + *, + node_id: str, + node_type: str, + relation: str, + direction: str, + limit: int, + ) -> list[GraphNeighbor]: + if direction not in {"outgoing", "incoming"}: + raise ValueError(f"Unsupported relation direction: {direction}") + try: + relation_type = get_relation_type(relation) + except KeyError: + return [] + + if direction == "outgoing": + if node_type not in relation_type.source_types: + return [] + neighbor_types = relation_type.target_types + else: + if node_type not in relation_type.target_types: + return [] + neighbor_types = relation_type.source_types + + neighbors: list[GraphNeighbor] = [] + for neighbor_type in neighbor_types: + remaining = limit - len(neighbors) + if remaining <= 0: + break + rows = self.store.execute( + _neighbor_statement( + node_type=node_type, + neighbor_type=neighbor_type, + relation=relation, + direction=direction, + limit=remaining, + ), + {"node_id": node_id}, + ).get_all() + neighbors.extend(_neighbor_from_row(row, neighbor_type) for row in rows) + return neighbors + + +def graph_query_adapter(store: Any) -> GraphQueryAdapter: + adapter = getattr(store, "graph_query_adapter", None) + if adapter is not None: + return adapter + return LadybugGraphQueryAdapter(store) + + +def _fts_query_statement(*, node_type: str, index_name: str) -> str: + return ( + f"CALL QUERY_FTS_INDEX('{node_type}', '{index_name}', $query, TOP := $top) " + "RETURN node.id, node.label, node.qualified_name, node.path, " + "node.line_start, node.line_end, node.summary, score" + ) + + +def _neighbor_statement( + *, + node_type: str, + neighbor_type: str, + relation: str, + direction: str, + limit: int, +) -> str: + if direction == "outgoing": + return ( + f"MATCH (source:{quote_identifier(node_type)} {{id: $node_id}})" + f"-[:{quote_identifier(f'FROM_{relation}')}]->(edge:{quote_identifier(relation)})" + f"-[:{quote_identifier(f'TO_{relation}')}]->(neighbor:{quote_identifier(neighbor_type)}) " + "RETURN neighbor.id, neighbor.label, neighbor.qualified_name, neighbor.path, " + f"neighbor.line_start, neighbor.line_end, neighbor.summary LIMIT {int(limit)}" + ) + return ( + f"MATCH (neighbor:{quote_identifier(neighbor_type)})" + f"-[:{quote_identifier(f'FROM_{relation}')}]->(edge:{quote_identifier(relation)})" + f"-[:{quote_identifier(f'TO_{relation}')}]->(target:{quote_identifier(node_type)} {{id: $node_id}}) " + "RETURN neighbor.id, neighbor.label, neighbor.qualified_name, neighbor.path, " + f"neighbor.line_start, neighbor.line_end, neighbor.summary LIMIT {int(limit)}" + ) + + +def _neighbor_from_row(row: Any, node_type: str) -> GraphNeighbor: + return GraphNeighbor( + node_id=_text(_value(row, 0)), + node_type=node_type, + label=_text(_value(row, 1)), + qualified_name=_text(_value(row, 2)), + path=_text(_value(row, 3)), + line_start=_optional_int(_value(row, 4)), + line_end=_optional_int(_value(row, 5)), + summary=_text(_value(row, 6)), + ) + + +def _optional_int(value: Any) -> int | None: + return None if value is None else int(value) + + +def _text(value: Any) -> str: + return "" if value is None else str(value) + + +def _value(row: Any, index: int) -> Any: + try: + return row[index] + except IndexError: + return None diff --git a/src/codebase_graph/db/schema.py b/src/codebase_graph/db/schema.py new file mode 100644 index 0000000..2b8fa8b --- /dev/null +++ b/src/codebase_graph/db/schema.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from collections.abc import Iterable + +from codebase_graph.ontology import EDGE_FIELDS, NODE_TYPES, RELATION_TYPES, SEARCH_INDEXES, FieldSpec + +TYPE_MAP = { + "string": "STRING", + "integer": "INT64", + "number": "DOUBLE", + "boolean": "BOOLEAN", + "json": "JSON", +} + + +def quote_identifier(name: str) -> str: + return f"`{name.replace('`', '``')}`" + + +def ladybug_type(value_type: str) -> str: + try: + return TYPE_MAP[value_type] + except KeyError as exc: + raise ValueError(f"Unsupported ontology field type for LadyBugDB: {value_type}") from exc + + +def build_ladybug_schema(*, include_fts: bool = True) -> str: + return ";\n\n".join(build_ladybug_schema_statements(include_fts=include_fts)) + ";" + + +def build_ladybug_schema_statements(*, include_fts: bool = True) -> list[str]: + statements = [ + "INSTALL json", + "LOAD json", + ] + if include_fts: + statements.extend(("INSTALL fts", "LOAD fts")) + statements.extend(_semantic_node_table_sql()) + statements.extend(_edge_node_table_sql()) + statements.extend(_connector_table_sql()) + if include_fts: + statements.extend(_fts_index_sql()) + return statements + + +def _semantic_node_table_sql() -> list[str]: + return [ + _node_table_sql(node_type.name, node_type.fields) + for node_type in NODE_TYPES + ] + + +def _edge_node_table_sql() -> list[str]: + return [ + _node_table_sql(relation_type.name, relation_type.fields or EDGE_FIELDS) + for relation_type in RELATION_TYPES + ] + + +def _connector_table_sql() -> list[str]: + statements: list[str] = [] + for relation_type in RELATION_TYPES: + relation_name = relation_type.name + source_pairs = _dedupe_pairs((source_type, relation_name) for source_type in relation_type.source_types) + target_pairs = _dedupe_pairs((relation_name, target_type) for target_type in relation_type.target_types) + statements.append(_relation_table_sql(f"FROM_{relation_name}", source_pairs, role="source")) + statements.append(_relation_table_sql(f"TO_{relation_name}", target_pairs, role="target")) + return statements + + +def _node_table_sql(table_name: str, fields: Iterable[FieldSpec]) -> str: + columns = [_field_sql(field) for field in _dedupe_fields(fields)] + return f"CREATE NODE TABLE IF NOT EXISTS {quote_identifier(table_name)}(\n" + ",\n".join(columns) + "\n)" + + +def _relation_table_sql(table_name: str, endpoint_pairs: Iterable[tuple[str, str]], *, role: str) -> str: + endpoints = [ + f" FROM {quote_identifier(source_type)} TO {quote_identifier(target_type)}" + for source_type, target_type in endpoint_pairs + ] + columns = [*endpoints, f" {quote_identifier('role')} STRING DEFAULT '{role}'"] + return f"CREATE REL TABLE IF NOT EXISTS {quote_identifier(table_name)}(\n" + ",\n".join(columns) + "\n)" + + +def _field_sql(field: FieldSpec) -> str: + primary_key = " PRIMARY KEY" if field.name == "id" else "" + return f" {quote_identifier(field.name)} {ladybug_type(field.value_type)}{primary_key}" + + +def _dedupe_fields(fields: Iterable[FieldSpec]) -> list[FieldSpec]: + seen: set[str] = set() + deduped: list[FieldSpec] = [] + for field in fields: + if field.name in seen: + continue + seen.add(field.name) + deduped.append(field) + return deduped + + +def _dedupe_pairs(pairs: Iterable[tuple[str, str]]) -> list[tuple[str, str]]: + seen: set[tuple[str, str]] = set() + deduped: list[tuple[str, str]] = [] + for pair in pairs: + if pair in seen: + continue + seen.add(pair) + deduped.append(pair) + return deduped + + +def _fts_index_sql() -> list[str]: + statements: list[str] = [] + for index in SEARCH_INDEXES: + fields = ", ".join(repr(field) for field in index["fields"]) + for node_type in index["node_types"]: + index_name = f"{index['name']}_{node_type}" + statements.append(f"CALL CREATE_FTS_INDEX('{node_type}', '{index_name}', [{fields}])") + return statements diff --git a/src/codebase_graph/db/store.py b/src/codebase_graph/db/store.py new file mode 100644 index 0000000..3b79e47 --- /dev/null +++ b/src/codebase_graph/db/store.py @@ -0,0 +1,475 @@ +from __future__ import annotations + +import csv +import json +import tempfile +from collections import defaultdict +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from codebase_graph.core import CodeGraph +from codebase_graph.ontology import NODE_TYPES, RELATION_TYPES + +from .schema import build_ladybug_schema, build_ladybug_schema_statements, quote_identifier + + +class LadybugUnavailableError(RuntimeError): + pass + + +@dataclass(frozen=True, slots=True) +class BulkLoadStats: + node_rows: int = 0 + edge_rows: int = 0 + connector_rows: int = 0 + copy_calls: int = 0 + + +class LadybugCodeGraphStore: + def __init__( + self, + db_path: str | Path = ":memory:", + *, + include_fts: bool = True, + read_only: bool = False, + ) -> None: + self.db_path = db_path + self.include_fts = include_fts + self.read_only = read_only + try: + import real_ladybug as lb + except ImportError as exc: + raise LadybugUnavailableError( + "LadyBugDB Python bindings are required for codebaseGraph. " + "Install a valid `codebase-graph` runtime with `real_ladybug` available." + ) from exc + + self._lb = lb + if str(db_path) != ":memory:" and not read_only: + Path(db_path).parent.mkdir(parents=True, exist_ok=True) + self.db = lb.Database(str(db_path), read_only=read_only) + self.conn = lb.Connection(self.db) + + @property + def schema_sql(self) -> str: + return build_ladybug_schema(include_fts=self.include_fts) + + def ensure_schema(self) -> None: + for statement in build_ladybug_schema_statements(include_fts=self.include_fts): + self._execute_ignoring_existing(statement) + + def load_extensions(self) -> None: + for statement in build_ladybug_schema_statements(include_fts=self.include_fts): + if statement.upper().startswith("LOAD "): + self.execute(statement) + + def execute(self, statement: str, parameters: dict[str, Any] | None = None) -> Any: + if parameters is None: + return self.conn.execute(statement) + return self.conn.execute(statement, parameters) + + def close(self) -> None: + self.conn.close() + self.db.close() + + def __enter__(self) -> LadybugCodeGraphStore: + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + self.close() + + def clear_graph(self) -> None: + for relation_type in RELATION_TYPES: + self._execute_ignoring_missing(f"MATCH ()-[r:{quote_identifier(f'FROM_{relation_type.name}')}]->() DELETE r") + self._execute_ignoring_missing(f"MATCH ()-[r:{quote_identifier(f'TO_{relation_type.name}')}]->() DELETE r") + for relation_type in RELATION_TYPES: + self._execute_ignoring_missing(f"MATCH (n:{quote_identifier(relation_type.name)}) DELETE n") + for node_type in NODE_TYPES: + self._execute_ignoring_missing(f"MATCH (n:{quote_identifier(node_type.name)}) DELETE n") + + def replace_partition( + self, + path: str, + graph: CodeGraph, + *, + previous_entry: Mapping[str, Any] | Any | None = None, + retained_node_ids: set[str] | None = None, + retained_edge_ids: set[str] | None = None, + ) -> None: + if previous_entry is not None: + self.delete_partition( + path, + manifest_entry=previous_entry, + retained_node_ids=retained_node_ids, + retained_edge_ids=retained_edge_ids, + ) + + self.insert_graphs_bulk( + [graph], + skip_node_ids=retained_node_ids, + skip_edge_ids=retained_edge_ids, + ) + + def insert_graphs_bulk( + self, + graphs: list[CodeGraph] | tuple[CodeGraph, ...], + *, + skip_node_ids: set[str] | None = None, + skip_edge_ids: set[str] | None = None, + ) -> BulkLoadStats: + staging_tables = _build_bulk_staging_tables( + graphs, + skip_node_ids=skip_node_ids, + skip_edge_ids=skip_edge_ids, + ) + if staging_tables.is_empty: + return BulkLoadStats() + + with tempfile.TemporaryDirectory(prefix="codebase-graph-ladybug-") as staging_dir: + staging = staging_tables.write(Path(staging_dir)) + for statement in staging.copy_statements: + self.execute(statement) + return BulkLoadStats( + node_rows=staging.node_rows, + edge_rows=staging.edge_rows, + connector_rows=staging.connector_rows, + copy_calls=len(staging.copy_statements), + ) + + def delete_partition( + self, + path: str, + *, + manifest_entry: Mapping[str, Any] | Any | None = None, + retained_node_ids: set[str] | None = None, + retained_edge_ids: set[str] | None = None, + ) -> None: + if manifest_entry is None: + return + retained = retained_node_ids or set() + retained_edges = retained_edge_ids or set() + edge_types = _entry_mapping(manifest_entry, "edge_types") + node_types = _entry_mapping(manifest_entry, "node_types") + + for edge_id in _entry_values(manifest_entry, "edge_ids"): + if edge_id in retained_edges: + continue + edge_type = edge_types.get(edge_id) + if edge_type: + self._delete_edge(edge_id, edge_type) + + for node_id in _entry_values(manifest_entry, "node_ids"): + if node_id in retained: + continue + node_type = node_types.get(node_id) + if node_type: + self._delete_node(node_id, node_type) + + def read_manifest(self, path: str | Path) -> Any: + from codebase_graph.ingest.materializer import MaterializationManifest + + return MaterializationManifest.load(Path(path)) + + def write_manifest(self, manifest: Any, path: str | Path) -> None: + manifest.write(Path(path)) + + def _execute_ignoring_existing(self, statement: str) -> None: + try: + self.conn.execute(statement) + except Exception as exc: + message = str(exc).lower() + if "already exists" not in message and "exists already" not in message and "already installed" not in message: + raise + + def _execute_ignoring_missing(self, statement: str, parameters: dict[str, Any] | None = None) -> None: + try: + self.execute(statement, parameters) + except Exception as exc: + message = str(exc).lower() + if "does not exist" not in message and "not found" not in message: + raise + + def _delete_edge(self, edge_id: str, edge_type: str) -> None: + self._execute_ignoring_missing( + f"MATCH ()-[r:{quote_identifier(f'FROM_{edge_type}')}]->(edge:{quote_identifier(edge_type)} {{id: $id}}) DELETE r", + {"id": edge_id}, + ) + self._execute_ignoring_missing( + f"MATCH (edge:{quote_identifier(edge_type)} {{id: $id}})-[r:{quote_identifier(f'TO_{edge_type}')}]->() DELETE r", + {"id": edge_id}, + ) + self._execute_ignoring_missing( + f"MATCH (edge:{quote_identifier(edge_type)} {{id: $id}}) DELETE edge", + {"id": edge_id}, + ) + + def _delete_node(self, node_id: str, node_type: str) -> None: + self._execute_ignoring_missing( + f"MATCH (node:{quote_identifier(node_type)} {{id: $id}}) DELETE node", + {"id": node_id}, + ) + + +def create_ladybug_database( + db_path: str | Path = ":memory:", + *, + include_fts: bool = True, + read_only: bool = False, +) -> LadybugCodeGraphStore: + store = LadybugCodeGraphStore(db_path, include_fts=include_fts, read_only=read_only) + if read_only: + store.load_extensions() + else: + store.ensure_schema() + return store + + +NODE_FIELDS = { + node_type.name: tuple(field for field in node_type.fields) + for node_type in NODE_TYPES +} +_OMIT_JSON_VALUE = object() +EDGE_FIELDS_BY_TYPE = { + relation_type.name: tuple(field for field in relation_type.fields) + for relation_type in RELATION_TYPES +} + + +@dataclass(slots=True) +class _BulkStagingTables: + nodes: dict[str, dict[str, dict[str, Any]]] + edges: dict[str, dict[str, dict[str, Any]]] + connectors: dict[tuple[str, str, str], dict[tuple[str, str, str], dict[str, str]]] + + @property + def is_empty(self) -> bool: + return not any(self.nodes.values()) and not any(self.edges.values()) and not any(self.connectors.values()) + + def write(self, staging_dir: Path) -> _BulkStagingResult: + staging_dir.mkdir(parents=True, exist_ok=True) + copy_statements: list[str] = [] + node_rows = 0 + edge_rows = 0 + connector_rows = 0 + + for node_type in NODE_TYPES: + rows = self.nodes.get(node_type.name, {}) + if not rows: + continue + path = staging_dir / f"{_stage_file_stem(node_type.name)}.json" + _write_json_rows(path, rows.values()) + node_rows += len(rows) + copy_statements.append(f'COPY {quote_identifier(node_type.name)} FROM "{_copy_path(path)}";') + + for relation_type in RELATION_TYPES: + rows = self.edges.get(relation_type.name, {}) + if not rows: + continue + path = staging_dir / f"{_stage_file_stem(relation_type.name)}.json" + _write_json_rows(path, rows.values()) + edge_rows += len(rows) + copy_statements.append(f'COPY {quote_identifier(relation_type.name)} FROM "{_copy_path(path)}";') + + for relation_type in RELATION_TYPES: + for connector_table in (f"FROM_{relation_type.name}", f"TO_{relation_type.name}"): + connector_groups = [ + (endpoint_pair, rows) + for endpoint_pair, rows in self.connectors.items() + if endpoint_pair[0] == connector_table and rows + ] + for (table, source_type, target_type), rows in sorted(connector_groups): + path = staging_dir / ( + f"{_stage_file_stem(table)}__" + f"{_stage_file_stem(source_type)}__{_stage_file_stem(target_type)}.csv" + ) + _write_csv_rows(path, ("from_id", "to_id", "role"), rows.values()) + connector_rows += len(rows) + copy_statements.append( + f'COPY {quote_identifier(table)} FROM "{_copy_path(path)}" ' + f'(header=true, from="{source_type}", to="{target_type}");' + ) + + return _BulkStagingResult( + copy_statements=tuple(copy_statements), + node_rows=node_rows, + edge_rows=edge_rows, + connector_rows=connector_rows, + ) + + +@dataclass(frozen=True, slots=True) +class _BulkStagingResult: + copy_statements: tuple[str, ...] + node_rows: int + edge_rows: int + connector_rows: int + + +def _build_bulk_staging_tables( + graphs: list[CodeGraph] | tuple[CodeGraph, ...], + *, + skip_node_ids: set[str] | None = None, + skip_edge_ids: set[str] | None = None, +) -> _BulkStagingTables: + skipped_nodes = skip_node_ids or set() + skipped_edges = skip_edge_ids or set() + node_rows: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) + edge_rows: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) + connector_rows: dict[tuple[str, str, str], dict[tuple[str, str, str], dict[str, str]]] = defaultdict(dict) + + for graph in graphs: + for node in graph.nodes.values(): + if node.id in skipped_nodes: + continue + row = _row_for_fields(node.as_dict(), NODE_FIELDS[node.table], for_json_copy=True) + _merge_staged_row(node_rows[node.table], node.id, row) + + for edge in graph.edges.values(): + if edge.id in skipped_edges: + continue + row = _row_for_fields(edge.as_dict(), EDGE_FIELDS_BY_TYPE[edge.type], for_json_copy=True) + _merge_staged_row(edge_rows[edge.type], edge.id, row) + + for edge in graph.edges.values(): + if edge.id in skipped_edges: + continue + source = graph.nodes[edge.source_id] + target = graph.nodes[edge.target_id] + _add_connector_row( + connector_rows, + table=f"FROM_{edge.type}", + source_type=source.table, + target_type=edge.type, + from_id=source.id, + to_id=edge.id, + role="source", + ) + _add_connector_row( + connector_rows, + table=f"TO_{edge.type}", + source_type=edge.type, + target_type=target.table, + from_id=edge.id, + to_id=target.id, + role="target", + ) + + return _BulkStagingTables(nodes=dict(node_rows), edges=dict(edge_rows), connectors=dict(connector_rows)) + + +def _row_for_fields(row: Mapping[str, Any], fields: tuple[Any, ...], *, for_json_copy: bool = False) -> dict[str, Any]: + return { + field.name: _copy_field_value(field.name, row, field.value_type, for_json_copy=for_json_copy) + for field in fields + } + + +def _copy_field_value(name: str, row: Mapping[str, Any], value_type: str, *, for_json_copy: bool = False) -> Any: + if not for_json_copy or value_type != "json": + return _field_value(name, row, value_type) + if name in row: + value = row[name] + else: + metadata = row.get("metadata") if isinstance(row.get("metadata"), Mapping) else {} + value = metadata.get(name) + safe = _json_safe(value if value is not None else {}) + if safe is _OMIT_JSON_VALUE: + return {} + return safe + + +def _merge_staged_row(rows: dict[str, dict[str, Any]], row_id: str, row: dict[str, Any]) -> None: + existing = rows.get(row_id) + if existing is None: + rows[row_id] = row + return + for key, value in row.items(): + if value not in (None, "", {}, []) and existing.get(key) in (None, "", {}, []): + existing[key] = value + existing_metadata = existing.get("metadata") + incoming_metadata = row.get("metadata") + if isinstance(existing_metadata, dict) and isinstance(incoming_metadata, dict): + existing_metadata.update(incoming_metadata) + + +def _add_connector_row( + rows: dict[tuple[str, str, str], dict[tuple[str, str, str], dict[str, str]]], + *, + table: str, + source_type: str, + target_type: str, + from_id: str, + to_id: str, + role: str, +) -> None: + key = (table, source_type, target_type) + rows[key][(from_id, to_id, role)] = {"from_id": from_id, "to_id": to_id, "role": role} + + +def _write_json_rows(path: Path, rows: Any) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(list(rows), handle, separators=(",", ":"), sort_keys=True) + handle.write("\n") + + +def _write_csv_rows(path: Path, columns: tuple[str, ...], rows: Any) -> None: + with path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=columns, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({column: row.get(column, "") for column in columns}) + + +def _stage_file_stem(name: str) -> str: + return "".join(character.lower() if character.isalnum() else "_" for character in name).strip("_") or "table" + + +def _copy_path(path: Path) -> str: + return path.as_posix().replace('"', '\\"') + + +def _field_value(name: str, row: Mapping[str, Any], value_type: str) -> Any: + if name in row: + value = row[name] + else: + metadata = row.get("metadata") if isinstance(row.get("metadata"), Mapping) else {} + value = metadata.get(name) + if value_type == "json": + return json.dumps(_json_safe(value if value is not None else {}), sort_keys=True) + return value + + +def _json_safe(value: Any) -> Any: + if isinstance(value, Mapping): + safe_items = {} + for key, item in value.items(): + safe_item = _json_safe(item) + if safe_item is _OMIT_JSON_VALUE: + continue + safe_items[str(key)] = safe_item + return safe_items + if isinstance(value, list | tuple): + if not value: + return _OMIT_JSON_VALUE + return [_json_safe(item) for item in value] + if value is None: + return _OMIT_JSON_VALUE + return value + + +def _entry_values(entry: Mapping[str, Any] | Any, field_name: str) -> tuple[str, ...]: + if isinstance(entry, Mapping): + values = entry.get(field_name, ()) + else: + values = getattr(entry, field_name, ()) + return tuple(str(value) for value in values) + + +def _entry_mapping(entry: Mapping[str, Any] | Any, field_name: str) -> dict[str, str]: + if isinstance(entry, Mapping): + values = entry.get(field_name, {}) + else: + values = getattr(entry, field_name, {}) + return {str(key): str(value) for key, value in dict(values).items()} diff --git a/src/codebase_graph/diagnostics.py b/src/codebase_graph/diagnostics.py new file mode 100644 index 0000000..ec4e5ca --- /dev/null +++ b/src/codebase_graph/diagnostics.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json +import os +import sys +from datetime import datetime, timezone +from typing import Any + +LOG_LEVEL_ENV = "CODEBASE_GRAPH_LOG_LEVEL" +_LEVELS = { + "DEBUG": 10, + "INFO": 20, + "WARNING": 30, + "ERROR": 40, + "CRITICAL": 50, +} + + +def log_event(event: str, *, level: str = "INFO", **fields: Any) -> None: + normalized_level = level.upper() + if _LEVELS.get(normalized_level, 20) < _configured_level(): + return + payload = { + "event": event, + "level": normalized_level, + "timestamp": datetime.now(timezone.utc).isoformat(), + **_safe_fields(fields), + } + print(json.dumps(payload, separators=(",", ":"), sort_keys=True), file=sys.stderr) + + +def _configured_level() -> int: + configured = os.environ.get(LOG_LEVEL_ENV, "WARNING").upper() + return _LEVELS.get(configured, _LEVELS["WARNING"]) + + +def _safe_fields(fields: dict[str, Any]) -> dict[str, Any]: + safe: dict[str, Any] = {} + for key, value in fields.items(): + if value is None or isinstance(value, (str, int, float, bool)): + safe[key] = value + elif isinstance(value, (list, tuple)): + safe[key] = [_safe_value(item) for item in value] + elif isinstance(value, dict): + safe[key] = {str(item_key): _safe_value(item_value) for item_key, item_value in value.items()} + else: + safe[key] = str(value) + return safe + + +def _safe_value(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + return [_safe_value(item) for item in value] + if isinstance(value, dict): + return {str(key): _safe_value(item) for key, item in value.items()} + return str(value) diff --git a/src/codebase_graph/document_layers.py b/src/codebase_graph/document_layers.py deleted file mode 100644 index 0e1cb37..0000000 --- a/src/codebase_graph/document_layers.py +++ /dev/null @@ -1,56 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -import re - -@dataclass(slots=True) -class LogicalChunk: - id: str - heading: str - text: str - ordinal: int - -class LogicalChunker: - def __init__(self, max_chars: int = 1600) -> None: - self.max_chars = max_chars - - def chunk(self, text: str) -> list[LogicalChunk]: - sections = _split_sections(text) - chunks: list[LogicalChunk] = [] - for index, (heading, body) in enumerate(sections): - body = body.strip() - if not body: - continue - for part_index, part in enumerate(_split_body(body, self.max_chars)): - suffix = f"-{part_index}" if part_index else "" - chunks.append(LogicalChunk(id=f"chunk-{index}{suffix}", heading=heading, text=part, ordinal=len(chunks))) - if not chunks and text.strip(): - chunks.append(LogicalChunk(id="chunk-0", heading="Document", text=text.strip()[: self.max_chars], ordinal=0)) - return chunks - -def _split_sections(text: str) -> list[tuple[str, str]]: - matches = list(re.finditer(r"^(#{1,6})\s+(.+)$", text, flags=re.MULTILINE)) - if not matches: - return [("Document", text)] - sections: list[tuple[str, str]] = [] - for index, match in enumerate(matches): - start = match.end() - end = matches[index + 1].start() if index + 1 < len(matches) else len(text) - sections.append((match.group(2).strip(), text[start:end])) - return sections - -def _split_body(body: str, max_chars: int) -> list[str]: - if len(body) <= max_chars: - return [body] - paragraphs = [part.strip() for part in body.split("\n\n") if part.strip()] - chunks: list[str] = [] - current = "" - for paragraph in paragraphs: - if current and len(current) + len(paragraph) + 2 > max_chars: - chunks.append(current) - current = paragraph - else: - current = f"{current}\n\n{paragraph}".strip() - if current: - chunks.append(current) - return chunks or [body[:max_chars]] diff --git a/src/codebase_graph/extract/__init__.py b/src/codebase_graph/extract/__init__.py new file mode 100644 index 0000000..f54048f --- /dev/null +++ b/src/codebase_graph/extract/__init__.py @@ -0,0 +1,21 @@ +"""Code entity and relation extraction.""" + +from .graph_builder import ( + CaptureRecord, + CaptureTableRegistry, + CaptureTableResolver, + GraphBuilder, + GraphBuildResult, + ParseBundle, + default_capture_table_registry, +) + +__all__ = [ + "CaptureRecord", + "CaptureTableRegistry", + "CaptureTableResolver", + "GraphBuilder", + "GraphBuildResult", + "ParseBundle", + "default_capture_table_registry", +] diff --git a/src/codebase_graph/extract/graph_builder.py b/src/codebase_graph/extract/graph_builder.py new file mode 100644 index 0000000..9552624 --- /dev/null +++ b/src/codebase_graph/extract/graph_builder.py @@ -0,0 +1,1327 @@ +from __future__ import annotations + +import hashlib +from collections.abc import Callable, Iterable, Mapping, Sequence +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from codebase_graph.core import CodeGraph, GraphEdge, GraphNode +from codebase_graph.ontology import ONTOLOGY_NAME, get_relation_type, node_type_names, relation_type_names + + +@dataclass(frozen=True, slots=True) +class CaptureRecord: + capture: str + node: Any + + +@dataclass(frozen=True, slots=True) +class ParseBundle: + language: str + path: str + source_text: str = "" + tree: Any | None = None + captures: Sequence[CaptureRecord | Mapping[str, Any] | tuple[Any, str]] = () + repository_label: str = "repository" + source_root: str = "." + content_hash: str = "" + + +@dataclass(frozen=True, slots=True) +class GraphBuildResult: + nodes: list[dict[str, Any]] + edges: list[dict[str, Any]] + diagnostics: list[str] + unresolved: list[str] + graph: CodeGraph + + def as_dict(self) -> dict[str, Any]: + return { + "nodes": self.nodes, + "edges": self.edges, + "diagnostics": self.diagnostics, + "unresolved": self.unresolved, + "summary": self.graph.summary(), + } + + +@dataclass(frozen=True, slots=True) +class ParserNode: + node_type: str + fields: Mapping[str, Any] + children: tuple[Any, ...] + line_start: int | None = None + line_end: int | None = None + byte_start: int | None = None + byte_end: int | None = None + capture_name: str = "" + text: str = "" + + +@dataclass(frozen=True, slots=True) +class BuildContext: + path: str + language: str + source_text: str + repository_label: str + source_root: str + + +@dataclass(frozen=True, slots=True) +class ScopeFrame: + node_id: str + table: str + label: str + scope_id: str + qualified_name: str + + +CaptureTableResolver = Callable[[str, ScopeFrame], str | None] + + +class CaptureTableRegistry: + def __init__(self) -> None: + self._exact: dict[str, str | CaptureTableResolver] = {} + self._prefix: list[tuple[str, str | CaptureTableResolver]] = [] + + def register_exact(self, capture_name: str, table: str | CaptureTableResolver) -> None: + self._exact[_normalize_capture_name(capture_name)] = table + + def register_prefix(self, prefix: str, table: str | CaptureTableResolver) -> None: + self._prefix.append((_normalize_capture_name(prefix), table)) + + def table_for(self, capture_name: str, owner: ScopeFrame) -> str | None: + capture = _normalize_capture_name(capture_name) + if not capture: + return None + if capture in self._exact: + return _resolve_capture_table(self._exact[capture], capture, owner) + for prefix, table in self._prefix: + if capture.startswith(prefix): + return _resolve_capture_table(table, capture, owner) + return None + + +def default_capture_table_registry() -> CaptureTableRegistry: + registry = CaptureTableRegistry() + for capture in ("definition.class", "definition.struct", "definition.interface"): + registry.register_exact(capture, "Class") + registry.register_exact("definition.component", "Component") + registry.register_exact("component", "Component") + registry.register_exact("definition.method", "Method") + registry.register_exact("definition.function", _function_capture_table) + registry.register_exact("definition.parameter", "Parameter") + registry.register_exact("parameter", "Parameter") + registry.register_exact("type.return", "ReturnType") + registry.register_exact("return_type", "ReturnType") + for capture in ("type", "type.annotation", "reference.type"): + registry.register_exact(capture, "TypeAnnotation") + registry.register_exact("definition.type_alias", "TypeAlias") + registry.register_exact("definition.constant", "Constant") + registry.register_exact("definition.variable", "Variable") + registry.register_exact("decorator", "Decorator") + registry.register_exact("definition.decorator", "Decorator") + for capture in ("reference.import", "reference.include", "reference.require", "reference.use", "import"): + registry.register_exact(capture, "ImportDeclaration") + registry.register_exact("export", "ExportDeclaration") + registry.register_exact("definition.export", "ExportDeclaration") + registry.register_exact("reference.call", "CallExpression") + registry.register_exact("call", "CallExpression") + registry.register_prefix("query.", "Query") + registry.register_prefix("secret.", "SecretRef") + registry.register_exact("entrypoint.api", "APIEndpoint") + registry.register_exact("endpoint", "APIEndpoint") + registry.register_exact("route", "Route") + registry.register_exact("doc.source", "DocumentationSource") + registry.register_prefix("doc", "DocumentationChunk") + registry.register_exact("literal", "Literal") + registry.register_exact("string", "Literal") + registry.register_exact("number", "Literal") + registry.register_exact("control_flow", "ControlFlowBlock") + registry.register_exact("exception", "ExceptionFlow") + registry.register_exact("raises", "ExceptionFlow") + registry.register_exact("handles", "ExceptionFlow") + registry.register_prefix("reference", "Reference") + return registry + + +class GraphBuilder: + """Build an ontology graph from tree-sitter-shaped parser output. + + The builder deliberately uses duck typing instead of importing tree-sitter. + It accepts dictionaries, Python AST-like objects, and tree-sitter Node-like + objects with ``type``, ``children``, ``start_point``, and ``end_point``. + """ + + def __init__( + self, + *, + default_language: str = "", + repository_label: str = "repository", + source_root: str | Path = ".", + include_syntax_captures: bool = True, + capture_table_registry: CaptureTableRegistry | None = None, + ) -> None: + self.default_language = default_language + self.repository_label = repository_label + self.source_root = Path(source_root).as_posix() + self.include_syntax_captures = include_syntax_captures + self.capture_table_registry = capture_table_registry or default_capture_table_registry() + self._node_types = set(node_type_names()) + self._relation_types = set(relation_type_names()) + self._graph = CodeGraph() + self._context = BuildContext("", "", "", repository_label, self.source_root) + self._syntax_nodes: dict[int, str] = {} + self._symbols_by_name: dict[str, list[str]] = {} + self._diagnostics: list[str] = [] + self._unresolved: list[str] = [] + + def build_file_graph(self, bundle: ParseBundle) -> GraphBuildResult: + if bundle.captures: + graph = self.build_from_captures( + bundle.captures, + source_path=bundle.path, + language=bundle.language, + source_text=bundle.source_text, + repository_label=bundle.repository_label, + source_root=bundle.source_root, + ) + else: + tree = bundle.tree or {"type": "Module", "children": []} + graph = self.build( + tree, + source_path=bundle.path, + language=bundle.language, + source_text=bundle.source_text, + repository_label=bundle.repository_label, + source_root=bundle.source_root, + ) + if bundle.content_hash: + for node in graph.nodes_by_type("File"): + node.metadata["content_hash"] = bundle.content_hash + return GraphBuildResult( + nodes=graph.as_dict()["nodes"], + edges=graph.as_dict()["edges"], + diagnostics=list(self._diagnostics), + unresolved=list(self._unresolved), + graph=graph, + ) + + def build( + self, + parse_tree: Any, + *, + source_path: str | Path, + language: str | None = None, + source_text: str = "", + repository_label: str | None = None, + source_root: str | Path | None = None, + ) -> CodeGraph: + path = Path(source_path).as_posix() + root = Path(source_root).as_posix() if source_root is not None else self.source_root + repo_label = repository_label or self.repository_label + self._graph = CodeGraph( + ontology=ONTOLOGY_NAME, + metadata={"source_path": path, "language": language or self.default_language, "source_root": root}, + ) + self._context = BuildContext( + path=path, + language=language or self.default_language, + source_text=source_text, + repository_label=repo_label, + source_root=root, + ) + self._syntax_nodes = {} + self._symbols_by_name = {} + self._diagnostics = [] + self._unresolved = [] + + repository = self._support_node("Repository", repo_label, repo_label, path="") + source = self._support_node("SourceRoot", root, root, path=root) + file = self._support_node("File", path, Path(path).name, path=path) + self._edge("Contains", repository.id, source.id, "repository_source_root") + self._edge("Contains", source.id, file.id, "source_root_file") + + root_node = self._normalize(parse_tree) + if root_node.node_type in {"Module", "module", "program", "source_file"}: + module = self._semantic_node("Module", root_node, label=_module_label(path), owner=file) + module_scope = self._scope_for(module) + self._edge("Contains", file.id, module.id, "file_module") + self._edge("Contains", module.id, module_scope.id, "module_contains_scope") + self._edge("HasScope", module.id, module_scope.id, "module_scope") + self._traverse(root_node, ScopeFrame(module.id, "Module", module.label, module_scope.id, module.label)) + else: + file_scope = self._scope_for(file) + self._edge("HasScope", file.id, file_scope.id, "file_scope") + self._traverse(root_node, ScopeFrame(file.id, "File", file.label, file_scope.id, file.label)) + + self._graph.validate_schema() + return self._graph + + def build_from_captures( + self, + captures: Iterable[CaptureRecord | Mapping[str, Any] | tuple[Any, str]], + *, + source_path: str | Path, + language: str | None = None, + source_text: str = "", + repository_label: str | None = None, + source_root: str | Path | None = None, + ) -> CodeGraph: + root = { + "type": "Module", + "children": [ + {"type": _capture_node_type(capture), "capture_name": _capture_name(capture), "node": _capture_node(capture)} + for capture in captures + ], + } + return self.build( + root, + source_path=source_path, + language=language, + source_text=source_text, + repository_label=repository_label, + source_root=source_root, + ) + + def _traverse(self, raw_node: Any, owner: ScopeFrame) -> None: + node = self._normalize(raw_node) + syntax_id = self._syntax_capture(node) + next_owner = owner + capture_table = self.capture_table_registry.table_for(node.capture_name, owner) + + if capture_table is not None: + semantic = self._emit_captured_semantic(capture_table, node, owner, syntax_id) + if capture_table in {"Class", "Function", "Method", "Component"}: + scope = self._scope_for(semantic) + self._edge("Contains", semantic.id, scope.id, f"{capture_table.lower()}_contains_scope") + self._edge("HasScope", semantic.id, scope.id, f"{capture_table.lower()}_scope") + next_owner = ScopeFrame(semantic.id, capture_table, semantic.label, scope.id, semantic.qualified_name) + elif node.node_type in {"Module", "module", "program", "source_file"} and owner.table != "Module": + semantic = self._semantic_node("Module", node, label=_module_label(self._context.path), owner_id=owner.node_id) + scope = self._scope_for(semantic) + self._edge("Contains", owner.node_id, semantic.id, "contains_module") + self._edge("Contains", semantic.id, scope.id, "module_contains_scope") + self._edge("HasScope", semantic.id, scope.id, "module_scope") + self._derived_from(semantic.id, syntax_id) + next_owner = ScopeFrame(semantic.id, "Module", semantic.label, scope.id, semantic.qualified_name) + elif node.node_type in IMPORT_NODE_TYPES: + self._emit_import(node, owner, syntax_id) + elif node.node_type in EXPORT_NODE_TYPES: + self._emit_simple_semantic("ExportDeclaration", node, owner, syntax_id) + elif node.node_type in CLASS_NODE_TYPES: + semantic = self._emit_declaration("Class", node, owner, syntax_id) + scope = self._scope_for(semantic) + self._edge("Contains", semantic.id, scope.id, "class_contains_scope") + self._edge("HasScope", semantic.id, scope.id, "class_scope") + next_owner = ScopeFrame(semantic.id, "Class", semantic.label, scope.id, semantic.qualified_name) + self._emit_decorators(node, semantic) + elif node.node_type in FUNCTION_NODE_TYPES: + table = "Method" if owner.table in {"Class", "Component"} else "Function" + semantic = self._emit_declaration(table, node, owner, syntax_id) + scope = self._scope_for(semantic) + self._edge("Contains", semantic.id, scope.id, f"{table.lower()}_contains_scope") + self._edge("HasScope", semantic.id, scope.id, f"{table.lower()}_scope") + next_owner = ScopeFrame(semantic.id, table, semantic.label, scope.id, semantic.qualified_name) + self._emit_parameters(node, semantic) + self._emit_return_type(node, semantic) + self._emit_decorators(node, semantic) + elif node.node_type in ASSIGNMENT_NODE_TYPES: + self._emit_assignment(node, owner, syntax_id) + elif node.node_type in CALL_NODE_TYPES: + self._emit_call(node, owner, syntax_id) + elif node.node_type in REFERENCE_NODE_TYPES: + self._emit_reference(node, owner, syntax_id) + elif node.node_type in LITERAL_NODE_TYPES: + self._emit_simple_semantic("Literal", node, owner, syntax_id) + elif node.node_type in PARAMETER_NODE_TYPES: + self._emit_simple_semantic("Parameter", node, owner, syntax_id) + elif node.node_type in RETURN_TYPE_NODE_TYPES: + self._emit_simple_semantic("ReturnType", node, owner, syntax_id) + elif node.node_type in TYPE_NODE_TYPES: + self._emit_simple_semantic("TypeAnnotation", node, owner, syntax_id) + elif node.node_type in CONTROL_FLOW_NODE_TYPES: + self._emit_simple_semantic("ControlFlowBlock", node, owner, syntax_id) + elif node.node_type in EXCEPTION_FLOW_NODE_TYPES: + self._emit_simple_semantic("ExceptionFlow", node, owner, syntax_id) + + for child in self._semantic_children(node): + self._traverse(child, next_owner) + + def _emit_captured_semantic( + self, + table: str, + node: ParserNode, + owner: ScopeFrame, + syntax_id: str, + ) -> GraphNode: + if table == "ImportDeclaration": + return self._emit_import(node, owner, syntax_id) + if table == "ExportDeclaration": + return self._emit_simple_semantic("ExportDeclaration", node, owner, syntax_id) + if table in {"Class", "Function", "Method"}: + return self._emit_declaration(table, node, owner, syntax_id) + if table == "CallExpression": + return self._emit_call(node, owner, syntax_id) + if table == "Reference": + return self._emit_reference(node, owner, syntax_id) + return self._emit_simple_semantic(table, node, owner, syntax_id) + + def _emit_import(self, node: ParserNode, owner: ScopeFrame, syntax_id: str) -> GraphNode: + imported = _import_label(node) or _label_for(node) + semantic = self._semantic_node( + "ImportDeclaration", + node, + label=imported or node.node_type, + owner_id=owner.node_id, + metadata={"imported_name": imported}, + ) + self._connect_owner(owner, semantic) + self._edge_if_allowed("Imports", _import_source_id(owner), semantic.id, "declares_import") + self._derived_from(semantic.id, syntax_id) + if imported: + dependency = self._support_node("Dependency", imported, imported, path=self._context.path) + self._edge("DependsOn", semantic.id, dependency.id, "import_dependency") + self._edge("EvidencedBy", dependency.id, syntax_id, "parser_evidence") + return semantic + + def _emit_declaration(self, table: str, node: ParserNode, owner: ScopeFrame, syntax_id: str) -> GraphNode: + semantic = self._semantic_node(table, node, owner_id=owner.node_id, owner_qualified_name=owner.qualified_name) + self._connect_owner(owner, semantic) + self._edge("Defines", owner.node_id, semantic.id, f"defines_{table.lower()}") + if owner.table in {"Module", "Scope", "Class", "Function", "Method"}: + self._edge("Declares", owner.node_id, semantic.id, f"declares_{table.lower()}") + self._derived_from(semantic.id, syntax_id) + return semantic + + def _emit_assignment(self, node: ParserNode, owner: ScopeFrame, syntax_id: str) -> GraphNode: + assignment = self._semantic_node("Assignment", node, owner_id=owner.node_id, owner_qualified_name=owner.qualified_name) + self._connect_owner(owner, assignment) + self._derived_from(assignment.id, syntax_id) + + target_label = _assignment_target_label(node) + if target_label: + target_table = _assignment_target_table(target_label, owner, node) + target = self._semantic_node( + target_table, + node, + label=target_label, + owner_id=owner.node_id, + owner_qualified_name=owner.qualified_name, + ) + self._connect_owner(owner, target) + self._edge("Defines", owner.node_id, target.id, f"defines_{target_table.lower()}") + self._edge("Assigns", assignment.id, target.id, "assignment_target") + self._derived_from(target.id, syntax_id) + annotation = _field(node, "annotation") + if annotation is not None: + type_node = self._emit_type_annotation(annotation, target) + self._edge("HasTypeAnnotation", target.id, type_node.id, "assignment_annotation") + + value = _field(node, "value") + if value is not None and _normalized_type(value) in CALL_NODE_TYPES: + call = self._emit_call(self._normalize(value), owner, self._syntax_capture(self._normalize(value))) + self._edge("Assigns", assignment.id, call.id, "assignment_value") + + return assignment + + def _emit_call(self, node: ParserNode, owner: ScopeFrame, syntax_id: str) -> GraphNode: + call = self._semantic_node( + "CallExpression", + node, + label=_call_label(node) or _label_for(node), + owner_id=owner.node_id, + owner_qualified_name=owner.qualified_name, + ) + self._connect_owner(owner, call) + if owner.table in {"Function", "Method", "APIEndpoint", "Route", "Component"}: + self._edge("Calls", owner.node_id, call.id, "body_call") + target = self._emit_reference_edges(call, call.label, kind_prefix="call") + if target is not None: + self._edge_if_allowed("Calls", call.id, target.id, "call_target") + self._derived_from(call.id, syntax_id) + return call + + def _emit_reference(self, node: ParserNode, owner: ScopeFrame, syntax_id: str) -> GraphNode: + reference = self._semantic_node( + "Reference", + node, + label=_label_for(node), + owner_id=owner.node_id, + owner_qualified_name=owner.qualified_name, + ) + self._connect_owner(owner, reference) + self._emit_reference_edges(reference, reference.label, kind_prefix="reference") + self._derived_from(reference.id, syntax_id) + return reference + + def _emit_simple_semantic(self, table: str, node: ParserNode, owner: ScopeFrame, syntax_id: str) -> GraphNode: + semantic = self._semantic_node( + table, + node, + label=_label_for(node), + owner_id=owner.node_id, + owner_qualified_name=owner.qualified_name, + ) + self._connect_owner(owner, semantic) + self._emit_contextual_relations(semantic, node, owner, syntax_id) + self._derived_from(semantic.id, syntax_id) + return semantic + + def _emit_parameters(self, node: ParserNode, callable_node: GraphNode) -> None: + for index, parameter in enumerate(_parameters(node)): + parser_node = self._normalize(parameter) + syntax_id = self._syntax_capture(parser_node) + param_node = self._semantic_node( + "Parameter", + parser_node, + label=_label_for(parser_node) or f"param_{index}", + owner_id=callable_node.id, + owner_qualified_name=callable_node.qualified_name, + ) + self._edge("HasParameter", callable_node.id, param_node.id, "callable_parameter", metadata={"ordinal": index}) + self._derived_from(param_node.id, syntax_id) + annotation = _field(parser_node, "annotation") + if annotation is not None: + type_node = self._emit_type_annotation(annotation, param_node) + self._edge("HasTypeAnnotation", param_node.id, type_node.id, "parameter_annotation") + + def _emit_return_type(self, node: ParserNode, callable_node: GraphNode) -> None: + raw_return = _field(node, "returns") or _field(node, "return_type") + if raw_return is None: + return + return_parser = self._normalize(raw_return) + syntax_id = self._syntax_capture(return_parser) + return_node = self._semantic_node( + "ReturnType", + return_parser, + label=_label_for(return_parser), + owner_id=callable_node.id, + owner_qualified_name=callable_node.qualified_name, + ) + self._edge("HasReturnType", callable_node.id, return_node.id, "callable_return_type") + type_node = self._emit_type_annotation(return_parser, return_node) + self._edge("HasTypeAnnotation", return_node.id, type_node.id, "return_type_annotation") + self._derived_from(return_node.id, syntax_id) + + def _emit_type_annotation(self, raw_node: Any, owner: GraphNode) -> GraphNode: + parser_node = self._normalize(raw_node) + syntax_id = self._syntax_capture(parser_node) + type_node = self._semantic_node( + "TypeAnnotation", + parser_node, + label=_label_for(parser_node), + owner_id=owner.id, + owner_qualified_name=owner.qualified_name, + ) + self._emit_reference_edges(type_node, type_node.label, kind_prefix="type_annotation") + self._derived_from(type_node.id, syntax_id) + return type_node + + def _emit_decorators(self, node: ParserNode, declaration: GraphNode) -> None: + for raw_decorator in _iter_field_items(node, "decorator_list", "decorators"): + decorator_node = self._normalize(raw_decorator) + syntax_id = self._syntax_capture(decorator_node) + decorator = self._semantic_node( + "Decorator", + decorator_node, + label=_call_label(decorator_node) or _label_for(decorator_node), + owner_id=declaration.id, + owner_qualified_name=declaration.qualified_name, + ) + self._edge("DecoratedBy", declaration.id, decorator.id, "declaration_decorator") + target = self._emit_reference_edges(decorator, decorator.label, kind_prefix="decorator") + if target is not None: + self._edge_if_allowed("Calls", decorator.id, target.id, "decorator_call") + self._derived_from(decorator.id, syntax_id) + + def _emit_contextual_relations( + self, + semantic: GraphNode, + node: ParserNode, + owner: ScopeFrame, + syntax_id: str, + ) -> None: + table = semantic.table + + if table == "ExportDeclaration": + self._edge_if_allowed("Exports", owner.node_id, semantic.id, "exports_declaration") + target = self._resolve_reference_target(_export_target_label(node) or semantic.label, EXPORT_TARGET_TYPES) + if target is not None and target.id != semantic.id: + self._edge_if_allowed("Exports", owner.node_id, target.id, "exports_symbol") + + if table in DEFINED_CAPTURE_TABLES: + self._edge_if_allowed("Defines", owner.node_id, semantic.id, f"defines_{table.lower()}") + self._edge_if_allowed("Declares", owner.node_id, semantic.id, f"declares_{table.lower()}") + + if table in {"Component", "APIEndpoint", "Route"}: + self._edge_if_allowed("Exposes", owner.node_id, semantic.id, f"exposes_{table.lower()}") + + if table in {"Route", "APIEndpoint"}: + target = self._runtime_target(node, owner, syntax_id) + if target is not None and target.id != semantic.id: + self._edge_if_allowed("RoutesTo", semantic.id, target.id, "runtime_handler") + self._edge_if_allowed("Exposes", semantic.id, target.id, "runtime_surface") + + if table == "Parameter": + self._edge_if_allowed("HasParameter", owner.node_id, semantic.id, "captured_parameter") + annotation = _field(node, "annotation", "type_annotation") + if annotation is not None: + type_node = self._emit_type_annotation(annotation, semantic) + self._edge("HasTypeAnnotation", semantic.id, type_node.id, "parameter_annotation") + + if table == "ReturnType": + self._edge_if_allowed("HasReturnType", owner.node_id, semantic.id, "captured_return_type") + type_node = self._emit_type_annotation(node, semantic) + self._edge("HasTypeAnnotation", semantic.id, type_node.id, "return_type_annotation") + + if table == "TypeAnnotation": + self._edge_if_allowed("HasTypeAnnotation", owner.node_id, semantic.id, "captured_type_annotation") + self._emit_reference_edges(semantic, semantic.label, kind_prefix="type_annotation") + + if table == "TypeAlias": + annotation = _field(node, "annotation", "type_annotation", "value") + if annotation is not None: + type_node = self._emit_type_annotation(annotation, semantic) + self._edge_if_allowed("HasTypeAnnotation", semantic.id, type_node.id, "type_alias_annotation") + + if table == "Decorator": + self._edge_if_allowed("DecoratedBy", owner.node_id, semantic.id, "captured_decorator") + target = self._emit_reference_edges(semantic, semantic.label, kind_prefix="decorator") + if target is not None: + self._edge_if_allowed("Calls", semantic.id, target.id, "decorator_call") + + if table == "Query": + self._edge_if_allowed("ExecutesQuery", owner.node_id, semantic.id, "executes_query") + self._emit_reference_edges(semantic, _query_reference_label(node), kind_prefix="query") + + if table == "SecretRef": + self._edge_if_allowed("UsesSecret", owner.node_id, semantic.id, "uses_secret") + self._emit_reference_edges(semantic, semantic.label, kind_prefix="secret") + + if table in {"DocumentationSource", "DocumentationChunk"}: + self._edge_if_allowed("Documents", semantic.id, owner.node_id, "documents_owner") + self._edge_if_allowed("EvidencedBy", semantic.id, syntax_id, "parser_evidence") + + if table == "ExceptionFlow": + if _is_raise_flow(node): + self._edge_if_allowed("Raises", owner.node_id, semantic.id, "raises_exception") + if _is_handle_flow(node): + self._edge_if_allowed("Handles", owner.node_id, semantic.id, "handles_exception") + + if table == "Reference": + self._emit_reference_edges(semantic, semantic.label, kind_prefix="reference") + + if table == "ControlFlowBlock": + self._emit_reference_edges(semantic, _control_flow_reference_label(node), kind_prefix="control_flow") + + def _emit_reference_edges( + self, + source: GraphNode, + label: str, + *, + kind_prefix: str, + target_tables: set[str] | None = None, + ) -> GraphNode | None: + target = self._resolve_reference_target(label, target_tables) + if target is None or target.id == source.id: + return None + metadata = {"label": label, "resolver": "label"} + self._edge_if_allowed("References", source.id, target.id, f"{kind_prefix}_reference", metadata=metadata) + self._edge_if_allowed("ResolvesTo", source.id, target.id, f"{kind_prefix}_resolution", metadata=metadata) + return target + + def _connect_owner(self, owner: ScopeFrame, semantic: GraphNode) -> None: + self._edge("Contains", owner.node_id, semantic.id, f"contains_{semantic.table.lower()}") + if owner.scope_id: + self._edge("Contains", owner.scope_id, semantic.id, f"scope_contains_{semantic.table.lower()}") + + def _support_node(self, table: str, stable_key: str, label: str, *, path: str) -> GraphNode: + node = GraphNode( + id=_id(table, stable_key), + table=table, + label=label, + kind=table.lower(), + path=path, + summary=label, + metadata={"canonical_key": stable_key}, + ) + added = self._graph.add_node(node) + self._register_resolvable(added) + return added + + def _semantic_node( + self, + table: str, + parser_node: ParserNode, + *, + label: str | None = None, + owner: GraphNode | None = None, + owner_id: str = "", + owner_qualified_name: str = "", + metadata: dict[str, Any] | None = None, + ) -> GraphNode: + if table not in self._node_types: + raise ValueError(f"Unknown ontology node type: {table}") + semantic_label = label or _label_for(parser_node) or table + qualified_name = _qualified_name(owner_qualified_name or (owner.qualified_name if owner else ""), semantic_label) + stable_key = "|".join( + str(value) + for value in ( + self._context.path, + table, + qualified_name, + parser_node.node_type, + parser_node.line_start, + parser_node.byte_start, + semantic_label, + ) + ) + node = GraphNode( + id=_id(table, stable_key), + table=table, + label=semantic_label, + kind=_kind_for(table, parser_node), + language=self._context.language, + path=self._context.path, + qualified_name=qualified_name, + scope_id=owner_id or (owner.id if owner else ""), + line_start=parser_node.line_start, + line_end=parser_node.line_end, + byte_start=parser_node.byte_start, + byte_end=parser_node.byte_end, + tree_sitter_node_type=parser_node.node_type, + capture_name=parser_node.capture_name, + summary=_summary_for(table, semantic_label, parser_node), + metadata={"canonical_key": stable_key, **(metadata or {})}, + ) + added = self._graph.add_node(node) + self._register_resolvable(added) + return added + + def _symbol_node(self, label: str) -> GraphNode | None: + symbol_label = label.strip() + if not symbol_label: + return None + stable_key = f"{self._context.path}|Symbol|{symbol_label}" + node = GraphNode( + id=_id("Symbol", stable_key), + table="Symbol", + label=symbol_label, + kind="symbol_reference", + language=self._context.language, + path=self._context.path, + qualified_name=symbol_label, + summary=symbol_label, + metadata={"canonical_key": stable_key, "resolution": "name_placeholder"}, + ) + added = self._graph.add_node(node) + self._register_resolvable(added) + return added + + def _register_resolvable(self, node: GraphNode) -> None: + if node.table not in RESOLVABLE_NODE_TYPES: + return + keys = {node.label, node.qualified_name, str(node.metadata.get("imported_name") or "")} + for key in keys: + normalized = _symbol_key(key) + if not normalized: + continue + self._symbols_by_name.setdefault(normalized, []) + if node.id not in self._symbols_by_name[normalized]: + self._symbols_by_name[normalized].append(node.id) + + def _resolve_reference_target(self, label: str, target_tables: set[str] | None = None) -> GraphNode | None: + reference_label = label.strip() + if not reference_label: + return None + candidate_labels = (reference_label, reference_label.rsplit(".", 1)[-1]) + for candidate_label in candidate_labels: + for node_id in reversed(self._symbols_by_name.get(_symbol_key(candidate_label), ())): + node = self._graph.nodes.get(node_id) + if node is not None and (target_tables is None or node.table in target_tables): + return node + if target_tables is not None and "Symbol" not in target_tables: + return None + return self._symbol_node(reference_label) + + def _scope_for(self, owner: GraphNode) -> GraphNode: + stable_key = f"{self._context.path}|{owner.id}|scope" + scope = GraphNode( + id=_id("Scope", stable_key), + table="Scope", + label=f"{owner.label} scope", + kind=f"{owner.table.lower()}_scope", + language=owner.language, + path=owner.path, + qualified_name=f"{owner.qualified_name or owner.label}.", + scope_id=owner.id, + line_start=owner.line_start, + line_end=owner.line_end, + byte_start=owner.byte_start, + byte_end=owner.byte_end, + summary=f"Scope for {owner.label}", + metadata={"canonical_key": stable_key}, + ) + return self._graph.add_node(scope) + + def _syntax_capture(self, node: ParserNode) -> str: + stable_key = "|".join( + str(value) + for value in (self._context.path, node.node_type, node.line_start, node.byte_start, _label_for(node)) + ) + syntax_id = _id("SyntaxCapture", stable_key) + if not self.include_syntax_captures: + return syntax_id + if id(node) in self._syntax_nodes: + return self._syntax_nodes[id(node)] + syntax = GraphNode( + id=syntax_id, + table="SyntaxCapture", + label=node.capture_name or node.node_type, + kind=node.node_type, + language=self._context.language, + path=self._context.path, + line_start=node.line_start, + line_end=node.line_end, + byte_start=node.byte_start, + byte_end=node.byte_end, + tree_sitter_node_type=node.node_type, + capture_name=node.capture_name, + summary=node.text[:160], + metadata={"canonical_key": stable_key, "fields": sorted(node.fields.keys())}, + ) + self._graph.add_node(syntax) + self._syntax_nodes[id(node)] = syntax_id + return syntax_id + + def _derived_from(self, semantic_id: str, syntax_id: str) -> None: + if self.include_syntax_captures and syntax_id in self._graph.nodes: + self._edge("DerivedFrom", semantic_id, syntax_id, "parser_capture") + + def _runtime_target(self, node: ParserNode, owner: ScopeFrame, syntax_id: str) -> GraphNode | None: + label = _runtime_target_label(node) + if label: + target = self._resolve_reference_target(label, RUNTIME_TARGET_TYPES) + if target is not None: + return target + endpoint = self._semantic_node( + "APIEndpoint", + node, + label=label, + owner_id=owner.node_id, + owner_qualified_name=owner.qualified_name, + metadata={"inferred_from": "runtime_target"}, + ) + self._connect_owner(owner, endpoint) + self._edge_if_allowed("Defines", owner.node_id, endpoint.id, "defines_inferred_endpoint") + self._edge_if_allowed("Exposes", owner.node_id, endpoint.id, "exposes_inferred_endpoint") + self._derived_from(endpoint.id, syntax_id) + return endpoint + if owner.table in RUNTIME_TARGET_TYPES: + return self._graph.nodes.get(owner.node_id) + return None + + def _edge_if_allowed( + self, + edge_type: str, + source_id: str, + target_id: str, + kind: str, + *, + metadata: dict[str, Any] | None = None, + ) -> GraphEdge | None: + source = self._graph.nodes.get(source_id) + target = self._graph.nodes.get(target_id) + if source is None or target is None: + return None + spec = get_relation_type(edge_type) + if source.table not in spec.source_types or target.table not in spec.target_types: + return None + return self._edge(edge_type, source_id, target_id, kind, metadata=metadata) + + def _edge( + self, + edge_type: str, + source_id: str, + target_id: str, + kind: str, + *, + metadata: dict[str, Any] | None = None, + ) -> GraphEdge: + if edge_type not in self._relation_types: + raise ValueError(f"Unknown ontology relation type: {edge_type}") + edge = GraphEdge( + id=_id("edge", f"{edge_type}|{source_id}|{target_id}|{kind}"), + type=edge_type, + source_id=source_id, + target_id=target_id, + kind=kind, + metadata={"canonical_key": f"{edge_type}|{source_id}|{target_id}|{kind}", **(metadata or {})}, + ) + return self._graph.add_edge(edge) + + def _normalize(self, raw_node: Any) -> ParserNode: + if isinstance(raw_node, ParserNode): + return raw_node + if isinstance(raw_node, Mapping): + nested = raw_node.get("node") + if nested is not None: + nested_node = self._normalize(nested) + return ParserNode( + node_type=str(raw_node.get("type") or nested_node.node_type), + fields={**nested_node.fields, **{key: value for key, value in raw_node.items() if key != "node"}}, + children=nested_node.children, + line_start=nested_node.line_start, + line_end=nested_node.line_end, + byte_start=nested_node.byte_start, + byte_end=nested_node.byte_end, + capture_name=str(raw_node.get("capture_name") or nested_node.capture_name or ""), + text=nested_node.text, + ) + fields = {key: value for key, value in raw_node.items() if key not in DICT_NODE_META_KEYS} + children = tuple(_coerce_children(raw_node)) + return ParserNode( + node_type=str(raw_node.get("type") or raw_node.get("node_type") or raw_node.get("kind") or "unknown"), + fields=fields, + children=children, + line_start=_line(raw_node, "line_start", "start_line"), + line_end=_line(raw_node, "line_end", "end_line"), + byte_start=_line(raw_node, "byte_start", "start_byte"), + byte_end=_line(raw_node, "byte_end", "end_byte"), + capture_name=str(raw_node.get("capture_name") or raw_node.get("capture") or ""), + text=str(raw_node.get("text") or ""), + ) + node_type = getattr(raw_node, "type", "") or type(raw_node).__name__ + fields = _object_fields(raw_node) + return ParserNode( + node_type=str(node_type), + fields=fields, + children=tuple(getattr(raw_node, "children", ()) or _field_children(fields)), + line_start=_point_line(getattr(raw_node, "start_point", None)) or getattr(raw_node, "lineno", None), + line_end=_point_line(getattr(raw_node, "end_point", None)) or getattr(raw_node, "end_lineno", None), + byte_start=getattr(raw_node, "start_byte", None) or getattr(raw_node, "col_offset", None), + byte_end=getattr(raw_node, "end_byte", None) or getattr(raw_node, "end_col_offset", None), + text=_node_text(raw_node), + ) + + def _semantic_children(self, node: ParserNode) -> tuple[Any, ...]: + ignored_fields = {"name", "id", "module", "names", "args", "returns", "return_type", "decorator_list", "decorators"} + children: list[Any] = list(node.children) + for field_name, value in node.fields.items(): + if field_name in ignored_fields: + continue + if _is_parser_like(value): + children.append(value) + elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + children.extend(item for item in value if _is_parser_like(item)) + return tuple(children) + + +IMPORT_NODE_TYPES = {"import_statement", "import_from_statement", "import_declaration", "Import", "ImportFrom"} +EXPORT_NODE_TYPES = {"export_statement", "export_clause", "export_declaration"} +CLASS_NODE_TYPES = {"class_definition", "class_declaration", "struct_item", "interface_declaration", "ClassDef"} +FUNCTION_NODE_TYPES = {"function_definition", "function_declaration", "method_definition", "method_declaration", "FunctionDef"} +PARAMETER_NODE_TYPES = {"parameter", "typed_parameter", "default_parameter", "arg"} +RETURN_TYPE_NODE_TYPES = {"return_type", "returns"} +TYPE_NODE_TYPES = {"type", "type_identifier", "type_annotation", "annotation"} +ASSIGNMENT_NODE_TYPES = {"assignment", "assignment_expression", "variable_declaration", "Assign", "AnnAssign"} +CALL_NODE_TYPES = {"call", "call_expression", "invocation_expression", "Call"} +REFERENCE_NODE_TYPES = {"identifier", "field_identifier", "attribute", "Name", "Attribute"} +LITERAL_NODE_TYPES = {"string", "integer", "float", "true", "false", "null", "none", "Constant"} +CONTROL_FLOW_NODE_TYPES = {"if_statement", "for_statement", "while_statement", "match_statement", "switch_statement"} +EXCEPTION_FLOW_NODE_TYPES = {"try_statement", "except_clause", "catch_clause", "raise_statement", "throw_statement"} +RESOLVABLE_NODE_TYPES = { + "Symbol", + "Module", + "Class", + "Function", + "Method", + "Variable", + "Constant", + "ClassAttribute", + "InstanceAttribute", + "Property", + "Parameter", + "Dependency", + "APIEndpoint", + "Component", +} +EXPORT_TARGET_TYPES = { + "Class", + "Function", + "Method", + "Variable", + "Constant", + "ClassAttribute", + "InstanceAttribute", + "Property", + "APIEndpoint", + "Component", +} +RUNTIME_TARGET_TYPES = {"Function", "Method", "Component", "APIEndpoint"} +IMPORT_SOURCE_TYPES = {"File", "Module", "Scope"} +DEFINED_CAPTURE_TABLES = { + "APIEndpoint", + "Component", + "Route", + "TypeAlias", + "Variable", + "Constant", + "ClassAttribute", + "InstanceAttribute", + "Property", +} +DICT_NODE_META_KEYS = { + "type", + "node_type", + "kind", + "children", + "body", + "line_start", + "line_end", + "start_line", + "end_line", + "byte_start", + "byte_end", + "start_byte", + "end_byte", + "capture", + "capture_name", + "text", +} + + +def _capture_node(capture: Mapping[str, Any] | tuple[Any, str]) -> Any: + if isinstance(capture, CaptureRecord): + return capture.node + if isinstance(capture, tuple): + return capture[0] + return capture.get("node") or capture + + +def _capture_name(capture: Mapping[str, Any] | tuple[Any, str]) -> str: + if isinstance(capture, CaptureRecord): + return capture.capture + if isinstance(capture, tuple): + return str(capture[1]) + return str(capture.get("capture_name") or capture.get("capture") or "") + + +def _capture_node_type(capture: Mapping[str, Any] | tuple[Any, str]) -> str: + node = _capture_node(capture) + if isinstance(node, Mapping): + return str(node.get("type") or node.get("node_type") or node.get("kind") or "unknown") + return str(getattr(node, "type", "") or type(node).__name__) + + +def _table_from_capture(capture_name: str, owner: ScopeFrame) -> str | None: + return default_capture_table_registry().table_for(capture_name, owner) + + +def _normalize_capture_name(capture_name: str) -> str: + return capture_name.lstrip("@") + + +def _resolve_capture_table(table: str | CaptureTableResolver, capture: str, owner: ScopeFrame) -> str | None: + if callable(table): + return table(capture, owner) + return table + + +def _function_capture_table(_capture: str, owner: ScopeFrame) -> str: + return "Method" if owner.table in {"Class", "Component"} else "Function" + + +def _import_source_id(owner: ScopeFrame) -> str: + if owner.table in IMPORT_SOURCE_TYPES: + return owner.node_id + return owner.scope_id or owner.node_id + + +def _id(prefix: str, value: str) -> str: + return f"{prefix}:{hashlib.sha1(value.encode('utf-8')).hexdigest()[:20]}" + + +def _module_label(path: str) -> str: + stem = path.rsplit(".", 1)[0] + return stem.replace("/", ".") + + +def _qualified_name(owner: str, label: str) -> str: + if not owner or owner == label: + return label + if not label: + return owner + return f"{owner}.{label}" + + +def _kind_for(table: str, node: ParserNode) -> str: + if table == "Method": + return "method" + if table == "Function": + return "function" + if table == "Class": + return "class" + return node.node_type + + +def _field(node: ParserNode, *names: str) -> Any: + for name in names: + if name in node.fields: + return node.fields[name] + return None + + +def _iter_field_items(node: ParserNode, *names: str) -> tuple[Any, ...]: + items: list[Any] = [] + for name in names: + value = node.fields.get(name) + if value is None: + continue + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + items.extend(value) + else: + items.append(value) + return tuple(items) + + +def _label_for(node: ParserNode) -> str: + for key in ("name", "id", "arg", "attr", "module"): + value = node.fields.get(key) + label = _value_label(value) + if label: + return label + if "value" in node.fields: + return _value_label(node.fields["value"]) + return node.text.strip() or node.node_type + + +def _summary_for(table: str, label: str, node: ParserNode) -> str: + if table in {"DocumentationSource", "DocumentationChunk"} and node.text.strip(): + return node.text.strip() + return label + + +def _value_label(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return str(value) + if isinstance(value, Mapping): + if "id" in value: + return str(value["id"]) + if "name" in value: + return str(value["name"]) + if "arg" in value: + return str(value["arg"]) + if "attr" in value: + base = _value_label(value.get("value")) + return f"{base}.{value['attr']}" if base else str(value["attr"]) + if "value" in value: + return _value_label(value["value"]) + if hasattr(value, "id"): + return str(getattr(value, "id")) + if hasattr(value, "name"): + return str(getattr(value, "name")) + if hasattr(value, "arg"): + return str(getattr(value, "arg")) + if hasattr(value, "attr"): + base = _value_label(getattr(value, "value", None)) + return f"{base}.{getattr(value, 'attr')}" if base else str(getattr(value, "attr")) + if hasattr(value, "value"): + return _value_label(getattr(value, "value")) + return "" + + +def _symbol_key(label: str) -> str: + return label.strip().lower() + + +def _export_target_label(node: ParserNode) -> str: + for field_name in ("exported", "target", "name", "declaration"): + label = _value_label(node.fields.get(field_name)) + if label: + return label + return _label_for(node) + + +def _runtime_target_label(node: ParserNode) -> str: + for field_name in ("handler", "endpoint", "target", "function", "callback"): + label = _value_label(node.fields.get(field_name)) + if label: + return label + return "" + + +def _query_reference_label(node: ParserNode) -> str: + for field_name in ("table", "collection", "model", "target", "index"): + label = _value_label(node.fields.get(field_name)) + if label: + return label + return "" + + +def _control_flow_reference_label(node: ParserNode) -> str: + for field_name in ("test", "condition", "subject"): + label = _value_label(node.fields.get(field_name)) + if label: + return label + return "" + + +def _is_raise_flow(node: ParserNode) -> bool: + capture = node.capture_name.lstrip("@") + return capture == "raises" or node.node_type in {"raise_statement", "throw_statement"} + + +def _is_handle_flow(node: ParserNode) -> bool: + capture = node.capture_name.lstrip("@") + return capture == "handles" or node.node_type in {"try_statement", "except_clause", "catch_clause"} + + +def _import_label(node: ParserNode) -> str: + module = _value_label(node.fields.get("module")) + names = node.fields.get("names") + imported_names: list[str] = [] + if isinstance(names, Sequence) and not isinstance(names, (str, bytes, bytearray)): + imported_names = [_value_label(name) for name in names if _value_label(name)] + elif names is not None: + imported_names = [_value_label(names)] + if module and imported_names: + return ", ".join(f"{module}.{name}" for name in imported_names) + return module or ", ".join(imported_names) + + +def _call_label(node: ParserNode) -> str: + return _value_label(node.fields.get("func")) or _value_label(node.fields.get("function")) + + +def _assignment_target_label(node: ParserNode) -> str: + target = node.fields.get("target") + targets = node.fields.get("targets") + if target is not None: + return _value_label(target) + if isinstance(targets, Sequence) and not isinstance(targets, (str, bytes, bytearray)) and targets: + return _value_label(targets[0]) + return _value_label(targets) + + +def _assignment_target_table(label: str, owner: ScopeFrame, node: ParserNode) -> str: + if label.isupper(): + return "Constant" + if owner.table == "Class": + return "ClassAttribute" + if "." in label: + return "InstanceAttribute" + if node.node_type == "AnnAssign" and owner.table == "Class": + return "ClassAttribute" + return "Variable" + + +def _parameters(node: ParserNode) -> tuple[Any, ...]: + raw_args = node.fields.get("args") or node.fields.get("parameters") + if raw_args is None: + return () + if isinstance(raw_args, Mapping): + args = raw_args.get("args") or raw_args.get("children") or () + if isinstance(args, Sequence) and not isinstance(args, (str, bytes, bytearray)): + return tuple(args) + if hasattr(raw_args, "args"): + args = getattr(raw_args, "args") + if isinstance(args, Sequence): + return tuple(args) + if isinstance(raw_args, Sequence) and not isinstance(raw_args, (str, bytes, bytearray)): + return tuple(raw_args) + return (raw_args,) + + +def _normalized_type(raw_node: Any) -> str: + if isinstance(raw_node, ParserNode): + return raw_node.node_type + if isinstance(raw_node, Mapping): + return str(raw_node.get("type") or raw_node.get("node_type") or raw_node.get("kind") or "unknown") + return str(getattr(raw_node, "type", "") or type(raw_node).__name__) + + +def _coerce_children(raw_node: Mapping[str, Any]) -> tuple[Any, ...]: + children: list[Any] = [] + for key in ("children", "body"): + value = raw_node.get(key) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + children.extend(value) + elif value is not None: + children.append(value) + return tuple(children) + + +def _field_children(fields: Mapping[str, Any]) -> tuple[Any, ...]: + children: list[Any] = [] + for value in fields.values(): + if _is_parser_like(value): + children.append(value) + elif isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + children.extend(item for item in value if _is_parser_like(item)) + return tuple(children) + + +def _object_fields(raw_node: Any) -> Mapping[str, Any]: + if hasattr(raw_node, "_fields"): + return {name: getattr(raw_node, name) for name in getattr(raw_node, "_fields")} + if hasattr(raw_node, "child_by_field_name"): + fields: dict[str, Any] = {} + for name in ("name", "body", "parameters", "return_type", "function", "argument", "left", "right"): + try: + value = raw_node.child_by_field_name(name) + except Exception: + value = None + if value is not None: + fields[name] = value + return fields + return { + key: value + for key, value in vars(raw_node).items() + if not key.startswith("_") and key not in {"children", "type"} + } if hasattr(raw_node, "__dict__") else {} + + +def _is_parser_like(value: Any) -> bool: + if value is None or isinstance(value, (str, bytes, bytearray, int, float, bool)): + return False + if isinstance(value, Mapping): + return any(key in value for key in ("type", "node_type", "kind", "body", "children")) + return hasattr(value, "type") or hasattr(value, "_fields") + + +def _line(raw_node: Mapping[str, Any], *keys: str) -> int | None: + for key in keys: + value = raw_node.get(key) + if isinstance(value, int): + return value + start_point = raw_node.get("start_point") + end_point = raw_node.get("end_point") + if "start" in keys[0] and start_point is not None: + return _point_line(start_point) + if "end" in keys[0] and end_point is not None: + return _point_line(end_point) + return None + + +def _point_line(point: Any) -> int | None: + if point is None: + return None + if isinstance(point, Sequence) and point: + return int(point[0]) + 1 + if hasattr(point, "row"): + return int(getattr(point, "row")) + 1 + return None + + +def _node_text(raw_node: Any) -> str: + text = getattr(raw_node, "text", b"") + if isinstance(text, bytes): + return text.decode("utf-8", errors="replace") + return str(text or "") diff --git a/src/codebase_graph/graph_context.py b/src/codebase_graph/graph_context.py deleted file mode 100644 index c5e1eb7..0000000 --- a/src/codebase_graph/graph_context.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from typing import Any - -def build_compact_graph_context( - graph: dict[str, Any], - query: str, - *, - kind: str | None = None, - profile: str = "dependencies", - limit: int = 3, - max_depth: int = 1, - budget: int = 600, - include_raw: bool = False, -) -> dict[str, Any]: - nodes = list(graph.get("nodes", [])) - edges = list(graph.get("edges", [])) - matches = _match_nodes(nodes, query, kind=kind)[:limit] - match_ids = {node["id"] for node in matches} - related_edges = [edge for edge in edges if edge.get("source_id") in match_ids or edge.get("target_id") in match_ids] - related_ids = {edge.get("source_id") for edge in related_edges} | {edge.get("target_id") for edge in related_edges} - related_nodes = [node for node in nodes if node.get("id") in related_ids and node.get("id") not in match_ids] - lines: list[str] = [] - for node in matches: - label = node.get("qualified_name") or node.get("label") or node.get("id") - path = node.get("path") or "" - lines.append(f"- {node.get('table')}: {label} {f'({path})' if path else ''}".strip()) - for edge in related_edges[: max(0, budget // 80)]: - lines.append(f"- {edge.get('type')}: {edge.get('source_id')} -> {edge.get('target_id')}") - text = "\n".join(lines) - if len(text) > budget: - text = text[:budget].rstrip() - payload = { - "query": query, - "profile": profile, - "max_depth": max_depth, - "context": text, - "items": matches, - "related": related_nodes[:limit], - "edge_count": len(related_edges), - } - if include_raw: - payload["raw_edges"] = related_edges - return payload - -def _match_nodes(nodes: list[dict[str, Any]], query: str, kind: str | None = None) -> list[dict[str, Any]]: - terms = [term for term in query.lower().replace("_", " ").replace(".", " ").split() if term] - scored: list[tuple[int, dict[str, Any]]] = [] - for node in nodes: - if kind and node.get("table") != kind and node.get("kind") != kind: - continue - haystack = " ".join( - str(node.get(field, "")) for field in ("id", "label", "qualified_name", "path", "summary", "kind", "table") - ).lower() - score = sum(1 for term in terms if term in haystack) - if score or not terms: - scored.append((score, node)) - return [node for _, node in sorted(scored, key=lambda item: (-item[0], item[1].get("id", "")))] diff --git a/src/codebase_graph/graph_core.py b/src/codebase_graph/graph_core.py deleted file mode 100644 index 1497fd2..0000000 --- a/src/codebase_graph/graph_core.py +++ /dev/null @@ -1,323 +0,0 @@ -from __future__ import annotations - -import argparse -import json -import re -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Sequence - -from .code_map import CODE_EXTENSIONS, _iter_indexable_files -from .graph_context import build_compact_graph_context -from .ladybug import HashingEmbeddingProvider, LadybugGraphExporter, LadybugGraphStore -from .ontology import schema_payload - -DEFAULT_STATE_DIR = Path(".codebase_graph/graph") -DEFAULT_DB_FILENAME = "knowledge_graph.json" -DEFAULT_STAGING_DIRNAME = "staging" - -@dataclass(slots=True) -class GraphCoreStatus: - source_root: Path - state_dir: Path - database_path: Path - staging_dir: Path - database_exists: bool - stale: bool - source_file_count: int - latest_source_mtime: float | None - database_mtime: float | None - - def as_dict(self) -> dict[str, Any]: - return { - "source_root": str(self.source_root), - "state_dir": str(self.state_dir), - "database_path": str(self.database_path), - "staging_dir": str(self.staging_dir), - "database_exists": self.database_exists, - "stale": self.stale, - "source_file_count": self.source_file_count, - "latest_source_mtime": self.latest_source_mtime, - "database_mtime": self.database_mtime, - "recommended_search_command": "codebase-graph search '' --source-root .", - "recommended_cypher_command": 'codebase-graph cypher "MATCH (n:PythonClass) RETURN n.id, n.label LIMIT 5" --source-root .', - "recommended_schema_command": "codebase-graph schema", - "recommended_context_command": "codebase-graph context '' --source-root .", - } - -class CodebaseGraph: - def __init__( - self, - source_root: str | Path = ".", - state_dir: str | Path | None = None, - database_path: str | Path | None = None, - staging_dir: str | Path | None = None, - embedding_provider: HashingEmbeddingProvider | None = None, - ) -> None: - self.source_root = Path(source_root) - self.state_dir = Path(state_dir) if state_dir is not None else self.source_root / DEFAULT_STATE_DIR - self.database_path = Path(database_path) if database_path is not None else self.state_dir / DEFAULT_DB_FILENAME - self.staging_dir = Path(staging_dir) if staging_dir is not None else self.state_dir / DEFAULT_STAGING_DIRNAME - self.embedding_provider = embedding_provider or HashingEmbeddingProvider() - - def status(self) -> GraphCoreStatus: - mtimes = _source_mtimes(self.source_root) - database_exists = self.database_path.exists() - database_mtime = self.database_path.stat().st_mtime if database_exists else None - latest_source_mtime = max(mtimes) if mtimes else None - stale = not database_exists or ( - latest_source_mtime is not None and database_mtime is not None and latest_source_mtime > database_mtime - ) - return GraphCoreStatus( - source_root=self.source_root, - state_dir=self.state_dir, - database_path=self.database_path, - staging_dir=self.staging_dir, - database_exists=database_exists, - stale=stale, - source_file_count=len(mtimes), - latest_source_mtime=latest_source_mtime, - database_mtime=database_mtime, - ) - - def materialize(self, overwrite: bool = True) -> dict[str, Any]: - if self.database_path.exists() and not overwrite: - raise ValueError(f"Graph database already exists: {self.database_path}") - export = LadybugGraphExporter(self.source_root, embedding_provider=self.embedding_provider).build_export() - self.staging_dir.mkdir(parents=True, exist_ok=True) - store = LadybugGraphStore(self.database_path) - store.write_export(export) - return {"database_path": str(self.database_path), "summary": export.summary()} - - def ensure_current(self) -> dict[str, Any]: - status = self.status() - if status.stale: - return self.materialize(overwrite=True) - return {"database_path": str(self.database_path), "summary": {"status": "current"}} - - def schema(self) -> dict[str, Any]: - return schema_payload() - - def search(self, query: str, limit: int = 10, refresh: bool = True, reinforce: bool = True) -> dict[str, Any]: - if refresh: - self.ensure_current() - graph = self._read_graph() - terms = _terms(query) - scored: list[tuple[float, dict[str, Any]]] = [] - for node in graph.get("nodes", []): - haystack = _node_text(node) - score = sum(2.0 if term in str(node.get("label", "")).lower() else 1.0 for term in terms if term in haystack) - if score > 0 or not terms: - item = _compact_node(node) - item["score"] = score - scored.append((score, item)) - items = [item for _, item in sorted(scored, key=lambda pair: (-pair[0], pair[1].get("id", "")))[:limit]] - return { - "query": query, - "items": items, - "count": len(items), - "database_path": str(self.database_path), - "retrieval": "lexical_graph", - } - - def context( - self, - query: str, - *, - kind: str | None = None, - profile: str = "dependencies", - limit: int = 3, - max_depth: int = 1, - budget: int = 600, - include_raw: bool = False, - refresh: bool = True, - ) -> dict[str, Any]: - if refresh: - self.ensure_current() - return build_compact_graph_context( - self._read_graph(), - query, - kind=kind, - profile=profile, - limit=limit, - max_depth=max_depth, - budget=budget, - include_raw=include_raw, - ) - - def cypher(self, query: str, parameters: dict[str, Any] | None = None, refresh: bool = True) -> dict[str, Any]: - if refresh: - self.ensure_current() - if not _is_read_only_query(query): - raise ValueError("Only read-only MATCH queries are supported") - graph = self._read_graph() - return _run_simple_match_query(graph, query, parameters or {}) - - def _read_graph(self) -> dict[str, Any]: - return LadybugGraphStore(self.database_path).read_export() - -def main(argv: Sequence[str] | None = None) -> int: - argv = _normalize_global_args(list(sys.argv[1:] if argv is None else argv)) - parser = argparse.ArgumentParser(description="Query or rebuild a generic codebase graph.") - parser.add_argument("--source-root", default=".") - parser.add_argument("--state-dir", default=None) - parser.add_argument("--db-path", default=None) - parser.add_argument("--staging-dir", default=None) - subparsers = parser.add_subparsers(dest="command", required=True) - subparsers.add_parser("status") - materialize_parser = subparsers.add_parser("materialize") - materialize_parser.add_argument("--no-overwrite", action="store_true") - subparsers.add_parser("schema") - search_parser = subparsers.add_parser("search") - search_parser.add_argument("query") - search_parser.add_argument("--limit", type=int, default=10) - search_parser.add_argument("--no-refresh", action="store_true") - context_parser = subparsers.add_parser("context") - context_parser.add_argument("query") - context_parser.add_argument("--kind") - context_parser.add_argument("--profile", default="dependencies") - context_parser.add_argument("--limit", type=int, default=3) - context_parser.add_argument("--max-depth", type=int, default=1) - context_parser.add_argument("--budget", type=int, default=600) - context_parser.add_argument("--include-raw", action="store_true") - context_parser.add_argument("--no-refresh", action="store_true") - cypher_parser = subparsers.add_parser("cypher") - cypher_parser.add_argument("query") - cypher_parser.add_argument("--params-json", default="{}") - cypher_parser.add_argument("--no-refresh", action="store_true") - args = parser.parse_args(argv) - core = CodebaseGraph( - source_root=args.source_root, - state_dir=args.state_dir, - database_path=args.db_path, - staging_dir=args.staging_dir, - ) - if args.command == "status": - payload = core.status().as_dict() - elif args.command == "schema": - payload = core.schema() - elif args.command == "materialize": - payload = core.materialize(overwrite=not args.no_overwrite) - elif args.command == "search": - payload = core.search(args.query, limit=args.limit, refresh=not args.no_refresh) - elif args.command == "context": - payload = core.context( - args.query, - kind=args.kind, - profile=args.profile, - limit=args.limit, - max_depth=args.max_depth, - budget=args.budget, - include_raw=args.include_raw, - refresh=not args.no_refresh, - ) - elif args.command == "cypher": - params = json.loads(args.params_json) - if not isinstance(params, dict): - raise ValueError("--params-json must decode to an object") - payload = core.cypher(args.query, parameters=params, refresh=not args.no_refresh) - else: - parser.error(f"unsupported command: {args.command}") - print(json.dumps(payload, indent=2, sort_keys=True)) - return 0 - - -def _normalize_global_args(argv: list[str]) -> list[str]: - value_flags = {"--source-root", "--state-dir", "--db-path", "--staging-dir"} - extracted: list[str] = [] - remaining: list[str] = [] - index = 0 - while index < len(argv): - item = argv[index] - if item in value_flags and index + 1 < len(argv): - extracted.extend([item, argv[index + 1]]) - index += 2 - continue - remaining.append(item) - index += 1 - return extracted + remaining - -def _source_mtimes(source_root: Path) -> list[float]: - mtimes: list[float] = [] - paths = { - *_iter_indexable_files(source_root, CODE_EXTENSIONS), - *_iter_indexable_files(source_root, {".md", ".txt", ".rst", ".toml"}, case_insensitive_suffixes=True), - } - for path in sorted(paths): - try: - mtimes.append(path.stat().st_mtime) - except OSError: - continue - return mtimes - -def _terms(query: str) -> list[str]: - return [term for term in re.split(r"[^a-zA-Z0-9_]+", query.lower()) if term] - -def _node_text(node: dict[str, Any]) -> str: - return " ".join( - str(node.get(field, "")) for field in ("id", "table", "label", "kind", "path", "qualified_name", "summary") - ).lower() - -def _compact_node(node: dict[str, Any]) -> dict[str, Any]: - return { - "id": node.get("id"), - "table": node.get("table"), - "label": node.get("label"), - "kind": node.get("kind"), - "path": node.get("path"), - "qualified_name": node.get("qualified_name"), - "line_start": node.get("line_start"), - "summary": node.get("summary"), - } - -def _is_read_only_query(query: str) -> bool: - lowered = query.strip().lower() - return lowered.startswith("match ") and not any( - token in lowered for token in (" create ", " merge ", " delete ", " set ", " drop ", " copy ", " load ") - ) - -def _run_simple_match_query(graph: dict[str, Any], query: str, parameters: dict[str, Any]) -> dict[str, Any]: - match = re.search(r"MATCH\s*\(\s*(\w+)\s*:\s*(\w+)\s*\)", query, flags=re.IGNORECASE) - if not match: - raise ValueError("Only simple MATCH (n:Label) queries are supported") - variable, table = match.group(1), match.group(2) - where = re.search(r"WHERE\s+(.+?)\s+RETURN", query, flags=re.IGNORECASE | re.DOTALL) - return_match = re.search(r"RETURN\s+(.+?)(?:\s+LIMIT\s+(\d+))?\s*$", query, flags=re.IGNORECASE | re.DOTALL) - if not return_match: - raise ValueError("Query must include RETURN") - columns = [column.strip() for column in return_match.group(1).split(",")] - limit = int(return_match.group(2) or 100) - rows: list[dict[str, Any]] = [] - for node in graph.get("nodes", []): - if node.get("table") != table: - continue - if where and not _where_matches(node, variable, where.group(1), parameters): - continue - row: dict[str, Any] = {} - for column in columns: - if column == variable: - row[column] = node - elif column.startswith(f"{variable}."): - field = column.split(".", 1)[1] - row[column] = node.get(field) - else: - row[column] = node.get(column) - rows.append(row) - if len(rows) >= limit: - break - return {"query": query, "columns": columns, "rows": rows, "count": len(rows), "database_path": graph.get("database_path")} - -def _where_matches(node: dict[str, Any], variable: str, expression: str, parameters: dict[str, Any]) -> bool: - equals = re.match(rf"{re.escape(variable)}\.(\w+)\s*=\s*(.+)$", expression.strip()) - if not equals: - return True - field, raw_value = equals.group(1), equals.group(2).strip() - if raw_value.startswith("$"): - expected = parameters.get(raw_value[1:]) - else: - expected = raw_value.strip("\'\"") - return node.get(field) == expected - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/codebase_graph/ingest/__init__.py b/src/codebase_graph/ingest/__init__.py new file mode 100644 index 0000000..3aef5f0 --- /dev/null +++ b/src/codebase_graph/ingest/__init__.py @@ -0,0 +1,39 @@ +"""Repository, documentation, issue, and tool-output ingestion.""" + +from .materializer import ( + GraphMaterializer, + ManifestDiff, + ManifestEntry, + MaterializationManifest, + MaterializationResult, + MaterializeMode, + SourceSnapshot, +) +from .tree_sitter_parser import ( + ParserRegistry, + ParserRegistration, + ParserUnavailableError, + SourceParser, + TreeSitterPythonParser, + default_parser_registry, + parser_for_language, +) +from .document_parser import MarkdownDocumentParser + +__all__ = [ + "GraphMaterializer", + "MarkdownDocumentParser", + "ManifestDiff", + "ManifestEntry", + "MaterializationManifest", + "MaterializationResult", + "MaterializeMode", + "ParserRegistry", + "ParserRegistration", + "ParserUnavailableError", + "SourceParser", + "SourceSnapshot", + "TreeSitterPythonParser", + "default_parser_registry", + "parser_for_language", +] diff --git a/src/codebase_graph/ingest/document_parser.py b/src/codebase_graph/ingest/document_parser.py new file mode 100644 index 0000000..5c75e4c --- /dev/null +++ b/src/codebase_graph/ingest/document_parser.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from codebase_graph.extract import ParseBundle + +HEADING_RE = re.compile(r"^(#{1,6})\s+(.+?)\s*$") + + +@dataclass(frozen=True, slots=True) +class MarkdownDocumentParser: + language: str = "markdown" + parser_version: str = "markdown-docs-v1" + + def parse_file( + self, + path: Path, + *, + relative_path: str, + source_root: Path, + repository_label: str, + content_hash: str, + ) -> ParseBundle: + source_text = path.read_text(encoding="utf-8") + return ParseBundle( + language=self.language, + path=relative_path, + source_text=source_text, + captures=_document_captures(relative_path, source_text), + repository_label=repository_label, + source_root=source_root.as_posix(), + content_hash=content_hash, + ) + + +def _document_captures(path: str, source_text: str) -> tuple[dict[str, Any], ...]: + lines = source_text.splitlines() + total_lines = max(len(lines), 1) + captures: list[dict[str, Any]] = [ + { + "capture_name": "doc.source", + "node": { + "type": "DocumentationSource", + "name": path, + "line_start": 1, + "line_end": total_lines, + "text": _summary(source_text), + }, + } + ] + for index, section in enumerate(_sections(lines), start=1): + label = section.heading or f"{path} section {index}" + captures.append( + { + "capture_name": "doc.chunk", + "node": { + "type": "DocumentationChunk", + "name": label, + "heading": section.heading, + "level": section.level, + "line_start": section.line_start, + "line_end": section.line_end, + "text": _summary(section.text), + }, + } + ) + return tuple(captures) + + +@dataclass(frozen=True, slots=True) +class _Section: + heading: str + level: int + line_start: int + line_end: int + text: str + + +def _sections(lines: list[str]) -> tuple[_Section, ...]: + headings: list[tuple[int, int, str]] = [] + for line_number, line in enumerate(lines, start=1): + match = HEADING_RE.match(line) + if match is None: + continue + headings.append((line_number, len(match.group(1)), match.group(2).strip())) + + if not headings: + text = "\n".join(lines).strip() + return (_Section("", 0, 1, max(len(lines), 1), text),) if text else () + + sections: list[_Section] = [] + for index, (line_start, level, heading) in enumerate(headings): + line_end = headings[index + 1][0] - 1 if index + 1 < len(headings) else len(lines) + text = "\n".join(lines[line_start - 1 : line_end]).strip() + if text: + sections.append(_Section(heading, level, line_start, line_end, text)) + return tuple(sections) + + +def _summary(text: str) -> str: + return text.strip()[:2000] diff --git a/src/codebase_graph/ingest/materializer.py b/src/codebase_graph/ingest/materializer.py new file mode 100644 index 0000000..96d7d36 --- /dev/null +++ b/src/codebase_graph/ingest/materializer.py @@ -0,0 +1,730 @@ +from __future__ import annotations + +import hashlib +import json +import os +import tempfile +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Literal + +from codebase_graph.core import CodeGraph +from codebase_graph.db import LadybugCodeGraphStore, create_ladybug_database +from codebase_graph.diagnostics import log_event +from codebase_graph.extract import GraphBuilder +from codebase_graph.ontology import ONTOLOGY_NAME +from codebase_graph.paths import DEFAULT_STATE_DIR, derive_graph_state_paths + +from .tree_sitter_parser import ParserRegistry, ParserUnavailableError, default_parser_registry + +MaterializeMode = Literal["full", "changed"] + +MANIFEST_SCHEMA_VERSION = 1 +PARSER_VERSION = "tree-sitter-python-v1+markdown-docs-v1" +EXCLUDED_PARTS = { + ".git", + ".venv", + ".cache", + ".coverage", + ".hypothesis", + ".mypy_cache", + ".nox", + ".pyre", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".tox", + ".vscode", + "__pycache__", + "build", + "coverage", + "dist", + "htmlcov", + "node_modules", + "vendor", + ".codebase_graph", + DEFAULT_STATE_DIR, +} + + +@dataclass(frozen=True, slots=True) +class SourceSnapshot: + path: str + absolute_path: Path + content_hash: str + language: str | None + + +@dataclass(frozen=True, slots=True) +class ManifestEntry: + path: str + content_hash: str + language: str + partition_id: str + node_ids: tuple[str, ...] + edge_ids: tuple[str, ...] + node_types: Mapping[str, str] = field(default_factory=dict) + edge_types: Mapping[str, str] = field(default_factory=dict) + materialized_at: str = "" + + @classmethod + def from_dict(cls, payload: Mapping[str, Any]) -> ManifestEntry: + return cls( + path=str(payload["path"]), + content_hash=str(payload["content_hash"]), + language=str(payload["language"]), + partition_id=str(payload["partition_id"]), + node_ids=tuple(str(value) for value in payload.get("node_ids", ())), + edge_ids=tuple(str(value) for value in payload.get("edge_ids", ())), + node_types={str(key): str(value) for key, value in dict(payload.get("node_types", {})).items()}, + edge_types={str(key): str(value) for key, value in dict(payload.get("edge_types", {})).items()}, + materialized_at=str(payload.get("materialized_at", "")), + ) + + def as_dict(self) -> dict[str, Any]: + return { + "path": self.path, + "content_hash": self.content_hash, + "language": self.language, + "partition_id": self.partition_id, + "node_ids": list(self.node_ids), + "edge_ids": list(self.edge_ids), + "node_types": dict(self.node_types), + "edge_types": dict(self.edge_types), + "materialized_at": self.materialized_at, + } + + +@dataclass(frozen=True, slots=True) +class MaterializationManifest: + schema_version: int = MANIFEST_SCHEMA_VERSION + ontology: str = ONTOLOGY_NAME + parser_version: str = PARSER_VERSION + files: Mapping[str, ManifestEntry] = field(default_factory=dict) + + @classmethod + def empty(cls, *, parser_version: str = PARSER_VERSION) -> MaterializationManifest: + return cls(parser_version=parser_version, files={}) + + @classmethod + def load(cls, path: Path) -> MaterializationManifest: + if not path.exists(): + return cls.empty() + with path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + return cls( + schema_version=int(payload.get("schema_version", 0)), + ontology=str(payload.get("ontology", "")), + parser_version=str(payload.get("parser_version", "")), + files={ + str(file_payload["path"]): ManifestEntry.from_dict(file_payload) + for file_payload in payload.get("files", []) + }, + ) + + def as_dict(self) -> dict[str, Any]: + return { + "schema_version": self.schema_version, + "ontology": self.ontology, + "parser_version": self.parser_version, + "files": [entry.as_dict() for entry in sorted(self.files.values(), key=lambda item: item.path)], + } + + def write(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + json.dump(self.as_dict(), handle, indent=2, sort_keys=True) + handle.write("\n") + os.replace(tmp_path, path) + + def is_compatible(self, *, parser_version: str = PARSER_VERSION) -> bool: + return ( + self.schema_version == MANIFEST_SCHEMA_VERSION + and self.ontology == ONTOLOGY_NAME + and self.parser_version == parser_version + ) + + def diff(self, current_files: Mapping[str, SourceSnapshot], *, parser_version: str = PARSER_VERSION) -> ManifestDiff: + if not self.is_compatible(parser_version=parser_version): + return ManifestDiff( + added=tuple(sorted(current_files)), + modified=(), + unchanged=(), + deleted=tuple(sorted(path for path in self.files if path not in current_files)), + force_rebuild=True, + ) + + added: list[str] = [] + modified: list[str] = [] + unchanged: list[str] = [] + for path, snapshot in current_files.items(): + previous = self.files.get(path) + if previous is None: + added.append(path) + elif previous.content_hash != snapshot.content_hash or previous.language != snapshot.language: + modified.append(path) + else: + unchanged.append(path) + + deleted = [path for path in self.files if path not in current_files] + return ManifestDiff( + added=tuple(sorted(added)), + modified=tuple(sorted(modified)), + unchanged=tuple(sorted(unchanged)), + deleted=tuple(sorted(deleted)), + force_rebuild=False, + ) + + +@dataclass(frozen=True, slots=True) +class ManifestDiff: + added: tuple[str, ...] + modified: tuple[str, ...] + unchanged: tuple[str, ...] + deleted: tuple[str, ...] + force_rebuild: bool = False + + @property + def rebuild_paths(self) -> tuple[str, ...]: + return tuple(sorted((*self.added, *self.modified))) + + +@dataclass(frozen=True, slots=True) +class MaterializationResult: + mode: MaterializeMode + scanned: int + rebuilt: int + skipped: int + deleted: int + diagnostics: tuple[str, ...] + manifest_path: str + rebuilt_paths: tuple[str, ...] + skipped_paths: tuple[str, ...] + deleted_paths: tuple[str, ...] + graph_summary: Mapping[str, Any] + + def as_dict(self) -> dict[str, Any]: + return { + "mode": self.mode, + "scanned": self.scanned, + "rebuilt": self.rebuilt, + "skipped": self.skipped, + "deleted": self.deleted, + "diagnostics": list(self.diagnostics), + "manifest_path": self.manifest_path, + "rebuilt_paths": list(self.rebuilt_paths), + "skipped_paths": list(self.skipped_paths), + "deleted_paths": list(self.deleted_paths), + "graph_summary": dict(self.graph_summary), + } + + +class GraphMaterializer: + def __init__( + self, + source_root: str | Path, + db_path: str | Path | None = None, + *, + manifest_path: str | Path | None = None, + include_fts: bool = True, + repository_label: str | None = None, + store: LadybugCodeGraphStore | None = None, + parser_registry: ParserRegistry | None = None, + graph_builder: GraphBuilder | None = None, + ) -> None: + self.source_root = Path(source_root).resolve() + paths = derive_graph_state_paths(self.source_root) + self.state_dir = paths.state_dir + self.db_path = _normalize_db_path(db_path if db_path is not None else paths.db_path) + self.manifest_path = Path(manifest_path) if manifest_path is not None else paths.manifest_path + self.include_fts = include_fts + self.repository_label = repository_label or self.source_root.name or "repository" + self._store = store + self._store_injected = store is not None + self.parser_registry = parser_registry or default_parser_registry() + self.parser_version = self.parser_registry.parser_version + self.builder = graph_builder or GraphBuilder(repository_label=self.repository_label, source_root=self.source_root) + + @property + def store(self) -> LadybugCodeGraphStore: + if self._store is None: + self._store = create_ladybug_database(self.db_path, include_fts=self.include_fts) + return self._store + + @store.setter + def store(self, value: LadybugCodeGraphStore | None) -> None: + self._store = value + self._store_injected = value is not None + + def close(self) -> None: + self._close_store() + + def materialize(self, mode: MaterializeMode = "changed") -> MaterializationResult: + if mode not in {"full", "changed"}: + raise ValueError(f"Unsupported materialization mode: {mode}") + + previous_manifest = self._read_manifest() + snapshots, diagnostics = self._scan_source_files() + supported = {path: snapshot for path, snapshot in snapshots.items() if snapshot.language is not None} + force_atomic_recovery = self._should_force_atomic_recovery() + + if mode == "full" or force_atomic_recovery: + diff = ManifestDiff( + added=tuple(sorted(supported)), + modified=(), + unchanged=(), + deleted=tuple(sorted(previous_manifest.files)), + force_rebuild=True, + ) + if self._can_atomic_rebuild(): + return self._materialize_full_atomic( + mode=mode, + snapshots=snapshots, + diagnostics=diagnostics, + supported=supported, + diff=diff, + ) + self.store.clear_graph() + retained_node_ids: set[str] = set() + retained_edge_ids: set[str] = set() + else: + diff = previous_manifest.diff(supported, parser_version=self.parser_version) + if diff.force_rebuild: + if self._can_atomic_rebuild(): + return self._materialize_full_atomic( + mode=mode, + snapshots=snapshots, + diagnostics=diagnostics, + supported=supported, + diff=diff, + ) + self.store.clear_graph() + retained_node_ids = set() + retained_edge_ids = set() + elif self._can_atomic_rebuild() and _diff_has_changes(diff): + return self._materialize_full_atomic( + mode=mode, + snapshots=snapshots, + diagnostics=diagnostics, + supported=supported, + diff=ManifestDiff( + added=tuple(sorted(supported)), + modified=(), + unchanged=(), + deleted=tuple(sorted(previous_manifest.files)), + force_rebuild=True, + ), + ) + else: + touched_paths = set(diff.rebuild_paths) | set(diff.deleted) + retained_node_ids = _retained_node_ids(previous_manifest, touched_paths) + retained_edge_ids = _retained_edge_ids(previous_manifest, touched_paths) + for path in diff.deleted: + self.store.delete_partition( + path, + manifest_entry=previous_manifest.files.get(path), + retained_node_ids=retained_node_ids, + retained_edge_ids=retained_edge_ids, + ) + + rebuilt_entries: dict[str, ManifestEntry] = {} + rebuilt_graphs: dict[str, CodeGraph] = {} + for path in diff.rebuild_paths: + snapshot = supported[path] + graph = self._build_graph(snapshot) + rebuilt_graphs[path] = graph + rebuilt_entries[path] = _manifest_entry(snapshot, graph) + + if not diff.force_rebuild: + for path in diff.rebuild_paths: + self.store.delete_partition( + path, + manifest_entry=previous_manifest.files.get(path), + retained_node_ids=retained_node_ids, + retained_edge_ids=retained_edge_ids, + ) + + if rebuilt_graphs: + self.store.insert_graphs_bulk( + [rebuilt_graphs[path] for path in sorted(rebuilt_graphs)], + skip_node_ids=retained_node_ids, + skip_edge_ids=retained_edge_ids, + ) + + next_files = { + path: entry + for path, entry in previous_manifest.files.items() + if path not in set(diff.deleted) | set(diff.rebuild_paths) + } + next_files.update(rebuilt_entries) + next_manifest = MaterializationManifest(parser_version=self.parser_version, files=next_files) + self._write_manifest(next_manifest) + + return _materialization_result( + mode=mode, + snapshots=snapshots, + diagnostics=diagnostics, + diff=diff, + manifest_path=self.manifest_path, + rebuilt_entries=rebuilt_entries, + next_manifest=next_manifest, + ) + + def _materialize_full_atomic( + self, + *, + mode: MaterializeMode, + snapshots: Mapping[str, SourceSnapshot], + diagnostics: list[str], + supported: Mapping[str, SourceSnapshot], + diff: ManifestDiff, + ) -> MaterializationResult: + target_db_path = _filesystem_db_path(self.db_path) + lock_fd, lock_path = _acquire_materialization_lock(target_db_path) + try: + rebuilt_entries: dict[str, ManifestEntry] = {} + rebuilt_graphs: dict[str, CodeGraph] = {} + for path in diff.rebuild_paths: + snapshot = supported[path] + graph = self._build_graph(snapshot) + rebuilt_graphs[path] = graph + rebuilt_entries[path] = _manifest_entry(snapshot, graph) + + next_manifest = MaterializationManifest(parser_version=self.parser_version, files=rebuilt_entries) + temp_db_path = _temporary_sibling(target_db_path, suffix=".lbug.tmp") + temp_manifest_path = _temporary_sibling(self.manifest_path, suffix=".manifest.tmp") + marker_path = self._rebuild_marker_path + temp_store: LadybugCodeGraphStore | None = None + try: + temp_store = create_ladybug_database(temp_db_path, include_fts=self.include_fts) + if rebuilt_graphs: + temp_store.insert_graphs_bulk([rebuilt_graphs[path] for path in sorted(rebuilt_graphs)]) + temp_store.close() + temp_store = None + + next_manifest.write(temp_manifest_path) + _write_rebuild_marker(marker_path, target_db_path, self.manifest_path) + self._close_store() + _unlink_db_sidecars(target_db_path) + os.replace(temp_db_path, target_db_path) + os.replace(temp_manifest_path, self.manifest_path) + _unlink_db_sidecars(target_db_path) + _unlink_if_exists(marker_path) + self._store = None + except Exception: + if temp_store is not None: + temp_store.close() + _unlink_if_exists(temp_db_path) + _unlink_db_sidecars(temp_db_path) + _unlink_if_exists(temp_manifest_path) + _unlink_if_exists(temp_manifest_path.with_suffix(temp_manifest_path.suffix + ".tmp")) + raise + finally: + _release_materialization_lock(lock_fd, lock_path) + + return _materialization_result( + mode=mode, + snapshots=snapshots, + diagnostics=diagnostics, + diff=diff, + manifest_path=self.manifest_path, + rebuilt_entries=rebuilt_entries, + next_manifest=next_manifest, + ) + + def _read_manifest(self) -> MaterializationManifest: + if self._store_injected and self._store is not None and hasattr(self._store, "read_manifest"): + return self._store.read_manifest(self.manifest_path) + return MaterializationManifest.load(self.manifest_path) + + def _write_manifest(self, manifest: MaterializationManifest) -> None: + if self._store_injected and self._store is not None and hasattr(self._store, "write_manifest"): + self._store.write_manifest(manifest, self.manifest_path) + return + manifest.write(self.manifest_path) + + def _can_atomic_rebuild(self) -> bool: + return not self._store_injected and not _is_memory_db_path(self.db_path) + + def _should_force_atomic_recovery(self) -> bool: + return self._can_atomic_rebuild() and self._rebuild_marker_path.exists() + + @property + def _rebuild_marker_path(self) -> Path: + return self.manifest_path.with_suffix(self.manifest_path.suffix + ".rebuild-pending") + + def _close_store(self) -> None: + if self._store is None: + return + close = getattr(self._store, "close", None) + if callable(close): + close() + self._store = None + + def _scan_source_files(self) -> tuple[dict[str, SourceSnapshot], list[str]]: + snapshots: dict[str, SourceSnapshot] = {} + diagnostics: list[str] = [] + for current_root, dirnames, filenames in os.walk(self.source_root): + dirnames[:] = [name for name in sorted(dirnames) if not _is_excluded_part(name)] + current_path = Path(current_root) + for filename in sorted(filenames): + path = current_path / filename + if _is_excluded(path, self.source_root): + continue + relative_path = path.relative_to(self.source_root).as_posix() + language = self.parser_registry.language_for_path(path) + if language is None: + snapshots[relative_path] = SourceSnapshot( + path=relative_path, + absolute_path=path, + content_hash="", + language=None, + ) + diagnostics.append(f"Skipped unsupported file: {relative_path}") + continue + snapshots[relative_path] = SourceSnapshot( + path=relative_path, + absolute_path=path, + content_hash=_file_hash(path), + language=language, + ) + return snapshots, diagnostics + + def _build_graph(self, snapshot: SourceSnapshot) -> CodeGraph: + if snapshot.language is None: + raise ValueError(f"Cannot build graph for unsupported file: {snapshot.path}") + try: + parser = self.parser_registry.parser_for_language(snapshot.language) + bundle = parser.parse_file( + snapshot.absolute_path, + relative_path=snapshot.path, + source_root=self.source_root, + repository_label=self.repository_label, + content_hash=snapshot.content_hash, + ) + except ParserUnavailableError: + raise + result = self.builder.build_file_graph(bundle) + return result.graph + + +def _is_excluded_part(part: str) -> bool: + return part in EXCLUDED_PARTS or part.endswith(".egg-info") + + +def _normalize_db_path(db_path: str | Path) -> str | Path: + if _is_memory_db_path(db_path): + return ":memory:" + return Path(db_path) + + +def _is_memory_db_path(db_path: str | Path) -> bool: + return str(db_path) == ":memory:" + + +def _filesystem_db_path(db_path: str | Path) -> Path: + if _is_memory_db_path(db_path): + raise ValueError("In-memory databases do not have a filesystem path") + return Path(db_path) + + +def _temporary_sibling(path: Path, *, suffix: str) -> Path: + path.parent.mkdir(parents=True, exist_ok=True) + descriptor, temp_path = tempfile.mkstemp(prefix=f".{path.name}.", suffix=suffix, dir=path.parent) + os.close(descriptor) + os.unlink(temp_path) + return Path(temp_path) + + +def _acquire_materialization_lock(db_path: Path) -> tuple[int, Path]: + lock_path = Path(f"{db_path}.lock") + lock_path.parent.mkdir(parents=True, exist_ok=True) + while True: + try: + descriptor = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + break + except FileExistsError as exc: + if _materialization_lock_is_stale(lock_path): + _unlink_if_exists(lock_path) + log_event( + "materializer.stale_lock_removed", + level="WARNING", + db_path=db_path.as_posix(), + lock_path=lock_path.as_posix(), + ) + continue + log_event( + "materializer.lock_exists", + level="WARNING", + db_path=db_path.as_posix(), + lock_path=lock_path.as_posix(), + ) + raise RuntimeError( + f"codebaseGraph materialization is already in progress for {db_path}. " + f"If no materializer is running, inspect the lock file before removing it: {lock_path}" + ) from exc + payload = { + "created_at": datetime.now(timezone.utc).isoformat(), + "pid": os.getpid(), + "db_path": db_path.as_posix(), + } + try: + os.write(descriptor, (json.dumps(payload, sort_keys=True) + "\n").encode("utf-8")) + except Exception: + os.close(descriptor) + _unlink_if_exists(lock_path) + raise + return descriptor, lock_path + + +def _materialization_lock_is_stale(lock_path: Path) -> bool: + try: + payload = json.loads(lock_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return False + pid = payload.get("pid") if isinstance(payload, dict) else None + if not isinstance(pid, int) or pid <= 0 or pid == os.getpid(): + return False + return not _process_is_running(pid) + + +def _process_is_running(pid: int) -> bool: + try: + os.kill(pid, 0) + except ProcessLookupError: + return False + except PermissionError: + return True + return True + + +def _release_materialization_lock(descriptor: int, lock_path: Path) -> None: + os.close(descriptor) + _unlink_if_exists(lock_path) + + +def _write_rebuild_marker(marker_path: Path, db_path: Path, manifest_path: Path) -> None: + marker_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = marker_path.with_suffix(marker_path.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + json.dump( + { + "created_at": datetime.now(timezone.utc).isoformat(), + "db_path": db_path.as_posix(), + "manifest_path": manifest_path.as_posix(), + }, + handle, + indent=2, + sort_keys=True, + ) + handle.write("\n") + os.replace(tmp_path, marker_path) + + +def _unlink_if_exists(path: Path) -> None: + try: + path.unlink() + except FileNotFoundError: + return + + +def _unlink_db_sidecars(db_path: Path) -> None: + for suffix in (".wal", ".shm", ".shadow"): + _unlink_if_exists(Path(f"{db_path}{suffix}")) + + +def _diff_has_changes(diff: ManifestDiff) -> bool: + return bool(diff.rebuild_paths or diff.deleted) + + +def _is_excluded(path: Path, source_root: Path) -> bool: + parts = path.relative_to(source_root).parts + return any(_is_excluded_part(part) for part in parts) + + +def _file_hash(path: Path) -> str: + digest = hashlib.sha256() + with path.open("rb") as handle: + for chunk in iter(lambda: handle.read(1024 * 1024), b""): + digest.update(chunk) + return digest.hexdigest() + + +def _partition_id(path: str) -> str: + return hashlib.sha1(path.encode("utf-8")).hexdigest()[:20] + + +def _manifest_entry(snapshot: SourceSnapshot, graph: CodeGraph) -> ManifestEntry: + return ManifestEntry( + path=snapshot.path, + content_hash=snapshot.content_hash, + language=snapshot.language or "", + partition_id=_partition_id(snapshot.path), + node_ids=tuple(sorted(graph.nodes)), + edge_ids=tuple(sorted(graph.edges)), + node_types={node_id: node.table for node_id, node in graph.nodes.items()}, + edge_types={edge_id: edge.type for edge_id, edge in graph.edges.items()}, + materialized_at=datetime.now(timezone.utc).isoformat(), + ) + + +def _materialization_result( + *, + mode: MaterializeMode, + snapshots: Mapping[str, SourceSnapshot], + diagnostics: list[str], + diff: ManifestDiff, + manifest_path: Path, + rebuilt_entries: Mapping[str, ManifestEntry], + next_manifest: MaterializationManifest, +) -> MaterializationResult: + unsupported_paths = tuple(path for path, snapshot in snapshots.items() if snapshot.language is None) + skipped_paths = tuple(sorted((*diff.unchanged, *unsupported_paths))) + return MaterializationResult( + mode=mode, + scanned=len(snapshots), + rebuilt=len(rebuilt_entries), + skipped=len(skipped_paths), + deleted=len(diff.deleted), + diagnostics=tuple(diagnostics), + manifest_path=manifest_path.as_posix(), + rebuilt_paths=tuple(sorted(rebuilt_entries)), + skipped_paths=skipped_paths, + deleted_paths=diff.deleted, + graph_summary=_manifest_summary(next_manifest), + ) + + +def _retained_node_ids(manifest: MaterializationManifest, touched_paths: set[str]) -> set[str]: + retained: set[str] = set() + for path, entry in manifest.files.items(): + if path in touched_paths: + continue + retained.update(entry.node_ids) + return retained + + +def _retained_edge_ids(manifest: MaterializationManifest, touched_paths: set[str]) -> set[str]: + retained: set[str] = set() + for path, entry in manifest.files.items(): + if path in touched_paths: + continue + retained.update(entry.edge_ids) + return retained + + +def _manifest_summary(manifest: MaterializationManifest) -> dict[str, int | str]: + node_ids: set[str] = set() + edge_ids: set[str] = set() + for entry in manifest.files.values(): + node_ids.update(entry.node_ids) + edge_ids.update(entry.edge_ids) + return { + "ontology": manifest.ontology, + "partition_count": len(manifest.files), + "node_count": len(node_ids), + "edge_count": len(edge_ids), + } diff --git a/src/codebase_graph/ingest/tree_sitter_parser.py b/src/codebase_graph/ingest/tree_sitter_parser.py new file mode 100644 index 0000000..95d870d --- /dev/null +++ b/src/codebase_graph/ingest/tree_sitter_parser.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +import re +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Protocol + +from codebase_graph.extract import ParseBundle +from .document_parser import MarkdownDocumentParser + + +class ParserUnavailableError(RuntimeError): + pass + + +class SourceParser(Protocol): + language: str + parser_version: str + + def parse_file( + self, + path: Path, + *, + relative_path: str, + source_root: Path, + repository_label: str, + content_hash: str, + ) -> ParseBundle: + ... + + +@dataclass(frozen=True, slots=True) +class ParserRegistration: + language: str + suffixes: tuple[str, ...] + parser_factory: Callable[[], SourceParser] + parser_version: str + + +class ParserRegistry: + def __init__(self, registrations: Mapping[str, ParserRegistration] | None = None) -> None: + self._registrations: dict[str, ParserRegistration] = dict(registrations or {}) + self._suffix_to_language: dict[str, str] = {} + for registration in self._registrations.values(): + self._register_suffixes(registration) + + @property + def parser_version(self) -> str: + return "+".join( + registration.parser_version + for registration in self._registrations.values() + ) + + def register( + self, + language: str, + *, + suffixes: tuple[str, ...], + parser_factory: Callable[[], SourceParser], + parser_version: str, + ) -> None: + registration = ParserRegistration(language, suffixes, parser_factory, parser_version) + self._registrations[language] = registration + self._register_suffixes(registration) + + def language_for_path(self, path: Path) -> str | None: + return self._suffix_to_language.get(path.suffix) + + def parser_for_language(self, language: str) -> SourceParser: + try: + registration = self._registrations[language] + except KeyError as exc: + raise ValueError(f"Unsupported materializer language: {language}") from exc + return registration.parser_factory() + + def _register_suffixes(self, registration: ParserRegistration) -> None: + for suffix in registration.suffixes: + self._suffix_to_language[suffix] = registration.language + + +@dataclass(frozen=True, slots=True) +class TreeSitterPythonParser: + language: str = "python" + parser_version: str = "tree-sitter-python-v1" + + def parse_file( + self, + path: Path, + *, + relative_path: str, + source_root: Path, + repository_label: str, + content_hash: str, + ) -> ParseBundle: + source_text = path.read_text(encoding="utf-8") + return ParseBundle( + language=self.language, + path=relative_path, + source_text=source_text, + tree=self.parse_source(source_text), + repository_label=repository_label, + source_root=source_root.as_posix(), + content_hash=content_hash, + ) + + def parse_source(self, source_text: str) -> dict[str, Any]: + parser = _python_parser() + source_bytes = source_text.encode("utf-8") + tree = parser.parse(source_bytes) + return _convert_node(tree.root_node, source_bytes) + + +def default_parser_registry() -> ParserRegistry: + registry = ParserRegistry() + registry.register( + "python", + suffixes=(".py",), + parser_factory=TreeSitterPythonParser, + parser_version=TreeSitterPythonParser().parser_version, + ) + registry.register( + "markdown", + suffixes=(".md", ".mdx"), + parser_factory=MarkdownDocumentParser, + parser_version=MarkdownDocumentParser().parser_version, + ) + return registry + + +def parser_for_language(language: str) -> SourceParser: + return default_parser_registry().parser_for_language(language) + + +def _python_parser() -> Any: + try: + from tree_sitter import Language, Parser + import tree_sitter_python + except ImportError as exc: + raise ParserUnavailableError( + "Tree-sitter Python parsing requires `tree-sitter` and `tree-sitter-python`." + ) from exc + + raw_language = tree_sitter_python.language() + try: + language = Language(raw_language) + except TypeError: + language = raw_language + + parser = Parser() + if hasattr(parser, "set_language"): + parser.set_language(language) + else: + parser.language = language + return parser + + +def _convert_node(node: Any, source_bytes: bytes, decorators: tuple[dict[str, Any], ...] = ()) -> dict[str, Any]: + if node.type == "decorated_definition": + converted_decorators = tuple( + _convert_node(child, source_bytes) + for child in _named_children(node) + if child.type == "decorator" + ) + for child in _named_children(node): + if child.type in {"class_definition", "function_definition"}: + return _convert_node(child, source_bytes, converted_decorators) + + converted: dict[str, Any] = { + "type": node.type, + "text": _node_text(node, source_bytes), + "line_start": _line_start(node), + "line_end": _line_end(node), + "byte_start": node.start_byte, + "byte_end": node.end_byte, + } + + if node.type == "module": + converted["children"] = [_convert_node(child, source_bytes) for child in _named_children(node)] + elif node.type == "class_definition": + converted.update(_class_fields(node, source_bytes, decorators)) + elif node.type == "function_definition": + converted.update(_function_fields(node, source_bytes, decorators)) + elif node.type in {"import_statement", "import_from_statement"}: + converted.update(_import_fields(node, source_bytes)) + elif node.type == "call": + converted.update(_call_fields(node, source_bytes)) + elif node.type == "assignment": + converted.update(_assignment_fields(node, source_bytes)) + elif node.type in {"identifier", "type_identifier"}: + converted["id"] = _node_text(node, source_bytes) + elif node.type == "attribute": + converted.update(_attribute_fields(node, source_bytes)) + elif node.type in {"string", "integer", "float", "true", "false", "none"}: + converted["value"] = _literal_value(node, source_bytes) + + converted.setdefault("children", [_convert_node(child, source_bytes) for child in _semantic_children(node)]) + return converted + + +def _class_fields( + node: Any, + source_bytes: bytes, + decorators: tuple[dict[str, Any], ...], +) -> dict[str, Any]: + fields: dict[str, Any] = {"name": _field_text(node, "name", source_bytes)} + if decorators: + fields["decorator_list"] = list(decorators) + body = node.child_by_field_name("body") + fields["children"] = [_convert_node(child, source_bytes) for child in _named_children(body)] + return fields + + +def _function_fields( + node: Any, + source_bytes: bytes, + decorators: tuple[dict[str, Any], ...], +) -> dict[str, Any]: + fields: dict[str, Any] = {"name": _field_text(node, "name", source_bytes)} + parameters = node.child_by_field_name("parameters") + if parameters is not None: + fields["args"] = {"type": "arguments", "args": [_parameter_node(child, source_bytes) for child in _named_children(parameters)]} + return_type = node.child_by_field_name("return_type") + if return_type is not None: + fields["returns"] = _convert_node(return_type, source_bytes) + if decorators: + fields["decorator_list"] = list(decorators) + body = node.child_by_field_name("body") + fields["children"] = [_convert_node(child, source_bytes) for child in _named_children(body)] + return fields + + +def _parameter_node(node: Any, source_bytes: bytes) -> dict[str, Any]: + text = _node_text(node, source_bytes) + name = text.split(":", 1)[0].split("=", 1)[0].strip().lstrip("*") + parameter: dict[str, Any] = { + "type": "arg", + "arg": name, + "text": text, + "line_start": _line_start(node), + "line_end": _line_end(node), + "byte_start": node.start_byte, + "byte_end": node.end_byte, + } + if ":" in text: + annotation = text.split(":", 1)[1].split("=", 1)[0].strip() + if annotation: + parameter["annotation"] = {"type": "type", "id": annotation, "text": annotation} + return parameter + + +def _import_fields(node: Any, source_bytes: bytes) -> dict[str, Any]: + text = _node_text(node, source_bytes).strip() + if node.type == "import_from_statement": + match = re.match(r"from\s+([.\w]+)\s+import\s+(.+)", text) + if match: + names = [_import_alias(name) for name in match.group(2).split(",")] + return {"module": match.group(1), "names": names} + if node.type == "import_statement": + imported = text.removeprefix("import").strip() + return {"names": [_import_alias(name) for name in imported.split(",")]} + return {} + + +def _import_alias(raw_name: str) -> dict[str, str]: + name = raw_name.strip().split(" as ", 1)[0].strip() + return {"type": "alias", "name": name} + + +def _call_fields(node: Any, source_bytes: bytes) -> dict[str, Any]: + function = node.child_by_field_name("function") + if function is not None: + return {"func": _convert_node(function, source_bytes)} + text = _node_text(node, source_bytes) + return {"func": {"type": "identifier", "id": text.split("(", 1)[0].strip()}} + + +def _assignment_fields(node: Any, source_bytes: bytes) -> dict[str, Any]: + left = node.child_by_field_name("left") + right = node.child_by_field_name("right") + fields: dict[str, Any] = {} + if left is not None: + fields["target"] = _convert_node(left, source_bytes) + if right is not None: + fields["value"] = _convert_node(right, source_bytes) + return fields + + +def _attribute_fields(node: Any, source_bytes: bytes) -> dict[str, Any]: + text = _node_text(node, source_bytes) + if "." not in text: + return {"id": text} + base, attr = text.rsplit(".", 1) + return {"value": {"type": "identifier", "id": base}, "attr": attr} + + +def _literal_value(node: Any, source_bytes: bytes) -> str: + return _node_text(node, source_bytes).strip("'\"") + + +def _semantic_children(node: Any) -> tuple[Any, ...]: + ignored = {"identifier", "type_identifier", "parameters", "decorator", "block"} + return tuple(child for child in _named_children(node) if child.type not in ignored) + + +def _named_children(node: Any | None) -> tuple[Any, ...]: + if node is None: + return () + return tuple(getattr(node, "named_children", ()) or ()) + + +def _field_text(node: Any, field_name: str, source_bytes: bytes) -> str: + child = node.child_by_field_name(field_name) + return _node_text(child, source_bytes) if child is not None else "" + + +def _node_text(node: Any, source_bytes: bytes) -> str: + return source_bytes[node.start_byte:node.end_byte].decode("utf-8", errors="replace") + + +def _line_start(node: Any) -> int: + return _point_row(node.start_point) + 1 + + +def _line_end(node: Any) -> int: + return _point_row(node.end_point) + 1 + + +def _point_row(point: Any) -> int: + if hasattr(point, "row"): + return int(point.row) + return int(point[0]) diff --git a/src/codebase_graph/ladybug.py b/src/codebase_graph/ladybug.py deleted file mode 100644 index 01123c3..0000000 --- a/src/codebase_graph/ladybug.py +++ /dev/null @@ -1,77 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -import math -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Protocol - -from .production import GraphExport, ProductionGraphBuilder - -DEFAULT_EMBEDDING_DIMENSIONS = 384 - -class LadybugUnavailableError(RuntimeError): - pass - -class EmbeddingProvider(Protocol): - dimensions: int - - def embed(self, text: str) -> list[float]: - ... - -class HashingEmbeddingProvider: - def __init__(self, dimensions: int = DEFAULT_EMBEDDING_DIMENSIONS) -> None: - self.dimensions = dimensions - - def embed(self, text: str) -> list[float]: - vector = [0.0] * self.dimensions - for token in text.lower().split(): - digest = hashlib.sha1(token.encode("utf-8")).digest() - index = int.from_bytes(digest[:4], "big") % self.dimensions - vector[index] += 1.0 - norm = math.sqrt(sum(value * value for value in vector)) or 1.0 - return [value / norm for value in vector] - -@dataclass(slots=True) -class LadybugGraphExport: - export: GraphExport - embedding_dimensions: int = DEFAULT_EMBEDDING_DIMENSIONS - - def as_dict(self) -> dict[str, Any]: - payload = self.export.as_dict() - payload["embedding_dimensions"] = self.embedding_dimensions - return payload - - def summary(self) -> dict[str, Any]: - payload = self.export.summary() - payload["embedding_dimensions"] = self.embedding_dimensions - return payload - -class LadybugGraphExporter: - def __init__(self, repo_root: str | Path = ".", embedding_provider: EmbeddingProvider | None = None) -> None: - self.repo_root = Path(repo_root) - self.embedding_provider = embedding_provider or HashingEmbeddingProvider() - - def build_export(self) -> LadybugGraphExport: - export = ProductionGraphBuilder(self.repo_root).build_export() - return LadybugGraphExport(export, int(self.embedding_provider.dimensions)) - -class LadybugGraphStore: - def __init__(self, db_path: str | Path) -> None: - self.db_path = Path(db_path) - - def write_export(self, export: LadybugGraphExport) -> None: - self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.db_path.write_text(json.dumps(export.as_dict(), indent=2, sort_keys=True), encoding="utf-8") - - def read_export(self) -> dict[str, Any]: - if not self.db_path.exists(): - return {"ontology": "", "metadata": {}, "nodes": [], "edges": []} - return json.loads(self.db_path.read_text(encoding="utf-8")) - - def ensure_schema(self, embedding_dimensions: int = DEFAULT_EMBEDDING_DIMENSIONS) -> None: - self.db_path.parent.mkdir(parents=True, exist_ok=True) - - def copy_from_staging(self, staging: Any) -> None: - raise LadybugUnavailableError("Staging copy is not implemented for the JSON-backed base store.") diff --git a/src/codebase_graph/markdown.py b/src/codebase_graph/markdown.py deleted file mode 100644 index f2e6deb..0000000 --- a/src/codebase_graph/markdown.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import re -from typing import Any - -FRONTMATTER_RE = re.compile(r"^---\s*\n(.*?)\n---\s*\n", re.DOTALL) -WIKI_LINK_RE = re.compile(r"\[\[([^\]#|]+)(?:#[^\]|]+)?(?:\|([^\]]+))?\]\]") - -def normalize_slug(value: str) -> str: - slug = re.sub(r"[^a-zA-Z0-9]+", "-", value.strip().lower()).strip("-") - return slug or "untitled" - -def extract_wiki_links(markdown: str) -> list[str]: - return [normalize_slug(match.group(1)) for match in WIKI_LINK_RE.finditer(markdown)] - -def parse_markdown(content: str) -> tuple[dict[str, Any], str]: - match = FRONTMATTER_RE.match(content) - if not match: - return {}, content - frontmatter: dict[str, Any] = {} - for line in match.group(1).splitlines(): - if ":" not in line: - continue - key, value = line.split(":", 1) - frontmatter[key.strip()] = value.strip().strip('"') - return frontmatter, content[match.end():] - -def plain_text(markdown: str) -> str: - text = WIKI_LINK_RE.sub(lambda match: match.group(2) or match.group(1), markdown) - text = re.sub(r"`{1,3}[^`]*`{1,3}", " ", text) - text = re.sub(r"[#*_>\-]+", " ", text) - return re.sub(r"\s+", " ", text).strip() diff --git a/src/codebase_graph/mcp/__init__.py b/src/codebase_graph/mcp/__init__.py new file mode 100644 index 0000000..ede6b1a --- /dev/null +++ b/src/codebase_graph/mcp/__init__.py @@ -0,0 +1,5 @@ +"""MCP server surface for codebaseGraph.""" + +from .server import McpGraphServer, handle_tool_call, serve_http, serve_stdio + +__all__ = ["McpGraphServer", "handle_tool_call", "serve_http", "serve_stdio"] diff --git a/src/codebase_graph/mcp/graph_commands.py b/src/codebase_graph/mcp/graph_commands.py new file mode 100644 index 0000000..5af427b --- /dev/null +++ b/src/codebase_graph/mcp/graph_commands.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import argparse +import json +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any + +from codebase_graph.ontology import CONTEXT_PROFILES +from codebase_graph.retrieval import DETAIL_LEVELS + + +MAX_GRAPH_QUERY_LIMIT = 1000 + +PayloadBuilder = Callable[[argparse.Namespace], dict[str, Any]] +ArgumentAdder = Callable[[argparse.ArgumentParser], None] + + +@dataclass(frozen=True, slots=True) +class GraphCommandSpec: + command_name: str + tool_name: str + help: str + description: str + input_schema: dict[str, Any] + add_arguments: ArgumentAdder + payload_from_args: PayloadBuilder + requires_runtime: bool = True + + def tool_spec(self) -> dict[str, Any]: + return { + "name": self.tool_name, + "description": self.description, + "inputSchema": self.input_schema, + } + + +def graph_command_specs() -> tuple[GraphCommandSpec, ...]: + return GRAPH_COMMAND_SPECS + + +def graph_command_names() -> set[str]: + return {spec.command_name for spec in GRAPH_COMMAND_SPECS} + + +def graph_tool_specs() -> list[dict[str, Any]]: + return [spec.tool_spec() for spec in GRAPH_COMMAND_SPECS] + + +def graph_command_spec(command_name: str) -> GraphCommandSpec: + for spec in GRAPH_COMMAND_SPECS: + if spec.command_name == command_name: + return spec + raise KeyError(command_name) + + +def search_arguments_payload(args: argparse.Namespace) -> dict[str, Any]: + payload: dict[str, Any] = { + "limit": args.limit, + "profile": args.profile, + "budget": args.budget, + "context_limit": args.context_limit, + "detail": args.detail, + } + if getattr(args, "query", None): + payload["query"] = args.query + if args.max_depth is not None: + payload["max_depth"] = args.max_depth + return payload + + +def _empty_payload(args: argparse.Namespace) -> dict[str, Any]: + return {} + + +def _architecture_payload(args: argparse.Namespace) -> dict[str, Any]: + payload: dict[str, Any] = {} + if args.group: + payload["group"] = args.group + return payload + + +def _context_payload(args: argparse.Namespace) -> dict[str, Any]: + if not args.query and not (args.node_id and args.node_type): + raise ValueError("graph-context requires a query or both --node-id and --node-type") + if (args.node_id and not args.node_type) or (args.node_type and not args.node_id): + raise ValueError("graph-context explicit lookup requires both --node-id and --node-type") + payload = search_arguments_payload(args) + if args.node_id and args.node_type: + payload["node_id"] = args.node_id + payload["node_type"] = args.node_type + return payload + + +def _query_payload(args: argparse.Namespace) -> dict[str, Any]: + try: + parameters = json.loads(args.parameters) + except json.JSONDecodeError as exc: + raise ValueError(f"graph-query --parameters must be a JSON object: {exc}") from exc + if not isinstance(parameters, dict): + raise ValueError("graph-query --parameters must be a JSON object") + return {"statement": args.statement, "parameters": parameters, "limit": args.limit} + + +def add_json_output_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--pretty", action="store_true", help="Emit indented JSON output") + + +def add_compact_context_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--limit", type=int, default=3, help="Maximum search hits to return") + parser.add_argument("--profile", choices=sorted(CONTEXT_PROFILES), default="brief", help="Context profile") + parser.add_argument("--budget", type=int, default=600, help="Approximate per-hit context character budget") + parser.add_argument("--max-depth", type=int, default=None, help="Override the context profile depth") + parser.add_argument("--context-limit", type=int, default=3, help="Maximum context items per search hit") + parser.add_argument("--detail", choices=sorted(DETAIL_LEVELS), default="standard", help="Output detail level") + parser.add_argument("--format", choices=("json", "block"), default="json", help="Output format") + add_json_output_arguments(parser) + + +def add_runtime_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--repo-root", default=".", help="Repository root containing .codebaseGraph/config.json") + parser.add_argument("--config", default=None, help="Path to .codebaseGraph/config.json") + parser.add_argument("--db", default=None, help="Override LadyBugDB path") + parser.add_argument("--manifest", default=None, help="Override manifest path") + + +def add_graph_compatibility_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--no-refresh", action="store_true", help="Accepted for search/context command parity") + parser.add_argument("--json", action="store_true", help="Accepted for search/context command parity; same as --format json") + + +def _add_graph_health_arguments(parser: argparse.ArgumentParser) -> None: + add_runtime_arguments(parser) + add_json_output_arguments(parser) + + +def _add_graph_search_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("query", help="Search query") + add_compact_context_arguments(parser) + add_runtime_arguments(parser) + add_graph_compatibility_arguments(parser) + + +def _add_graph_context_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("query", nargs="?", help="Search query") + parser.add_argument("--node-id", default=None, help="Explicit graph node id") + parser.add_argument("--node-type", default=None, help="Explicit graph node type") + add_compact_context_arguments(parser) + add_runtime_arguments(parser) + add_graph_compatibility_arguments(parser) + + +def _add_graph_architecture_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("--group", default=None, help="Optional architecture query group") + add_json_output_arguments(parser) + + +def _add_graph_query_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument("statement", help="Read-only graph query statement") + parser.add_argument("--parameters", default="{}", help="JSON object with query parameters") + parser.add_argument("--limit", type=int, default=100, help="Maximum rows to return") + add_runtime_arguments(parser) + add_json_output_arguments(parser) + + +def _object_schema( + properties: dict[str, Any] | None = None, + *, + required: Sequence[str] = (), +) -> dict[str, Any]: + schema: dict[str, Any] = { + "type": "object", + "properties": properties or {}, + "additionalProperties": False, + } + if required: + schema["required"] = list(required) + return schema + + +def _search_schema(*, required: Sequence[str]) -> dict[str, Any]: + return _object_schema( + { + "query": {"type": "string"}, + "limit": {"type": "integer", "minimum": 1}, + "profile": {"type": "string"}, + "budget": {"type": "integer", "minimum": 0}, + "max_depth": {"type": "integer", "minimum": 0}, + "context_limit": {"type": "integer", "minimum": 0}, + "detail": {"type": "string", "enum": sorted(DETAIL_LEVELS)}, + "output_format": {"type": "string", "enum": ["json", "block"]}, + "node_id": {"type": "string"}, + "node_type": {"type": "string"}, + }, + required=required, + ) + + +GRAPH_COMMAND_SPECS = ( + GraphCommandSpec( + command_name="graph-health", + tool_name="graph_health", + help="Check configured graph paths", + description="Check the configured codebaseGraph database path and manifest path.", + input_schema=_object_schema(), + add_arguments=_add_graph_health_arguments, + payload_from_args=_empty_payload, + ), + GraphCommandSpec( + command_name="graph-search", + tool_name="graph_search", + help="Search the code graph with compact context", + description="Search code, documentation, paths, and dependencies with compact graph context.", + input_schema=_search_schema(required=("query",)), + add_arguments=_add_graph_search_arguments, + payload_from_args=search_arguments_payload, + ), + GraphCommandSpec( + command_name="graph-context", + tool_name="graph_context", + help="Return compact graph context", + description="Return compact context for a search query or explicit node_id/node_type pair.", + input_schema=_search_schema(required=()), + add_arguments=_add_graph_context_arguments, + payload_from_args=_context_payload, + ), + GraphCommandSpec( + command_name="graph-schema", + tool_name="graph_schema", + help="Return ontology schema, indexes, profiles, and helpers", + description="Return ontology schema, search indexes, context profiles, and query helper metadata.", + input_schema=_object_schema(), + add_arguments=add_json_output_arguments, + payload_from_args=_empty_payload, + requires_runtime=False, + ), + GraphCommandSpec( + command_name="graph-query-helpers", + tool_name="graph_query_helpers", + help="Return named read-only graph query helpers", + description="Return named read-only query helpers for common graph exploration tasks.", + input_schema=_object_schema(), + add_arguments=add_json_output_arguments, + payload_from_args=_empty_payload, + requires_runtime=False, + ), + GraphCommandSpec( + command_name="graph-architecture-queries", + tool_name="graph_architecture_queries", + help="Return the architecture-discovery query catalog", + description="Return the grouped architecture-discovery Cypher catalog for coding-agent first-step orientation.", + input_schema=_object_schema( + { + "group": { + "type": "string", + "description": "Optional architecture query group to return.", + }, + } + ), + add_arguments=_add_graph_architecture_arguments, + payload_from_args=_architecture_payload, + requires_runtime=False, + ), + GraphCommandSpec( + command_name="graph-query", + tool_name="graph_query", + help="Execute a restricted read-only graph query", + description="Execute a restricted read-only graph query against the configured database.", + input_schema=_object_schema( + { + "statement": {"type": "string"}, + "parameters": {"type": "object"}, + "limit": {"type": "integer", "minimum": 1, "maximum": MAX_GRAPH_QUERY_LIMIT}, + }, + required=("statement",), + ), + add_arguments=_add_graph_query_arguments, + payload_from_args=_query_payload, + ), +) + diff --git a/src/codebase_graph/mcp/protocol.py b/src/codebase_graph/mcp/protocol.py new file mode 100644 index 0000000..e963755 --- /dev/null +++ b/src/codebase_graph/mcp/protocol.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from codebase_graph.paths import MCP_SERVER_NAME + +from .runtime import GraphRuntimeConfig, package_version +from .tools import UnknownToolError, call_tool_result, tool_specs + +SUPPORTED_PROTOCOL_VERSIONS = ("2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05") +LATEST_PROTOCOL_VERSION = SUPPORTED_PROTOCOL_VERSIONS[0] + + +@dataclass(slots=True) +class ProtocolSession: + protocol_version: str | None = None + initialized: bool = False + + +class McpGraphServer: + def __init__(self, runtime: GraphRuntimeConfig) -> None: + self.runtime = runtime + self.session = ProtocolSession() + + @classmethod + def from_paths( + cls, + *, + repo_root: str = ".", + config_path: str | None = None, + db_path: str | None = None, + manifest_path: str | None = None, + ) -> McpGraphServer: + from .runtime import runtime_config + + runtime = runtime_config( + repo_root=repo_root, + config_path=config_path, + db_path=db_path, + manifest_path=manifest_path, + ) + return cls(runtime) + + def handle_json_rpc(self, message: dict[str, Any]) -> dict[str, Any] | None: + method = str(message.get("method", "")) + request_id = message.get("id") + if method == "notifications/initialized": + self.session.initialized = True + return None + if method.startswith("notifications/"): + return None + if method in {"tools/list", "tools/call"} and self.session.protocol_version is None: + return rpc_error(request_id, -32002, "MCP session is not initialized") + try: + if method == "initialize": + result = self._initialize(dict(message.get("params") or {})) + elif method == "ping": + result = {} + elif method == "tools/list": + result = {"tools": tool_specs()} + elif method == "tools/call": + result = self._call_tool(dict(message.get("params") or {})) + else: + return rpc_error(request_id, -32601, f"Unsupported MCP method: {method}") + except UnknownToolError as exc: + return rpc_error(request_id, -32602, str(exc)) + except ValueError as exc: + return rpc_error(request_id, -32602, str(exc)) + except Exception as exc: + return rpc_error(request_id, -32000, str(exc)) + return {"jsonrpc": "2.0", "id": request_id, "result": result} + + def _initialize(self, params: dict[str, Any]) -> dict[str, Any]: + requested = str(params.get("protocolVersion") or "") + protocol_version = negotiate_protocol_version(requested) + self.session.protocol_version = protocol_version + return { + "protocolVersion": protocol_version, + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": MCP_SERVER_NAME, "version": package_version()}, + } + + def _call_tool(self, params: dict[str, Any]) -> dict[str, Any]: + return call_tool_result( + str(params.get("name", "")), + dict(params.get("arguments") or {}), + runtime=self.runtime, + ) + + +def negotiate_protocol_version(requested: str) -> str: + if requested in SUPPORTED_PROTOCOL_VERSIONS: + return requested + return LATEST_PROTOCOL_VERSION + + +def rpc_error(request_id: Any, code: int, message: str, data: dict[str, Any] | None = None) -> dict[str, Any]: + error: dict[str, Any] = {"code": code, "message": message} + if data is not None: + error["data"] = data + return {"jsonrpc": "2.0", "id": request_id, "error": error} diff --git a/src/codebase_graph/mcp/runtime.py b/src/codebase_graph/mcp/runtime.py new file mode 100644 index 0000000..70f44ce --- /dev/null +++ b/src/codebase_graph/mcp/runtime.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from dataclasses import dataclass +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any + +from codebase_graph.db import LadybugCodeGraphStore, create_ladybug_database +from codebase_graph.setup.state import derive_setup_paths, load_setup_config + + +@dataclass(frozen=True, slots=True) +class GraphRuntimeConfig: + repo_root: Path + db_path: Path + manifest_path: Path | None = None + + +def runtime_config( + *, + repo_root: str | Path, + config_path: str | Path | None, + db_path: str | Path | None, + manifest_path: str | Path | None, +) -> GraphRuntimeConfig: + root = Path(repo_root).expanduser().resolve() + config = Path(config_path).expanduser().resolve() if config_path else derive_setup_paths(root).config_path + payload: dict[str, Any] = {} + if config.exists(): + payload = load_setup_config(config) + root = Path(str(payload["repo_root"])).expanduser().resolve() + elif db_path is None: + raise FileNotFoundError(f"codebaseGraph setup config is missing: {config}") + resolved_db = Path(db_path or payload["database_path"]).expanduser().resolve() + resolved_manifest = ( + Path(manifest_path or payload.get("manifest_path", "")).expanduser().resolve() + if (manifest_path or payload.get("manifest_path")) + else None + ) + if not resolved_db.exists(): + raise FileNotFoundError(f"codebaseGraph database is missing: {resolved_db}") + return GraphRuntimeConfig(repo_root=root, db_path=resolved_db, manifest_path=resolved_manifest) + + +def open_graph_store(runtime: GraphRuntimeConfig) -> LadybugCodeGraphStore: + return create_ladybug_database(runtime.db_path, include_fts=True, read_only=True) + + +def package_version() -> str: + try: + return version("codebase-graph") + except PackageNotFoundError: + return "0.1.0" diff --git a/src/codebase_graph/mcp/server.py b/src/codebase_graph/mcp/server.py new file mode 100644 index 0000000..e1ffff5 --- /dev/null +++ b/src/codebase_graph/mcp/server.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from .protocol import LATEST_PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS, McpGraphServer, negotiate_protocol_version +from .runtime import GraphRuntimeConfig +from .tools import handle_tool_call +from .transports.http import build_http_server, serve_http +from .transports.stdio import serve_stdio + + +def main() -> int: + import argparse + + parser = argparse.ArgumentParser(prog="codebase-graph-mcp") + parser.add_argument("--repo-root", default=".", help="Repository root containing .codebaseGraph/config.json") + parser.add_argument("--config", default=None, help="Path to .codebaseGraph/config.json") + parser.add_argument("--db", default=None, help="Override LadyBugDB path") + parser.add_argument("--manifest", default=None, help="Override manifest path") + args = parser.parse_args() + serve_stdio(repo_root=args.repo_root, config_path=args.config, db_path=args.db, manifest_path=args.manifest) + return 0 + + +__all__ = [ + "LATEST_PROTOCOL_VERSION", + "SUPPORTED_PROTOCOL_VERSIONS", + "GraphRuntimeConfig", + "McpGraphServer", + "build_http_server", + "handle_tool_call", + "negotiate_protocol_version", + "serve_http", + "serve_stdio", +] diff --git a/src/codebase_graph/mcp/tools.py b/src/codebase_graph/mcp/tools.py new file mode 100644 index 0000000..0d98722 --- /dev/null +++ b/src/codebase_graph/mcp/tools.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import json +import re +from typing import Any + +from codebase_graph.db import LadybugCodeGraphStore +from codebase_graph.diagnostics import log_event +from codebase_graph.ontology import QUERY_HELPERS, schema_payload +from codebase_graph.reasoning import CompactContextBuilder, architecture_query_catalog +from codebase_graph.retrieval import DETAIL_LEVELS, SearchRequest, SearchService, serialize_graph_block + +from .graph_commands import MAX_GRAPH_QUERY_LIMIT, graph_tool_specs +from .runtime import GraphRuntimeConfig, open_graph_store + +READ_ONLY_DENY_RE = re.compile( + r"\b(" + r"ALTER|ATTACH|CALL|COPY|CREATE|DELETE|DETACH|DROP|EXPORT|IMPORT|INSERT|INSTALL|LOAD|MERGE|REMOVE|RENAME|SET|" + r"TRUNCATE|UPDATE|USE" + r")\b", + re.IGNORECASE, +) + + +class UnknownToolError(ValueError): + pass + + +def handle_tool_call(name: str, arguments: dict[str, Any], *, runtime: GraphRuntimeConfig | None) -> dict[str, Any]: + if name == "graph_health": + return _health(runtime) + if name == "graph_schema": + return schema_payload() + if name == "graph_query_helpers": + return {"query_helpers": [helper.as_dict() for helper in QUERY_HELPERS]} + if name == "graph_architecture_queries": + return architecture_query_catalog(group=_optional_str(arguments.get("group"))) + if name == "graph_search": + with open_graph_store(_require_runtime(runtime, name)) as store: + request = _search_request(arguments) + return SearchService(store).search(request).as_dict(detail=request.detail) + if name == "graph_context": + with open_graph_store(_require_runtime(runtime, name)) as store: + return _context_payload(store, arguments) + if name == "graph_query": + with open_graph_store(_require_runtime(runtime, name)) as store: + return _query_payload(store, arguments) + raise UnknownToolError(f"Unknown codebaseGraph MCP tool: {name}") + + +def call_tool_result(name: str, arguments: dict[str, Any], *, runtime: GraphRuntimeConfig) -> dict[str, Any]: + try: + payload = handle_tool_call(name, arguments, runtime=runtime) + return tool_result(name, payload, arguments) + except UnknownToolError: + raise + except Exception as exc: + return tool_error_result(name, exc) + + +def _require_runtime(runtime: GraphRuntimeConfig | None, tool_name: str) -> GraphRuntimeConfig: + if runtime is None: + raise ValueError(f"{tool_name} requires a graph runtime") + return runtime + + +def tool_result(name: str, payload: dict[str, Any], arguments: dict[str, Any] | None = None) -> dict[str, Any]: + text = json.dumps(payload, separators=(",", ":"), sort_keys=True) + if name in {"graph_search", "graph_context"} and _output_format(arguments or {}) == "block": + text = serialize_graph_block(payload) + return { + "content": [{"type": "text", "text": text}], + "structuredContent": payload, + "isError": False, + } + + +def tool_error_result(name: str, exc: Exception) -> dict[str, Any]: + log_event( + "mcp.tool_error", + level="WARNING", + tool=name, + error_type=exc.__class__.__name__, + message=str(exc), + ) + payload = { + "error": { + "tool": name, + "type": exc.__class__.__name__, + "message": str(exc), + } + } + return { + "content": [{"type": "text", "text": f"{name} failed: {exc}"}], + "structuredContent": payload, + "isError": True, + } + + +def tool_specs() -> list[dict[str, Any]]: + return graph_tool_specs() + + +def _health(runtime: GraphRuntimeConfig) -> dict[str, Any]: + payload: dict[str, Any] = { + "ok": False, + "repo_root": runtime.repo_root.as_posix(), + "database_path": runtime.db_path.as_posix(), + "manifest_path": runtime.manifest_path.as_posix() if runtime.manifest_path else None, + "database_exists": runtime.db_path.exists(), + "manifest_exists": runtime.manifest_path.exists() if runtime.manifest_path else None, + } + if not runtime.db_path.exists(): + return payload + try: + with open_graph_store(runtime) as store: + rows = store.execute("MATCH (n) RETURN count(n) AS total_nodes LIMIT 1").get_n(1) + except Exception as exc: + payload["graph_readable"] = False + payload["error"] = {"type": exc.__class__.__name__, "message": str(exc)} + log_event( + "mcp.graph_health_failed", + level="WARNING", + database_path=runtime.db_path.as_posix(), + error_type=exc.__class__.__name__, + message=str(exc), + ) + return payload + payload["ok"] = True + payload["graph_readable"] = True + payload["total_nodes"] = _json_safe(rows[0][0]) if rows and rows[0] else 0 + return payload + + +def _search_request(arguments: dict[str, Any]) -> SearchRequest: + request = SearchRequest( + query=str(arguments.get("query", "")), + limit=int(arguments.get("limit", 3)), + profile=str(arguments.get("profile", "brief")), + budget=int(arguments.get("budget", 600)), + max_depth=_optional_int(arguments.get("max_depth")), + context_limit=int(arguments.get("context_limit", 3)), + detail=_detail(arguments), + ) + request.validate() + return request + + +def _context_payload(store: LadybugCodeGraphStore, arguments: dict[str, Any]) -> dict[str, Any]: + node_id = str(arguments.get("node_id") or "") + node_type = str(arguments.get("node_type") or "") + if node_id and node_type: + profile = str(arguments.get("profile", "brief")) + detail = _detail(arguments) + context = CompactContextBuilder(store).build( + node_id, + node_type, + profile=profile, + limit=int(arguments.get("limit", 3)), + budget=int(arguments.get("budget", 600)), + max_depth=_optional_int(arguments.get("max_depth")), + ) + return { + "node_id": node_id, + "node_type": node_type, + "profile": profile, + "context": [node.as_dict(detail=detail) for node in context], + } + request = _search_request(arguments) + return SearchService(store).search(request).as_dict(detail=request.detail) + + +def _query_payload(store: LadybugCodeGraphStore, arguments: dict[str, Any]) -> dict[str, Any]: + statement = str(arguments.get("statement") or arguments.get("query") or "").strip() + if not statement: + raise ValueError("graph_query requires a non-empty statement") + _validate_read_only_statement(statement) + parameters = arguments.get("parameters") or {} + if not isinstance(parameters, dict): + raise ValueError("graph_query parameters must be a JSON object") + limit = _graph_query_limit(arguments) + result = store.execute(statement, parameters) + try: + rows = result.get_n(limit + 1) + finally: + close = getattr(result, "close", None) + if callable(close): + close() + visible_rows = rows[:limit] + return { + "statement": statement, + "row_count": len(visible_rows), + "rows": [_row_values(row) for row in visible_rows], + "truncated": len(rows) > limit, + } + + +def _validate_read_only_statement(statement: str) -> None: + stripped = statement.strip().rstrip(";") + if ";" in stripped: + raise ValueError("graph_query accepts one read-only statement at a time") + match = READ_ONLY_DENY_RE.search(stripped) + if match is not None: + raise ValueError(f"graph_query is read-only; blocked keyword: {match.group(1).upper()}") + + +def _graph_query_limit(arguments: dict[str, Any]) -> int: + limit = int(arguments.get("limit", 100)) + if limit <= 0: + raise ValueError("graph_query limit must be greater than zero") + if limit > MAX_GRAPH_QUERY_LIMIT: + raise ValueError(f"graph_query limit must be {MAX_GRAPH_QUERY_LIMIT} or less") + return limit + + +def _row_values(row: Any) -> list[Any]: + try: + return [_json_safe(value) for value in row] + except TypeError: + return [_json_safe(row)] + + +def _json_safe(value: Any) -> Any: + if value is None or isinstance(value, (str, int, float, bool)): + return value + if isinstance(value, (list, tuple)): + return [_json_safe(item) for item in value] + if isinstance(value, dict): + return {str(key): _json_safe(item) for key, item in value.items()} + return str(value) + + +def _optional_int(value: Any) -> int | None: + if value is None or value == "": + return None + return int(value) + + +def _optional_str(value: Any) -> str | None: + if value is None or value == "": + return None + return str(value) + + +def _detail(arguments: dict[str, Any]) -> str: + detail = str(arguments.get("detail", "standard")) + if detail not in DETAIL_LEVELS: + valid = ", ".join(sorted(DETAIL_LEVELS)) + raise ValueError(f"Unknown detail level: {detail}. Valid levels: {valid}") + return detail + + +def _output_format(arguments: dict[str, Any]) -> str: + output_format = str(arguments.get("output_format", "json")) + if output_format not in {"json", "block"}: + raise ValueError(f"Unknown output format: {output_format}. Valid formats: block, json") + return output_format diff --git a/src/codebase_graph/mcp/transports/__init__.py b/src/codebase_graph/mcp/transports/__init__.py new file mode 100644 index 0000000..8d63218 --- /dev/null +++ b/src/codebase_graph/mcp/transports/__init__.py @@ -0,0 +1 @@ +"""Transport implementations for the codebaseGraph MCP server.""" diff --git a/src/codebase_graph/mcp/transports/http.py b/src/codebase_graph/mcp/transports/http.py new file mode 100644 index 0000000..af1db03 --- /dev/null +++ b/src/codebase_graph/mcp/transports/http.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import secrets +import json +from http import HTTPStatus +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path +from typing import Any +from urllib.parse import urlparse + +from codebase_graph.diagnostics import log_event +from codebase_graph.mcp.protocol import SUPPORTED_PROTOCOL_VERSIONS, McpGraphServer, rpc_error +from codebase_graph.mcp.runtime import GraphRuntimeConfig, runtime_config + +LOCAL_ORIGINS = {"localhost", "127.0.0.1", "::1"} +MAX_HTTP_BODY_BYTES = 1_000_000 + + +class McpHttpServer(ThreadingHTTPServer): + def __init__(self, server_address: tuple[str, int], handler: type[BaseHTTPRequestHandler]) -> None: + super().__init__(server_address, handler) + self.mcp_runtime: GraphRuntimeConfig + self.mcp_sessions: dict[str, McpGraphServer] + self.endpoint_path: str + self.auth_token: str | None + + +def build_http_server( + *, + repo_root: str | Path = ".", + config_path: str | Path | None = None, + db_path: str | Path | None = None, + manifest_path: str | Path | None = None, + host: str = "127.0.0.1", + port: int = 8765, + endpoint_path: str = "/mcp", + allow_remote: bool = False, + auth_token: str | None = None, +) -> McpHttpServer: + if auth_token is not None and not auth_token.strip(): + raise ValueError("MCP HTTP auth token must not be blank") + if allow_remote and auth_token is None: + log_event("mcp.http_remote_bind_rejected", level="WARNING", host=host, port=port) + raise ValueError("MCP HTTP remote bind requires an auth token") + if not allow_remote and host not in LOCAL_ORIGINS: + log_event("mcp.http_remote_bind_rejected", level="WARNING", host=host, port=port) + raise ValueError("MCP HTTP transport may only bind to localhost unless allow_remote is enabled") + graph_runtime = runtime_config( + repo_root=repo_root, + config_path=config_path, + db_path=db_path, + manifest_path=manifest_path, + ) + httpd = McpHttpServer((host, port), _McpHttpHandler) + httpd.mcp_runtime = graph_runtime + httpd.mcp_sessions = {} + httpd.endpoint_path = endpoint_path + httpd.auth_token = auth_token + return httpd + + +def serve_http( + *, + repo_root: str | Path = ".", + config_path: str | Path | None = None, + db_path: str | Path | None = None, + manifest_path: str | Path | None = None, + host: str = "127.0.0.1", + port: int = 8765, + endpoint_path: str = "/mcp", + allow_remote: bool = False, + auth_token: str | None = None, +) -> None: + server = build_http_server( + repo_root=repo_root, + config_path=config_path, + db_path=db_path, + manifest_path=manifest_path, + host=host, + port=port, + endpoint_path=endpoint_path, + allow_remote=allow_remote, + auth_token=auth_token, + ) + try: + server.serve_forever() + finally: + server.server_close() + + +class _McpHttpHandler(BaseHTTPRequestHandler): + server: McpHttpServer + + def do_POST(self) -> None: + if not self._request_path_matches() or not self._valid_origin() or not self._valid_auth(): + return + if not self._valid_protocol_header(): + return + length = self._content_length() + if length is None: + return + try: + message = json.loads(self.rfile.read(length).decode("utf-8")) + except Exception as exc: + log_event("mcp.http_parse_error", level="WARNING", message=str(exc), client_address=self.client_address[0]) + self._send_json(rpc_error(None, -32700, f"Invalid JSON-RPC payload: {exc}"), status=HTTPStatus.BAD_REQUEST) + return + if not isinstance(message, dict): + self._send_json(rpc_error(None, -32600, "JSON-RPC payload must be an object"), status=HTTPStatus.BAD_REQUEST) + return + session_id, server = self._resolve_session(message) + if server is None: + return + response = server.handle_json_rpc(message) + if response is None: + self.send_response(HTTPStatus.ACCEPTED) + self.end_headers() + return + headers = {"Mcp-Session-Id": session_id} if str(message.get("method", "")) == "initialize" else None + self._send_json(response, headers=headers) + + def do_GET(self) -> None: + if not self._request_path_matches() or not self._valid_origin() or not self._valid_auth(): + return + self.send_response(HTTPStatus.METHOD_NOT_ALLOWED) + self.send_header("Allow", "POST") + self.end_headers() + + def log_message(self, format: str, *args: Any) -> None: + return + + def _resolve_session(self, message: dict[str, Any]) -> tuple[str, McpGraphServer] | tuple[None, None]: + method = str(message.get("method", "")) + request_id = message.get("id") + session_id = self.headers.get("Mcp-Session-Id") + if method == "initialize": + if session_id and session_id in self.server.mcp_sessions: + return session_id, self.server.mcp_sessions[session_id] + session_id = secrets.token_urlsafe(32) + server = McpGraphServer(self.server.mcp_runtime) + self.server.mcp_sessions[session_id] = server + return session_id, server + if not session_id or session_id not in self.server.mcp_sessions: + self._send_json( + rpc_error(request_id, -32002, "MCP session is not initialized"), + status=HTTPStatus.BAD_REQUEST, + ) + return None, None + return session_id, self.server.mcp_sessions[session_id] + + def _request_path_matches(self) -> bool: + if urlparse(self.path).path == self.server.endpoint_path: + return True + self._send_json(rpc_error(None, -32601, "MCP endpoint not found"), status=HTTPStatus.NOT_FOUND) + return False + + def _valid_origin(self) -> bool: + origin = self.headers.get("Origin") + if not origin: + return True + hostname = urlparse(origin).hostname + if hostname in LOCAL_ORIGINS: + return True + log_event( + "mcp.http_forbidden_origin", + level="WARNING", + origin=origin, + client_address=self.client_address[0], + ) + self._send_json(rpc_error(None, -32000, "Forbidden origin"), status=HTTPStatus.FORBIDDEN) + return False + + def _valid_auth(self) -> bool: + if self.server.auth_token is None: + return True + authorization = self.headers.get("Authorization", "") + prefix = "Bearer " + if authorization.startswith(prefix) and secrets.compare_digest(authorization[len(prefix) :], self.server.auth_token): + return True + log_event( + "mcp.http_unauthorized", + level="WARNING", + client_address=self.client_address[0], + ) + self._send_json( + rpc_error(None, -32000, "Unauthorized"), + status=HTTPStatus.UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + ) + return False + + def _valid_protocol_header(self) -> bool: + requested = self.headers.get("MCP-Protocol-Version") + if requested is None: + return True + if requested in SUPPORTED_PROTOCOL_VERSIONS: + return True + log_event( + "mcp.http_unsupported_protocol", + level="WARNING", + requested=requested, + client_address=self.client_address[0], + ) + self._send_json( + rpc_error( + None, + -32602, + "Unsupported MCP protocol version", + {"supported": list(SUPPORTED_PROTOCOL_VERSIONS), "requested": requested}, + ), + status=HTTPStatus.BAD_REQUEST, + ) + return False + + def _content_length(self) -> int | None: + raw_length = self.headers.get("Content-Length", "0") + try: + length = int(raw_length) + except ValueError: + log_event( + "mcp.http_invalid_content_length", + level="WARNING", + content_length=raw_length, + client_address=self.client_address[0], + ) + self._send_json(rpc_error(None, -32600, "Content-Length must be an integer"), status=HTTPStatus.BAD_REQUEST) + return None + if length < 0: + log_event( + "mcp.http_invalid_content_length", + level="WARNING", + content_length=raw_length, + client_address=self.client_address[0], + ) + self._send_json(rpc_error(None, -32600, "Content-Length must be non-negative"), status=HTTPStatus.BAD_REQUEST) + return None + if length > MAX_HTTP_BODY_BYTES: + log_event( + "mcp.http_body_too_large", + level="WARNING", + content_length=length, + client_address=self.client_address[0], + ) + self._send_json( + rpc_error(None, -32000, "MCP request body is too large", {"max_bytes": MAX_HTTP_BODY_BYTES}), + status=HTTPStatus.REQUEST_ENTITY_TOO_LARGE, + ) + return None + return length + + def _send_json( + self, + payload: dict[str, Any], + *, + status: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> None: + body = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + for name, value in (headers or {}).items(): + self.send_header(name, value) + self.end_headers() + self.wfile.write(body) diff --git a/src/codebase_graph/mcp/transports/stdio.py b/src/codebase_graph/mcp/transports/stdio.py new file mode 100644 index 0000000..8736841 --- /dev/null +++ b/src/codebase_graph/mcp/transports/stdio.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import json +import sys +from pathlib import Path +from typing import Any, BinaryIO + +from codebase_graph.diagnostics import log_event +from codebase_graph.mcp.protocol import McpGraphServer, rpc_error + + +class StdioMessageError(ValueError): + pass + + +def serve_stdio( + *, + repo_root: str | Path = ".", + config_path: str | Path | None = None, + db_path: str | Path | None = None, + manifest_path: str | Path | None = None, +) -> None: + server = McpGraphServer.from_paths( + repo_root=repo_root, + config_path=config_path, + db_path=db_path, + manifest_path=manifest_path, + ) + while True: + try: + message = read_message(sys.stdin.buffer) + except StdioMessageError as exc: + log_event("mcp.stdio_parse_error", level="WARNING", message=str(exc)) + write_message(sys.stdout.buffer, rpc_error(None, -32700, f"Invalid JSON-RPC payload: {exc}")) + continue + if message is None: + return + response = server.handle_json_rpc(message) + if response is not None: + write_message(sys.stdout.buffer, response) + + +def read_message(stream: BinaryIO) -> dict[str, Any] | None: + line = stream.readline() + if not line: + return None + if line.lower().startswith(b"content-length:"): + try: + length = int(line.split(b":", 1)[1].strip()) + except ValueError as exc: + raise StdioMessageError("Content-Length must be an integer") from exc + if length < 0: + raise StdioMessageError("Content-Length must be non-negative") + while True: + header = stream.readline() + if header in {b"\r\n", b"\n", b""}: + break + body = stream.read(length) + if len(body) != length: + raise StdioMessageError("Body ended before Content-Length bytes were read") + return _json_rpc_payload(body) + return _json_rpc_payload(line) + + +def write_message(stream: BinaryIO, message: dict[str, Any]) -> None: + body = json.dumps(message, separators=(",", ":"), sort_keys=True).encode("utf-8") + stream.write(f"Content-Length: {len(body)}\r\n\r\n".encode("ascii")) + stream.write(body) + stream.flush() + + +def _json_rpc_payload(data: bytes) -> dict[str, Any]: + try: + payload = json.loads(data.decode("utf-8")) + except UnicodeDecodeError as exc: + raise StdioMessageError(f"Body must be UTF-8: {exc}") from exc + except json.JSONDecodeError as exc: + raise StdioMessageError(str(exc)) from exc + if not isinstance(payload, dict): + raise StdioMessageError("JSON-RPC payload must be an object") + return payload diff --git a/src/codebase_graph/memory/__init__.py b/src/codebase_graph/memory/__init__.py new file mode 100644 index 0000000..7b38a52 --- /dev/null +++ b/src/codebase_graph/memory/__init__.py @@ -0,0 +1 @@ +"""Memory store, recall, update, and consolidation workflows.""" diff --git a/src/codebase_graph/ontology.py b/src/codebase_graph/ontology.py deleted file mode 100644 index 320ec88..0000000 --- a/src/codebase_graph/ontology.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -ONTOLOGY_NAME = "codebase_graph_v1" - -NODE_TABLES = ( - "Project", - "Repository", - "File", - "DocumentationSource", - "PythonModule", - "PythonClass", - "PythonFunction", - "PythonMethod", - "Import", - "Call", - "Dependency", - "EntryPoint", - "Test", - "Risk", - "Verification", -) - -EDGE_NODE_TABLES = ( - "Contains", - "Imports", - "Calls", - "DependsOn", - "Defines", - "CoveredBy", - "RoutesTo", - "Describes", - "Produces", -) - -TABLE_COLUMNS: dict[str, tuple[str, ...]] = { - table: ( - "id", - "label", - "kind", - "path", - "qualified_name", - "module_name", - "line_start", - "line_end", - "summary", - "metadata", - ) - for table in NODE_TABLES -} - -VECTOR_INDEXES: tuple[tuple[str, str, str], ...] = () -FTS_INDEXES: tuple[tuple[str, str, tuple[str, ...]], ...] = ( - ("File", "idx_file_text", ("label", "path", "summary")), - ("PythonClass", "idx_python_class_text", ("label", "qualified_name", "summary")), - ("PythonFunction", "idx_python_function_text", ("label", "qualified_name", "summary")), - ("PythonMethod", "idx_python_method_text", ("label", "qualified_name", "summary")), - ("DocumentationSource", "idx_documentation_source_text", ("label", "path", "summary")), -) - -def schema_payload() -> dict[str, object]: - return { - "ontology": ONTOLOGY_NAME, - "node_tables": list(NODE_TABLES), - "edge_tables": list(EDGE_NODE_TABLES), - "table_columns": {name: list(columns) for name, columns in TABLE_COLUMNS.items()}, - "vector_indexes": [ - {"table": table, "index": index, "column": column} - for table, index, column in VECTOR_INDEXES - ], - "fts_indexes": [ - {"table": table, "index": index, "columns": list(columns)} - for table, index, columns in FTS_INDEXES - ], - "examples": [ - "MATCH (n:PythonClass) RETURN n.id, n.label, n.qualified_name LIMIT 5", - "MATCH (n:File) RETURN n.path, n.summary LIMIT 10", - "MATCH (n:EntryPoint) RETURN n.label, n.kind, n.path LIMIT 5", - ], - } diff --git a/src/codebase_graph/ontology/__init__.py b/src/codebase_graph/ontology/__init__.py new file mode 100644 index 0000000..dbb9244 --- /dev/null +++ b/src/codebase_graph/ontology/__init__.py @@ -0,0 +1,47 @@ +"""Language-neutral code graph ontology.""" + +from .ontology import ( + COMMON_NODE_FIELDS, + CONTEXT_PROFILES, + EDGE_FIELDS, + ONTOLOGY_NAME, + ONTOLOGY_VERSION, + NODE_TYPES, + PARSER_NODE_MAPPINGS, + QUERY_HELPERS, + RELATION_TYPES, + SEARCH_INDEXES, + FieldSpec, + NodeTypeSpec, + ParserNodeMappingSpec, + QueryHelperSpec, + RelationTypeSpec, + get_node_type, + get_relation_type, + node_type_names, + relation_type_names, + schema_payload, +) + +__all__ = [ + "COMMON_NODE_FIELDS", + "CONTEXT_PROFILES", + "EDGE_FIELDS", + "ONTOLOGY_NAME", + "ONTOLOGY_VERSION", + "NODE_TYPES", + "PARSER_NODE_MAPPINGS", + "QUERY_HELPERS", + "RELATION_TYPES", + "SEARCH_INDEXES", + "FieldSpec", + "NodeTypeSpec", + "ParserNodeMappingSpec", + "QueryHelperSpec", + "RelationTypeSpec", + "get_node_type", + "get_relation_type", + "node_type_names", + "relation_type_names", + "schema_payload", +] diff --git a/src/codebase_graph/ontology/ontology.py b/src/codebase_graph/ontology/ontology.py new file mode 100644 index 0000000..d7e9b50 --- /dev/null +++ b/src/codebase_graph/ontology/ontology.py @@ -0,0 +1,800 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +ONTOLOGY_NAME = "code_ontology_v1" +ONTOLOGY_VERSION = "1.0.0" + + +@dataclass(frozen=True, slots=True) +class FieldSpec: + name: str + value_type: str + description: str + required: bool = False + + def as_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "type": self.value_type, + "description": self.description, + "required": self.required, + } + + +@dataclass(frozen=True, slots=True) +class NodeTypeSpec: + name: str + description: str + fields: tuple[FieldSpec, ...] = () + parser_node_types: tuple[str, ...] = () + constraints: tuple[str, ...] = () + + def as_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "fields": [field.as_dict() for field in self.fields], + "parser_node_types": list(self.parser_node_types), + "constraints": list(self.constraints), + } + + +@dataclass(frozen=True, slots=True) +class RelationTypeSpec: + name: str + source_types: tuple[str, ...] + target_types: tuple[str, ...] + description: str + fields: tuple[FieldSpec, ...] = () + constraints: tuple[str, ...] = () + + def as_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "source_types": list(self.source_types), + "target_types": list(self.target_types), + "description": self.description, + "fields": [field.as_dict() for field in self.fields], + "constraints": list(self.constraints), + } + + +@dataclass(frozen=True, slots=True) +class ParserNodeMappingSpec: + name: str + parser_node_types: tuple[str, ...] + captures: tuple[str, ...] + target_node_types: tuple[str, ...] + relation_types: tuple[str, ...] + description: str + context_rule: str = "" + + def as_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "parser_node_types": list(self.parser_node_types), + "captures": list(self.captures), + "target_node_types": list(self.target_node_types), + "relation_types": list(self.relation_types), + "description": self.description, + "context_rule": self.context_rule, + } + + +@dataclass(frozen=True, slots=True) +class QueryHelperSpec: + name: str + description: str + query: str + parameters: tuple[str, ...] = () + returns: tuple[str, ...] = () + + def as_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "query": self.query, + "parameters": list(self.parameters), + "returns": list(self.returns), + } + + +COMMON_NODE_FIELDS = ( + FieldSpec("id", "string", "Stable unique node identifier.", True), + FieldSpec("label", "string", "Short human-readable node label.", True), + FieldSpec("kind", "string", "Ontology-specific subtype or parser-derived role."), + FieldSpec("language", "string", "Source language when the node is code-derived."), + FieldSpec("path", "string", "Repository-relative source path."), + FieldSpec("qualified_name", "string", "Best-effort language-neutral qualified name."), + FieldSpec("scope_id", "string", "Containing lexical or semantic scope id."), + FieldSpec("line_start", "integer", "One-based start line in the source file."), + FieldSpec("line_end", "integer", "One-based end line in the source file."), + FieldSpec("byte_start", "integer", "Zero-based start byte in the source file."), + FieldSpec("byte_end", "integer", "Zero-based end byte in the source file."), + FieldSpec("tree_sitter_node_type", "string", "Raw parser node type that produced this node."), + FieldSpec("capture_name", "string", "Tree-sitter query capture name when available."), + FieldSpec("summary", "string", "Compact text summary used for context assembly."), + FieldSpec("metadata", "json", "Structured extractor-specific details."), +) + +EDGE_FIELDS = ( + FieldSpec("id", "string", "Stable unique relation identifier.", True), + FieldSpec("kind", "string", "Relation subtype or evidence role."), + FieldSpec("source_id", "string", "Source node id.", True), + FieldSpec("target_id", "string", "Target node id.", True), + FieldSpec("confidence", "number", "Extractor confidence between 0 and 1."), + FieldSpec("line_start", "integer", "One-based evidence start line."), + FieldSpec("line_end", "integer", "One-based evidence end line."), + FieldSpec("byte_start", "integer", "Zero-based evidence start byte."), + FieldSpec("byte_end", "integer", "Zero-based evidence end byte."), + FieldSpec("metadata", "json", "Structured relation evidence and resolver details."), +) + + +def _node( + name: str, + description: str, + *, + parser_node_types: tuple[str, ...] = (), + fields: tuple[FieldSpec, ...] = (), + constraints: tuple[str, ...] = (), +) -> NodeTypeSpec: + return NodeTypeSpec( + name=name, + description=description, + fields=COMMON_NODE_FIELDS + fields, + parser_node_types=parser_node_types, + constraints=constraints, + ) + + +NODE_TYPES = ( + _node("Repository", "A version-controlled repository or source tree boundary."), + _node("SourceRoot", "A configured root scanned for source, docs, manifests, and generated evidence."), + _node( + "File", + "A source, manifest, configuration, or documentation file.", + fields=( + FieldSpec("content_hash", "string", "Hash of file content at extraction time."), + FieldSpec("size_bytes", "integer", "File size in bytes at extraction time."), + ), + ), + _node( + "Module", + "A language-level compilation or namespace unit derived from a source file.", + parser_node_types=("module", "program", "source_file", "Module"), + ), + _node( + "ImportDeclaration", + "An import/include/use/require declaration.", + parser_node_types=( + "import_statement", + "import_from_statement", + "import_declaration", + "Import", + "ImportFrom", + ), + fields=(FieldSpec("imported_name", "string", "Imported module, package, symbol, or path."),), + ), + _node( + "ExportDeclaration", + "An exported symbol or module boundary declaration.", + parser_node_types=("export_statement", "export_clause", "export_declaration"), + ), + _node("Symbol", "A named code artifact when the exact semantic subtype is unresolved."), + _node("Scope", "A lexical or semantic boundary for name resolution."), + _node( + "Class", + "A class, struct, trait, interface, enum class, or similar type container.", + parser_node_types=("class_definition", "class_declaration", "struct_item", "ClassDef"), + ), + _node( + "Function", + "A standalone function, lambda with stable name, or callable declaration.", + parser_node_types=("function_definition", "function_declaration", "FunctionDef"), + ), + _node( + "Method", + "A function declared inside a class, trait, component, or object scope.", + parser_node_types=("method_definition", "method_declaration", "FunctionDef"), + ), + _node("Parameter", "A callable parameter.", parser_node_types=("parameter", "typed_parameter", "arg")), + _node("ReturnType", "A callable return type annotation.", parser_node_types=("return_type", "returns")), + _node( + "TypeAnnotation", + "A type annotation attached to a symbol, parameter, assignment, or return value.", + parser_node_types=("type", "type_identifier", "type_annotation", "annotation"), + ), + _node("TypeAlias", "A named alias for a type expression.", parser_node_types=("type_alias", "type_alias_declaration")), + _node("Variable", "A mutable or local named binding.", parser_node_types=("variable_declaration", "Name")), + _node("Constant", "A named binding treated as stable or immutable by convention or syntax."), + _node("ClassAttribute", "A class-level attribute or static field.", parser_node_types=("AnnAssign", "field_declaration")), + _node("InstanceAttribute", "An instance-level attribute or field assignment."), + _node("Property", "A computed or decorated property exposed as an attribute."), + _node("Decorator", "A decorator, annotation, macro, or attribute attached to a declaration."), + _node("CallExpression", "A call, constructor invocation, message send, or macro invocation.", parser_node_types=("call", "Call")), + _node("Assignment", "An assignment, binding, or destructuring declaration.", parser_node_types=("assignment", "Assign", "AnnAssign")), + _node("Reference", "A name or member reference that may resolve to another node.", parser_node_types=("identifier", "Name")), + _node("Literal", "A literal value from source code.", parser_node_types=("string", "integer", "float", "Constant")), + _node("Expression", "A non-literal expression worth preserving for context or reasoning."), + _node("ControlFlowBlock", "A branch, loop, match, switch, or guard block."), + _node("ExceptionFlow", "A raise, throw, try, catch, except, rescue, or finally flow unit."), + _node("APIEndpoint", "A network, RPC, CLI, event, or message endpoint exposed by code."), + _node("Component", "A UI, service, package, or runtime component represented in source."), + _node("Route", "A route pattern, path binding, or router entry."), + _node("Query", "A database, search, analytics, or graph query string/expression."), + _node("SecretRef", "A reference to a secret, credential, token, key, or sensitive environment variable."), + _node( + "Dependency", + "An external package, library, framework, service, or runtime dependency.", + fields=( + FieldSpec("version", "string", "Declared version or version constraint."), + FieldSpec("ecosystem", "string", "Dependency ecosystem such as pypi, npm, cargo, or go."), + ), + ), + _node("DocumentationSource", "A documentation file or generated documentation artifact."), + _node("DocumentationChunk", "A chunk or heading-level section of documentation."), + _node( + "SyntaxCapture", + "Raw parser evidence preserving the concrete syntax node and capture name.", + fields=( + FieldSpec("sexp", "string", "Optional S-expression or compact parse-tree fragment."), + FieldSpec("text", "string", "Optional source text captured for this syntax node."), + ), + ), +) + + +def _relation( + name: str, + source_types: tuple[str, ...], + target_types: tuple[str, ...], + description: str, + *, + constraints: tuple[str, ...] = (), +) -> RelationTypeSpec: + return RelationTypeSpec( + name=name, + source_types=source_types, + target_types=target_types, + description=description, + fields=EDGE_FIELDS, + constraints=constraints, + ) + + +DECLARATION_NODES = ( + "Symbol", + "Class", + "Function", + "Method", + "Parameter", + "ReturnType", + "TypeAnnotation", + "TypeAlias", + "Variable", + "Constant", + "ClassAttribute", + "InstanceAttribute", + "Property", + "Decorator", + "Assignment", + "APIEndpoint", + "Component", + "Route", + "Query", + "SecretRef", +) + +EXPRESSION_NODES = ( + "CallExpression", + "Assignment", + "Reference", + "Literal", + "Expression", + "ControlFlowBlock", + "ExceptionFlow", + "Query", + "SecretRef", +) + +DOCUMENTATION_NODES = ("DocumentationSource", "DocumentationChunk") + +RELATION_TYPES = ( + _relation( + "Contains", + ("Repository", "SourceRoot", "File", "Module", "Scope", "Class", "Function", "Method", "Component"), + ( + "SourceRoot", + "File", + "Module", + "Scope", + "ImportDeclaration", + "ExportDeclaration", + *DECLARATION_NODES, + *EXPRESSION_NODES, + *DOCUMENTATION_NODES, + ), + "Structural containment between repository, files, scopes, declarations, and syntax-derived units.", + ), + _relation( + "Defines", + ("File", "Module", "Scope", "Class", "Function", "Method", "Component"), + DECLARATION_NODES, + "A file, module, scope, or component defines a semantic code node.", + ), + _relation( + "Imports", + ("File", "Module", "Scope"), + ("ImportDeclaration", "Dependency", "Module", "Symbol"), + "A source unit imports, includes, requires, or uses another unit.", + ), + _relation( + "Exports", + ("File", "Module", "Scope", "Component"), + ("ExportDeclaration", *DECLARATION_NODES), + "A source unit exports a declaration or public surface.", + ), + _relation( + "Declares", + ("Module", "Scope", "Class", "Function", "Method", "Assignment"), + DECLARATION_NODES, + "A declaration site introduces a named symbol or subordinate declaration.", + ), + _relation( + "HasScope", + ("File", "Module", *DECLARATION_NODES, *EXPRESSION_NODES), + ("Scope",), + "Connects a node to the lexical or semantic scope used for resolution.", + ), + _relation( + "HasParameter", + ("Function", "Method", "APIEndpoint", "Route", "CallExpression"), + ("Parameter",), + "Connects callables, endpoints, routes, or calls to their parameters or arguments.", + ), + _relation( + "HasReturnType", + ("Function", "Method", "APIEndpoint"), + ("ReturnType",), + "Connects callables or endpoints to their return type node.", + ), + _relation( + "HasTypeAnnotation", + ("Symbol", "Parameter", "ReturnType", "TypeAlias", "Variable", "Constant", "ClassAttribute", "InstanceAttribute"), + ("TypeAnnotation", "Reference", "Literal"), + "Connects a typed code node to its annotation expression.", + ), + _relation( + "Assigns", + ("Assignment", "Variable", "Constant", "ClassAttribute", "InstanceAttribute", "Property"), + ("Variable", "Constant", "ClassAttribute", "InstanceAttribute", "Property", "Literal", "Expression", "CallExpression"), + "Connects an assignment site or assigned symbol to the target or assigned value.", + ), + _relation( + "References", + ( + "Reference", + "Expression", + "CallExpression", + "Assignment", + "ControlFlowBlock", + "TypeAnnotation", + "Decorator", + "Query", + "SecretRef", + ), + ("Symbol", "Class", "Function", "Method", "Variable", "Constant", "ClassAttribute", "InstanceAttribute", "Property", "Parameter", "Module", "Dependency"), + "A source reference mentions another semantic node without necessarily resolving to it.", + ), + _relation( + "Calls", + ("Function", "Method", "CallExpression", "Decorator", "APIEndpoint", "Route", "Component"), + ("CallExpression", "Function", "Method", "Class", "APIEndpoint"), + "A callable or call expression invokes another callable-like target.", + ), + _relation( + "DecoratedBy", + ("Class", "Function", "Method", "Property", "APIEndpoint", "Route", "Component"), + ("Decorator", "CallExpression", "Reference"), + "A declaration is modified by a decorator, annotation, macro, or framework marker.", + ), + _relation( + "ResolvesTo", + ("Reference", "ImportDeclaration", "CallExpression", "TypeAnnotation", "Decorator"), + ("Symbol", "Module", "Class", "Function", "Method", "Variable", "Constant", "Dependency", "Parameter"), + "A resolver maps a syntactic reference to the best semantic target.", + ), + _relation( + "DependsOn", + ("Repository", "SourceRoot", "File", "Module", "ImportDeclaration", "Dependency", "Component"), + ("Dependency", "Module", "Component", "SecretRef"), + "A repository or code unit depends on an external or internal dependency.", + ), + _relation( + "Documents", + ("DocumentationSource", "DocumentationChunk", "Literal"), + ("Repository", "File", "Module", *DECLARATION_NODES), + "Documentation describes a repository, source unit, or semantic declaration.", + ), + _relation( + "RoutesTo", + ("Route", "APIEndpoint", "Component"), + ("APIEndpoint", "Function", "Method", "Component"), + "A route or component dispatches to an endpoint or handler.", + ), + _relation( + "Exposes", + ("Repository", "Module", "Component", "APIEndpoint", "Route"), + ("APIEndpoint", "Route", "Function", "Method", "Component", "ExportDeclaration"), + "A source unit exposes a public runtime or module surface.", + ), + _relation( + "ExecutesQuery", + ("Function", "Method", "CallExpression", "APIEndpoint", "Component"), + ("Query",), + "A code path executes or constructs a query.", + ), + _relation( + "UsesSecret", + ("Function", "Method", "CallExpression", "Component", "APIEndpoint", "Dependency"), + ("SecretRef",), + "A code path or dependency uses a secret or sensitive configuration value.", + ), + _relation( + "Raises", + ("Function", "Method", "CallExpression", "ControlFlowBlock"), + ("ExceptionFlow",), + "A code path raises or throws an exception flow.", + ), + _relation( + "Handles", + ("Function", "Method", "ControlFlowBlock", "ExceptionFlow"), + ("ExceptionFlow",), + "A code path handles or catches an exception flow.", + ), + _relation( + "DerivedFrom", + (*DECLARATION_NODES, *EXPRESSION_NODES, *DOCUMENTATION_NODES, "Module", "ImportDeclaration", "ExportDeclaration"), + ("SyntaxCapture",), + "A semantic node was derived from a raw parser capture.", + ), + _relation( + "EvidencedBy", + ("Repository", "File", "Module", *DECLARATION_NODES, *EXPRESSION_NODES, "Dependency", *DOCUMENTATION_NODES), + ("SyntaxCapture", "File", "DocumentationChunk"), + "A semantic claim is supported by parser, file, or documentation evidence.", + ), +) + +PARSER_NODE_MAPPINGS = ( + ParserNodeMappingSpec( + "module", + ("module", "program", "source_file", "Module"), + ("module", "source_file"), + ("Module",), + ("Contains", "Defines", "DerivedFrom"), + "Create one Module node per parser root or language namespace root.", + ), + ParserNodeMappingSpec( + "imports", + ("import_statement", "import_from_statement", "import_declaration", "Import", "ImportFrom"), + ("import", "reference.import", "reference.include", "reference.require", "reference.use"), + ("ImportDeclaration",), + ("Imports", "DependsOn", "DerivedFrom"), + "Normalize import-like declarations across languages and attach imported names as metadata.", + ), + ParserNodeMappingSpec( + "exports", + ("export_statement", "export_clause", "export_declaration"), + ("export", "definition.export"), + ("ExportDeclaration",), + ("Exports", "DerivedFrom"), + "Capture public export declarations and declarations marked as exported.", + ), + ParserNodeMappingSpec( + "classes", + ("class_definition", "class_declaration", "struct_item", "interface_declaration", "ClassDef"), + ("definition.class", "definition.struct", "definition.interface"), + ("Class",), + ("Defines", "Declares", "HasScope", "DecoratedBy", "DerivedFrom"), + "Map class-like containers to Class nodes with a child Scope.", + ), + ParserNodeMappingSpec( + "functions_and_methods", + ("function_definition", "function_declaration", "method_definition", "method_declaration", "FunctionDef"), + ("definition.function", "definition.method"), + ("Function", "Method"), + ("Defines", "Declares", "HasScope", "HasParameter", "HasReturnType", "DecoratedBy", "DerivedFrom"), + "Create Function for module-level callables and Method when the callable is enclosed by Class or Component.", + context_rule="enclosing Class or Component changes the target node from Function to Method", + ), + ParserNodeMappingSpec( + "parameters", + ("parameter", "typed_parameter", "default_parameter", "arg"), + ("definition.parameter", "parameter"), + ("Parameter",), + ("HasParameter", "HasTypeAnnotation", "DerivedFrom"), + "Create Parameter nodes for callable parameter declarations.", + ), + ParserNodeMappingSpec( + "return_types", + ("return_type", "type", "type_identifier", "returns"), + ("type.return", "return_type"), + ("ReturnType",), + ("HasReturnType", "HasTypeAnnotation", "References", "DerivedFrom"), + "Capture explicit return type annotations.", + ), + ParserNodeMappingSpec( + "type_annotations", + ("type", "type_identifier", "type_annotation", "annotation", "Name"), + ("type", "type.annotation", "reference.type"), + ("TypeAnnotation",), + ("HasTypeAnnotation", "References", "ResolvesTo", "DerivedFrom"), + "Capture type annotation expressions attached to declarations.", + ), + ParserNodeMappingSpec( + "type_aliases", + ("type_alias", "type_alias_declaration"), + ("definition.type_alias",), + ("TypeAlias",), + ("Defines", "HasTypeAnnotation", "DerivedFrom"), + "Capture named type aliases.", + ), + ParserNodeMappingSpec( + "assignments", + ("assignment", "assignment_expression", "variable_declaration", "Assign", "AnnAssign"), + ("definition.variable", "definition.constant", "assignment"), + ("Assignment", "Variable", "Constant", "ClassAttribute", "InstanceAttribute", "Property"), + ("Defines", "Declares", "Assigns", "HasTypeAnnotation", "DerivedFrom"), + "Normalize assignments; scope, naming convention, and receiver decide variable, constant, or attribute node type.", + ), + ParserNodeMappingSpec( + "decorators", + ("decorator", "attribute_item", "annotation", "Call", "Name"), + ("decorator", "definition.decorator"), + ("Decorator",), + ("DecoratedBy", "Calls", "References", "DerivedFrom"), + "Capture decorators, annotations, macros, or framework markers that modify declarations.", + ), + ParserNodeMappingSpec( + "calls", + ("call", "call_expression", "invocation_expression", "Call"), + ("reference.call", "call"), + ("CallExpression",), + ("Calls", "References", "ResolvesTo", "DerivedFrom"), + "Create call-expression nodes and optionally resolve them to callable targets.", + ), + ParserNodeMappingSpec( + "references", + ("identifier", "field_identifier", "attribute", "Name", "Attribute"), + ("reference", "reference.identifier", "reference.member"), + ("Reference",), + ("References", "ResolvesTo", "DerivedFrom"), + "Capture name and member references before or after semantic resolution.", + ), + ParserNodeMappingSpec( + "literals", + ("string", "integer", "float", "true", "false", "null", "none", "Constant"), + ("literal", "string", "number"), + ("Literal",), + ("Contains", "References", "DerivedFrom"), + "Capture literals that are useful for docs, routes, queries, secrets, or assignment values.", + ), + ParserNodeMappingSpec( + "control_flow", + ("if_statement", "for_statement", "while_statement", "match_statement", "switch_statement"), + ("control_flow",), + ("ControlFlowBlock",), + ("Contains", "References", "DerivedFrom"), + "Capture branch and loop blocks when they affect reasoning or dependency paths.", + ), + ParserNodeMappingSpec( + "exception_flow", + ("try_statement", "except_clause", "catch_clause", "raise_statement", "throw_statement"), + ("exception", "raises", "handles"), + ("ExceptionFlow",), + ("Raises", "Handles", "DerivedFrom"), + "Capture exception raising and handling paths.", + ), + ParserNodeMappingSpec( + "routes_and_endpoints", + ("decorator", "call", "route_declaration", "handler_definition"), + ("entrypoint.api", "route", "endpoint"), + ("APIEndpoint", "Route"), + ("RoutesTo", "Exposes", "DecoratedBy", "DerivedFrom"), + "Create APIEndpoint and Route nodes from framework route declarations or decorated handlers.", + ), + ParserNodeMappingSpec( + "components", + ("class_definition", "function_definition", "jsx_element", "component_declaration"), + ("definition.component", "component"), + ("Component",), + ("Defines", "Contains", "Exposes", "DerivedFrom"), + "Capture UI, service, runtime, or package components when extractor rules identify them.", + ), + ParserNodeMappingSpec( + "queries", + ("string", "template_string", "call", "Call"), + ("query.sql", "query.graph", "query.search"), + ("Query",), + ("ExecutesQuery", "References", "DerivedFrom"), + "Capture query strings or query builder expressions.", + ), + ParserNodeMappingSpec( + "secrets", + ("identifier", "string", "attribute", "Name", "Constant"), + ("secret", "secret.env", "secret.ref"), + ("SecretRef",), + ("UsesSecret", "References", "DerivedFrom"), + "Capture secret-looking names, environment references, keys, and credential handles.", + ), + ParserNodeMappingSpec( + "documentation", + ("comment", "string", "docstring", "DocumentationSource", "DocumentationChunk"), + ("doc", "doc.string", "doc.comment"), + ("DocumentationSource", "DocumentationChunk"), + ("Documents", "EvidencedBy"), + "Capture documentation sources and chunks from docs, comments, and docstrings.", + ), +) + +SEARCH_INDEXES = ( + {"name": "idx_code_symbols", "node_types": ["Symbol", "Class", "Function", "Method", "Variable", "Constant"], "fields": ["label", "qualified_name", "summary"]}, + {"name": "idx_source_units", "node_types": ["Repository", "SourceRoot", "File", "Module"], "fields": ["label", "path", "summary"]}, + {"name": "idx_dependencies", "node_types": ["ImportDeclaration", "Dependency"], "fields": ["label", "qualified_name", "summary"]}, + {"name": "idx_runtime_surface", "node_types": ["APIEndpoint", "Component", "Route", "Query", "SecretRef"], "fields": ["label", "qualified_name", "summary"]}, + {"name": "idx_docs", "node_types": ["DocumentationSource", "DocumentationChunk"], "fields": ["label", "path", "summary"]}, +) + +CONTEXT_PROFILES = { + "brief": { + "description": "Smallest useful context: matched nodes plus direct defining file/module.", + "relations": ["Contains", "Defines", "EvidencedBy"], + "max_depth": 1, + }, + "definitions": { + "description": "Definition-oriented context for symbols and scopes.", + "relations": ["Defines", "Declares", "HasScope", "HasParameter", "HasReturnType", "HasTypeAnnotation"], + "max_depth": 2, + }, + "dependencies": { + "description": "Import and dependency context.", + "relations": ["Imports", "DependsOn", "References", "ResolvesTo"], + "max_depth": 2, + }, + "callgraph": { + "description": "Callable neighborhood for callers, callees, and call expressions.", + "relations": ["Calls", "References", "ResolvesTo"], + "max_depth": 2, + }, + "runtime": { + "description": "Runtime surface context for routes, endpoints, queries, and secrets.", + "relations": ["RoutesTo", "Exposes", "ExecutesQuery", "UsesSecret"], + "max_depth": 2, + }, + "docs": { + "description": "Documentation context connected to code artifacts.", + "relations": ["Documents", "EvidencedBy"], + "max_depth": 1, + }, + "change_impact": { + "description": "Context for likely downstream impact of changing a symbol.", + "relations": ["Defines", "References", "Calls", "RoutesTo", "ExecutesQuery", "UsesSecret", "DependsOn"], + "max_depth": 3, + }, +} + +SYMBOL_LOOKUP_DEFINITION_TYPES = ( + "Class", + "Function", + "Method", + "Variable", + "Constant", + "ClassAttribute", + "InstanceAttribute", + "Property", + "Parameter", + "TypeAlias", +) + +SYMBOL_LOOKUP_QUERY = ( + " UNION ALL ".join( + f"MATCH (s:{node_type}) " + "WHERE s.label = $name OR s.qualified_name = $name " + "RETURN s.id, s.label, s.qualified_name, s.path" + for node_type in SYMBOL_LOOKUP_DEFINITION_TYPES + ) + + " LIMIT 25" +) + +QUERY_HELPERS = ( + QueryHelperSpec( + "repository_overview", + "List high-level source roots, files, modules, dependencies, and runtime surfaces.", + "MATCH (n) WHERE n:SourceRoot OR n:File OR n:Module OR n:Dependency OR n:APIEndpoint OR n:Component RETURN n.id, n.label, n.path LIMIT 100", + returns=("id", "label", "path"), + ), + QueryHelperSpec( + "symbol_lookup", + "Find concrete semantic definitions by label or qualified name.", + SYMBOL_LOOKUP_QUERY, + parameters=("name",), + returns=("id", "label", "qualified_name", "path"), + ), + QueryHelperSpec( + "definition_context", + "Find a named class, function, method, variable, or constant definition.", + "MATCH (d) WHERE d:Class OR d:Function OR d:Method OR d:Variable OR d:Constant RETURN d.id, d.label, d.kind, d.path LIMIT 50", + returns=("id", "label", "kind", "path"), + ), + QueryHelperSpec( + "callgraph_neighborhood", + "Find call expressions and resolved callable targets near a symbol.", + "MATCH (c:CallExpression)-[:FROM_ResolvesTo]->(:ResolvesTo)-[:TO_ResolvesTo]->(target) RETURN c.id, c.path, target.id, target.label LIMIT 50", + returns=("call_id", "path", "target_id", "target_label"), + ), + QueryHelperSpec( + "dependency_map", + "Inspect imports and dependencies.", + "MATCH (i:ImportDeclaration)-[:FROM_DependsOn]->(:DependsOn)-[:TO_DependsOn]->(d:Dependency) RETURN i.id, i.label, d.id, d.label LIMIT 100", + returns=("import_id", "import_label", "dependency_id", "dependency_label"), + ), + QueryHelperSpec( + "runtime_surface", + "Inspect routes, endpoints, executed queries, and secret use.", + "MATCH (r:Route)-[:FROM_RoutesTo]->(:RoutesTo)-[:TO_RoutesTo]->(e:APIEndpoint) RETURN r.id, r.label, e.id, e.label LIMIT 100", + returns=("route_id", "route_label", "endpoint_id", "endpoint_label"), + ), + QueryHelperSpec( + "documentation_context", + "Find documentation chunks connected to code nodes.", + "MATCH (d:DocumentationChunk)-[:FROM_Documents]->(:Documents)-[:TO_Documents]->(n) RETURN d.id, d.label, n.id, n.label LIMIT 50", + returns=("doc_id", "doc_label", "node_id", "node_label"), + ), + QueryHelperSpec( + "unresolved_references", + "Find references that have not been resolved to a semantic target.", + "MATCH (r:Reference) " + "WHERE NOT EXISTS { MATCH (r)-[:FROM_ResolvesTo]->(:ResolvesTo)-[:TO_ResolvesTo]->() } " + "RETURN r.id, r.label, r.path, r.line_start LIMIT 100", + returns=("id", "label", "path", "line_start"), + ), +) + + +def node_type_names() -> tuple[str, ...]: + return tuple(node.name for node in NODE_TYPES) + + +def relation_type_names() -> tuple[str, ...]: + return tuple(relation.name for relation in RELATION_TYPES) + + +def get_node_type(name: str) -> NodeTypeSpec: + for node_type in NODE_TYPES: + if node_type.name == name: + return node_type + raise KeyError(name) + + +def get_relation_type(name: str) -> RelationTypeSpec: + for relation_type in RELATION_TYPES: + if relation_type.name == name: + return relation_type + raise KeyError(name) + + +def schema_payload() -> dict[str, Any]: + return { + "ontology": ONTOLOGY_NAME, + "version": ONTOLOGY_VERSION, + "node_types": [node.as_dict() for node in NODE_TYPES], + "relation_types": [relation.as_dict() for relation in RELATION_TYPES], + "parser_node_mappings": [mapping.as_dict() for mapping in PARSER_NODE_MAPPINGS], + "search_indexes": list(SEARCH_INDEXES), + "context_profiles": CONTEXT_PROFILES, + "query_helpers": [helper.as_dict() for helper in QUERY_HELPERS], + } diff --git a/src/codebase_graph/paths.py b/src/codebase_graph/paths.py new file mode 100644 index 0000000..f89fe2b --- /dev/null +++ b/src/codebase_graph/paths.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +DEFAULT_STATE_DIR = ".codebaseGraph" +CONFIG_NAME = "config.json" +MANIFEST_NAME = "manifest.json" +MCP_SERVER_NAME = "codebase_graph" + + +@dataclass(frozen=True, slots=True) +class GraphStatePaths: + repo_root: Path + repo_name: str + state_dir: Path + db_path: Path + manifest_path: Path + config_path: Path + + def as_dict(self) -> dict[str, str]: + return { + "repo_root": self.repo_root.as_posix(), + "repo_name": self.repo_name, + "state_dir": self.state_dir.as_posix(), + "db_path": self.db_path.as_posix(), + "manifest_path": self.manifest_path.as_posix(), + "config_path": self.config_path.as_posix(), + } + + +def derive_graph_state_paths(repo_root: str | Path) -> GraphStatePaths: + root = Path(repo_root).expanduser().resolve() + repo_name = _repo_name(root) + state_dir = root / DEFAULT_STATE_DIR + return GraphStatePaths( + repo_root=root, + repo_name=repo_name, + state_dir=state_dir, + db_path=state_dir / f"{repo_name}_graph.ldb", + manifest_path=state_dir / MANIFEST_NAME, + config_path=state_dir / CONFIG_NAME, + ) + + +def _repo_name(root: Path) -> str: + name = root.name.strip() + if name: + return _safe_name(name) + return "repository" + + +def _safe_name(value: str) -> str: + normalized = "".join(character if character.isalnum() or character in {"-", "_"} else "_" for character in value) + return normalized.strip("._-") or "repository" diff --git a/src/codebase_graph/production.py b/src/codebase_graph/production.py deleted file mode 100644 index 2d61278..0000000 --- a/src/codebase_graph/production.py +++ /dev/null @@ -1,245 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -try: - import tomllib -except ModuleNotFoundError: # pragma: no cover - py310 fallback - import tomli as tomllib # type: ignore[no-redef] - -from .code_map import CodebaseGraphBuilder, _iter_indexable_files -from .document_layers import LogicalChunker -from .markdown import parse_markdown, plain_text -from .ontology import ONTOLOGY_NAME -from .verification import summarize_verification_run - -@dataclass(slots=True) -class GraphExport: - nodes: list[dict[str, Any]] = field(default_factory=list) - edges: list[dict[str, Any]] = field(default_factory=list) - metadata: dict[str, Any] = field(default_factory=dict) - - def as_dict(self) -> dict[str, Any]: - return {"ontology": ONTOLOGY_NAME, "metadata": self.metadata, "nodes": self.nodes, "edges": self.edges} - - def summary(self) -> dict[str, Any]: - node_counts: dict[str, int] = {} - edge_counts: dict[str, int] = {} - for node in self.nodes: - node_counts[node.get("table", "Unknown")] = node_counts.get(node.get("table", "Unknown"), 0) + 1 - for edge in self.edges: - edge_counts[edge.get("type", "Unknown")] = edge_counts.get(edge.get("type", "Unknown"), 0) + 1 - return { - "ontology": ONTOLOGY_NAME, - "node_count": len(self.nodes), - "edge_count": len(self.edges), - "node_counts": node_counts, - "edge_counts": edge_counts, - } - -class ProductionGraphBuilder: - def __init__(self, repo_root: str | Path = ".") -> None: - self.repo_root = Path(repo_root) - self.nodes: dict[str, dict[str, Any]] = {} - self.edges: dict[str, dict[str, Any]] = {} - - def build_export(self) -> GraphExport: - project_name = self._project_name() - project_id = _id("project", project_name) - repository_id = _id("repository", str(self.repo_root.resolve())) - self._node("Project", project_id, project_name, "project", path=".") - self._node("Repository", repository_id, self.repo_root.name, "repository", path=str(self.repo_root)) - self._edge("Contains", project_id, repository_id, "project_repository") - self._add_codebase(repository_id) - self._add_documentation(repository_id) - self._add_dependencies(repository_id) - self._add_entry_points(repository_id) - self._add_verification_sources(repository_id) - return GraphExport( - nodes=sorted(self.nodes.values(), key=lambda item: (item.get("table", ""), item.get("id", ""))), - edges=sorted(self.edges.values(), key=lambda item: item.get("id", "")), - metadata={"project_name": project_name, "source_root": str(self.repo_root)}, - ) - - def _add_codebase(self, repository_id: str) -> None: - code_map = CodebaseGraphBuilder(self.repo_root).build() - for file in code_map.files: - file_node_id = file.id - self._node( - "File", - file_node_id, - Path(file.path).name, - "python_file", - path=file.path, - module_name=file.module_name, - summary=file.summary, - metadata={"line_count": file.line_count, "language": file.language}, - ) - self._edge("Contains", repository_id, file_node_id, "repository_file") - module_id = _id("module", file.module_name or file.path) - self._node( - "PythonModule", - module_id, - file.module_name or file.path, - "python_module", - path=file.path, - module_name=file.module_name, - qualified_name=file.module_name, - summary=file.summary, - ) - self._edge("Defines", file_node_id, module_id, "file_module") - for imported in file.imports: - import_id = _id("import", f"{file.path}:{imported}") - self._node("Import", import_id, imported, "python_import", path=file.path, qualified_name=imported) - self._edge("Imports", module_id, import_id, "module_import") - for call in file.calls: - call_id = _id("call", f"{file.path}:{call}") - self._node("Call", call_id, call, "python_call", path=file.path, qualified_name=call) - self._edge("Calls", module_id, call_id, "module_call") - for symbol in file.symbols: - table = { - "python_class": "PythonClass", - "python_function": "PythonFunction", - "python_method": "PythonMethod", - }.get(symbol.kind, "PythonFunction") - self._node( - table, - symbol.id, - symbol.label, - symbol.kind, - path=symbol.path, - module_name=symbol.module_name, - qualified_name=symbol.qualified_name, - line_start=symbol.line_start, - line_end=symbol.line_end, - summary=symbol.summary, - metadata={"decorators": symbol.decorators, "bases": symbol.bases}, - ) - self._edge("Defines", module_id, symbol.id, "module_symbol") - - def _add_documentation(self, repository_id: str) -> None: - chunker = LogicalChunker(max_chars=1200) - for path in _iter_documentation_files(self.repo_root): - rel_path = path.relative_to(self.repo_root).as_posix() - content = path.read_text(encoding="utf-8", errors="replace") - _, body = parse_markdown(content) if path.suffix.lower() == ".md" else ({}, content) - summary = plain_text(body)[:500] - doc_id = _id("doc", rel_path) - self._node( - "DocumentationSource", - doc_id, - path.name, - "documentation_source", - path=rel_path, - summary=summary, - metadata={"chunks": [_chunk_as_dict(chunk) for chunk in chunker.chunk(body)[:5]]}, - ) - self._edge("Describes", repository_id, doc_id, "repository_documentation") - - def _add_dependencies(self, repository_id: str) -> None: - pyproject = self.repo_root / "pyproject.toml" - if not pyproject.exists(): - return - payload = tomllib.loads(pyproject.read_text(encoding="utf-8")) - project = payload.get("project", {}) if isinstance(payload, dict) else {} - dependencies = project.get("dependencies", []) if isinstance(project, dict) else [] - for dependency in dependencies: - name = str(dependency).split(";", 1)[0].strip() - dep_id = _id("dependency", name) - self._node("Dependency", dep_id, name, "python_dependency", path="pyproject.toml", summary=str(dependency)) - self._edge("DependsOn", repository_id, dep_id, "declared_dependency") - - def _add_entry_points(self, repository_id: str) -> None: - pyproject = self.repo_root / "pyproject.toml" - if not pyproject.exists(): - return - payload = tomllib.loads(pyproject.read_text(encoding="utf-8")) - scripts = payload.get("project", {}).get("scripts", {}) if isinstance(payload, dict) else {} - if not isinstance(scripts, dict): - return - for name, target in sorted(scripts.items()): - entry_id = _id("entry", f"script:{name}") - self._node("EntryPoint", entry_id, name, "console_script", path="pyproject.toml", qualified_name=str(target)) - self._edge("Produces", repository_id, entry_id, "repository_entry_point") - - def _add_verification_sources(self, repository_id: str) -> None: - for directory in (self.repo_root / ".codebase_graph" / "verification_runs", self.repo_root / "verification"): - if not directory.exists(): - continue - for path in sorted(directory.glob("*.json")): - try: - payload = json.loads(path.read_text(encoding="utf-8")) - except json.JSONDecodeError: - continue - command = str(payload.get("command", "")) - output = str(payload.get("output", "")) - exit_code = payload.get("exit_code") - summary = summarize_verification_run(command, output, exit_code if isinstance(exit_code, int) else None) - verification_id = _id("verification", path.relative_to(self.repo_root).as_posix()) - self._node( - "Verification", - verification_id, - summary["tool"], - "verification_run", - path=path.relative_to(self.repo_root).as_posix(), - summary=summary["summary"], - metadata=summary, - ) - self._edge("Produces", repository_id, verification_id, "repository_verification") - - def _project_name(self) -> str: - pyproject = self.repo_root / "pyproject.toml" - if pyproject.exists(): - try: - payload = tomllib.loads(pyproject.read_text(encoding="utf-8")) - name = payload.get("project", {}).get("name") - if name: - return str(name) - except Exception: - pass - return self.repo_root.name - - def _node(self, table: str, node_id: str, label: str, kind: str, **fields: Any) -> None: - existing = self.nodes.get(node_id, {}) - node = { - "id": node_id, - "table": table, - "label": label, - "kind": kind, - "path": fields.pop("path", ""), - "qualified_name": fields.pop("qualified_name", ""), - "module_name": fields.pop("module_name", ""), - "line_start": fields.pop("line_start", None), - "line_end": fields.pop("line_end", None), - "summary": fields.pop("summary", ""), - "metadata": fields.pop("metadata", {}), - } - node.update(fields) - existing.update({key: value for key, value in node.items() if value not in (None, "", {})}) - self.nodes[node_id] = existing or node - - def _edge(self, edge_type: str, source_id: str, target_id: str, kind: str, **fields: Any) -> None: - edge_id = _id("edge", f"{edge_type}:{source_id}:{target_id}:{kind}") - self.edges[edge_id] = { - "id": edge_id, - "type": edge_type, - "kind": kind, - "source_id": source_id, - "target_id": target_id, - "metadata": fields, - } - - -def _chunk_as_dict(chunk: Any) -> dict[str, Any]: - return {"id": chunk.id, "heading": chunk.heading, "text": chunk.text, "ordinal": chunk.ordinal} - -def _iter_documentation_files(root: Path) -> list[Path]: - suffixes = {".md", ".txt", ".rst"} - return _iter_indexable_files(root, suffixes, case_insensitive_suffixes=True) - -def _id(prefix: str, value: str) -> str: - return f"{prefix}:{hashlib.sha1(value.encode('utf-8')).hexdigest()[:20]}" diff --git a/src/codebase_graph/production_queries.py b/src/codebase_graph/production_queries.py deleted file mode 100644 index d08097d..0000000 --- a/src/codebase_graph/production_queries.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -from collections import Counter -from typing import Any - -from .ontology import schema_payload - -def graph_schema() -> dict[str, Any]: - return schema_payload() - -def graph_query(core: Any, query: str, parameters: dict[str, Any] | None = None) -> dict[str, Any]: - return core.cypher(query, parameters=parameters or {}) - -def graph_coverage(core: Any) -> dict[str, Any]: - core.ensure_current() - graph = core._read_graph() - counts = Counter(node.get("table", "Unknown") for node in graph.get("nodes", [])) - return {"node_counts": dict(counts), "node_count": sum(counts.values())} - -def repository_analysis(core: Any) -> dict[str, Any]: - search = core.search("project repository python module class function", limit=20) - return {"retrieval": search.get("retrieval"), "items": search.get("items", []), "count": search.get("count", 0)} - -def risk_report(core: Any) -> dict[str, Any]: - result = core.cypher("MATCH (n:Risk) RETURN n.id, n.label, n.summary LIMIT 25") - return {"items": result.get("rows", []), "count": result.get("count", 0)} - -def task_report(core: Any) -> dict[str, Any]: - return {"items": [], "count": 0} - -def artifact_by_id(core: Any, artifact_id: str) -> dict[str, Any] | None: - core.ensure_current() - for node in core._read_graph().get("nodes", []): - if node.get("id") == artifact_id: - return node - return None - -def explain_decision(core: Any, decision_id: str) -> dict[str, Any]: - artifact = artifact_by_id(core, decision_id) - return {"id": decision_id, "artifact": artifact, "explanation": artifact.get("summary") if artifact else ""} diff --git a/src/codebase_graph/question_query_registry.py b/src/codebase_graph/question_query_registry.py deleted file mode 100644 index d7c47ce..0000000 --- a/src/codebase_graph/question_query_registry.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import annotations - -import argparse -import json -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Mapping - -PHASE_ARCHITECTURE_UNDERSTANDING = "architecture_understanding" -PHASE_CHANGE_PREPARATION = "change_preparation" -PHASE_BREAKING_CHANGE_PREPARATION = "breaking_change_preparation" - -@dataclass(frozen=True, slots=True) -class EngineeringQuestionQuery: - id: str - question: str - intent: str - phase: str - query: str - required_params: tuple[str, ...] = () - result_shape: tuple[str, ...] = () - tags: tuple[str, ...] = () - - def validate_params(self, params: Mapping[str, Any]) -> None: - missing = [name for name in self.required_params if not params.get(name)] - if missing: - raise ValueError(f"Missing required params for {self.id}: {', '.join(missing)}") - - def run(self, core: Any, **params: Any) -> dict[str, Any]: - self.validate_params(params) - return core.cypher(self.query, parameters=dict(params)) - -_QUERIES: tuple[EngineeringQuestionQuery, ...] = ( - EngineeringQuestionQuery( - id="se.architecture.entrypoints.v1", - question="What are the main CLI or package entry points?", - intent="Map runtime ingress points before architecture reasoning.", - phase=PHASE_ARCHITECTURE_UNDERSTANDING, - query="MATCH (n:EntryPoint) RETURN n.id, n.label, n.kind, n.path, n.qualified_name LIMIT 50", - result_shape=("id", "label", "kind", "path", "qualified_name"), - tags=("architecture", "entrypoints", "runtime"), - ), - EngineeringQuestionQuery( - id="se.change.tests_for_artifact.v1", - question="What test artifacts exist for this path or symbol?", - intent="Identify existing tests for target path or symbol before edits.", - phase=PHASE_CHANGE_PREPARATION, - query="MATCH (n:Test) RETURN n.id, n.label, n.kind, n.path, n.qualified_name LIMIT 50", - required_params=("path", "symbol"), - result_shape=("id", "label", "kind", "path", "qualified_name"), - tags=("change", "tests", "coverage"), - ), - EngineeringQuestionQuery( - id="se.breaking.consumers_of_contract.v1", - question="Who are the consumers of the old behavior?", - intent="Find direct consumers of a contract before introducing breaking changes.", - phase=PHASE_BREAKING_CHANGE_PREPARATION, - query="MATCH (n:Call) RETURN n.id, n.label, n.kind, n.path, n.qualified_name LIMIT 50", - required_params=("contract_id",), - result_shape=("id", "label", "kind", "path", "qualified_name"), - tags=("breaking-change", "consumers", "contract"), - ), -) - -def list_engineering_question_queries(phase: str | None = None) -> list[EngineeringQuestionQuery]: - return [query for query in _QUERIES if phase is None or query.phase == phase] - -def get_engineering_question_query(query_id: str) -> EngineeringQuestionQuery: - for query in _QUERIES: - if query.id == query_id: - return query - raise KeyError(query_id) - -def main(argv: Sequence[str] | None = None) -> int: - parser = argparse.ArgumentParser(description="List or run versioned engineering question queries.") - parser.add_argument("--source-root", default=".") - parser.add_argument("--state-dir", default=None) - parser.add_argument("--db-path", default=None) - parser.add_argument("--staging-dir", default=None) - subparsers = parser.add_subparsers(dest="command", required=True) - list_parser = subparsers.add_parser("list") - list_parser.add_argument("--phase") - run_parser = subparsers.add_parser("run") - run_parser.add_argument("query_id") - run_parser.add_argument("--params-json", default="{}") - run_parser.add_argument("--no-refresh", action="store_true") - args = parser.parse_args(argv) - if args.command == "list": - queries = list_engineering_question_queries(phase=args.phase) - payload = {"count": len(queries), "items": [_query_as_dict(query) for query in queries]} - elif args.command == "run": - from .graph_core import CodebaseGraph - - params = json.loads(args.params_json) - if not isinstance(params, dict): - raise ValueError("--params-json must decode to an object") - core = CodebaseGraph( - source_root=args.source_root, - state_dir=args.state_dir, - database_path=args.db_path, - staging_dir=args.staging_dir, - ) - if not args.no_refresh: - core.ensure_current() - question = get_engineering_question_query(args.query_id) - payload = {"question": _query_as_dict(question), "result": question.run(core, **params)} - else: - parser.error(f"unsupported command: {args.command}") - print(json.dumps(payload, indent=2, sort_keys=True)) - return 0 - -def _query_as_dict(query: EngineeringQuestionQuery) -> dict[str, Any]: - return { - "id": query.id, - "question": query.question, - "intent": query.intent, - "phase": query.phase, - "required_params": list(query.required_params), - "result_shape": list(query.result_shape), - "tags": list(query.tags), - } - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/src/codebase_graph/reasoning/__init__.py b/src/codebase_graph/reasoning/__init__.py new file mode 100644 index 0000000..4f3d7b4 --- /dev/null +++ b/src/codebase_graph/reasoning/__init__.py @@ -0,0 +1,20 @@ +"""Path explanation, causal trace, and context assembly.""" + +from .architecture_queries import ( + ARCHITECTURE_QUERY_GROUPS, + ARCHITECTURE_QUERY_ORDER, + ArchitectureQueryGroup, + ArchitectureQuerySpec, + architecture_query_catalog, +) +from .context_builder import CompactContextBuilder, ContextNode + +__all__ = [ + "ARCHITECTURE_QUERY_GROUPS", + "ARCHITECTURE_QUERY_ORDER", + "ArchitectureQueryGroup", + "ArchitectureQuerySpec", + "CompactContextBuilder", + "ContextNode", + "architecture_query_catalog", +] diff --git a/src/codebase_graph/reasoning/architecture_queries.py b/src/codebase_graph/reasoning/architecture_queries.py new file mode 100644 index 0000000..fb98970 --- /dev/null +++ b/src/codebase_graph/reasoning/architecture_queries.py @@ -0,0 +1,372 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +WORKFLOW_NAME = "coding_task_architecture_discovery" +EXECUTION_TOOL = "graph_query" + + +@dataclass(frozen=True, slots=True) +class ArchitectureQuerySpec: + name: str + description: str + statement: str + parameters: tuple[str, ...] = () + returns: tuple[str, ...] = () + + def as_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "statement": self.statement, + "parameters": list(self.parameters), + "returns": list(self.returns), + } + + +@dataclass(frozen=True, slots=True) +class ArchitectureQueryGroup: + name: str + goal: str + queries: tuple[ArchitectureQuerySpec, ...] + + def as_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "goal": self.goal, + "queries": [query.as_dict() for query in self.queries], + } + + +ARCHITECTURE_QUERY_ORDER = ( + "overview", + "public_surface", + "dependency_topology", + "execution_flow", + "runtime_data_security", + "documentation_context", + "graph_quality_gaps", +) + + +ARCHITECTURE_QUERY_GROUPS: dict[str, ArchitectureQueryGroup] = { + "overview": ArchitectureQueryGroup( + name="overview", + goal="Check graph coverage and establish the indexed codebase shape.", + queries=( + ArchitectureQuerySpec( + name="graph_coverage", + description="Count all materialized graph nodes as a quick coverage check.", + statement="MATCH (n) RETURN count(n) AS total_nodes LIMIT 1", + returns=("total_nodes",), + ), + ArchitectureQuerySpec( + name="source_unit_inventory", + description="List materialized modules with source paths and spans.", + statement=( + "MATCH (m:Module) " + "RETURN m.id, m.label, m.qualified_name, m.path, m.line_start, m.line_end " + "ORDER BY m.path LIMIT 200" + ), + returns=("id", "label", "qualified_name", "path", "line_start", "line_end"), + ), + ArchitectureQuerySpec( + name="package_directory_shape", + description="List source files with path and content metadata.", + statement=( + "MATCH (f:File) " + "RETURN f.path, f.label, f.size_bytes, f.content_hash " + "ORDER BY f.path LIMIT 300" + ), + returns=("path", "label", "size_bytes", "content_hash"), + ), + ), + ), + "public_surface": ArchitectureQueryGroup( + name="public_surface", + goal="Find how the library exposes behavior through modules, definitions, or runtime entrypoints.", + queries=( + ArchitectureQuerySpec( + name="public_surface_candidates", + description="Find exposed module surfaces and fallback definition-level public candidates.", + statement=( + "MATCH (m:Module)-[:FROM_Exposes]->(:Exposes)-[:TO_Exposes]->(surface) " + "RETURN 'exposed' AS surface_source, m.label AS module_label, m.path AS module_path, " + "surface.id AS surface_id, surface.label AS surface_label, " + "surface.qualified_name AS surface_qualified_name, surface.path AS surface_path, " + "surface.line_start AS line_start " + "UNION ALL " + "MATCH (m:Module)-[:FROM_Defines]->(:Defines)-[:TO_Defines]->(surface:Class) " + "RETURN 'defined' AS surface_source, m.label AS module_label, m.path AS module_path, " + "surface.id AS surface_id, surface.label AS surface_label, " + "surface.qualified_name AS surface_qualified_name, surface.path AS surface_path, " + "surface.line_start AS line_start " + "UNION ALL " + "MATCH (m:Module)-[:FROM_Defines]->(:Defines)-[:TO_Defines]->(surface:Function) " + "RETURN 'defined' AS surface_source, m.label AS module_label, m.path AS module_path, " + "surface.id AS surface_id, surface.label AS surface_label, " + "surface.qualified_name AS surface_qualified_name, surface.path AS surface_path, " + "surface.line_start AS line_start " + "UNION ALL " + "MATCH (m:Module)-[:FROM_Defines]->(:Defines)-[:TO_Defines]->(surface:Method) " + "RETURN 'defined' AS surface_source, m.label AS module_label, m.path AS module_path, " + "surface.id AS surface_id, surface.label AS surface_label, " + "surface.qualified_name AS surface_qualified_name, surface.path AS surface_path, " + "surface.line_start AS line_start LIMIT 200" + ), + returns=( + "surface_source", + "module_label", + "module_path", + "surface_id", + "surface_label", + "surface_qualified_name", + "surface_path", + "line_start", + ), + ), + ArchitectureQuerySpec( + name="entrypoint_runtime_surface", + description="Find function-level name/path candidates for runtime or CLI entrypoints.", + statement=( + "MATCH (d:Function) " + "WHERE d.label = 'main' OR d.label = 'cli' OR d.label CONTAINS 'server' OR d.path CONTAINS 'cli' " + "RETURN 'name_candidate' AS entrypoint_kind, d.id AS entrypoint_id, d.label AS entrypoint_label, " + "d.path AS entrypoint_path, d.id AS target_id, d.label AS target_label, " + "d.qualified_name AS target_qualified_name, d.path AS target_path, d.line_start AS line_start " + "LIMIT 100" + ), + returns=( + "entrypoint_kind", + "entrypoint_id", + "entrypoint_label", + "entrypoint_path", + "target_id", + "target_label", + "target_qualified_name", + "target_path", + "line_start", + ), + ), + ), + ), + "dependency_topology": ArchitectureQueryGroup( + name="dependency_topology", + goal="Map internal and external dependencies so agents can infer layers and adapters.", + queries=( + ArchitectureQuerySpec( + name="external_dependency_map", + description="Map import declarations to external dependency nodes.", + statement=( + "MATCH (i:ImportDeclaration)-[:FROM_DependsOn]->(:DependsOn)-[:TO_DependsOn]->(d:Dependency) " + "RETURN i.path, i.label AS import_label, d.label AS dependency " + "ORDER BY d.label, i.path LIMIT 300" + ), + returns=("path", "import_label", "dependency"), + ), + ArchitectureQuerySpec( + name="module_import_coupling", + description="List modules and their import declarations as a coupling inventory.", + statement=( + "MATCH (m:Module)-[:FROM_Imports]->(:Imports)-[:TO_Imports]->(i:ImportDeclaration) " + "RETURN m.label, m.path, i.label, i.line_start " + "ORDER BY m.path, i.line_start LIMIT 300" + ), + returns=("module_label", "module_path", "import_label", "line_start"), + ), + ), + ), + "execution_flow": ArchitectureQueryGroup( + name="execution_flow", + goal="Identify important call paths, orchestration nodes, and central implementation flows.", + queries=( + ArchitectureQuerySpec( + name="high_fan_in_definitions", + description="Find definitions with many resolved incoming references.", + statement=( + "MATCH (ref)-[:FROM_ResolvesTo]->(:ResolvesTo)-[:TO_ResolvesTo]->(target:Class) " + "RETURN target.id, target.label, target.qualified_name, target.path, count(ref) AS inbound_refs " + "UNION ALL " + "MATCH (ref)-[:FROM_ResolvesTo]->(:ResolvesTo)-[:TO_ResolvesTo]->(target:Function) " + "RETURN target.id, target.label, target.qualified_name, target.path, count(ref) AS inbound_refs " + "UNION ALL " + "MATCH (ref)-[:FROM_ResolvesTo]->(:ResolvesTo)-[:TO_ResolvesTo]->(target:Method) " + "RETURN target.id, target.label, target.qualified_name, target.path, count(ref) AS inbound_refs " + "UNION ALL " + "MATCH (ref)-[:FROM_ResolvesTo]->(:ResolvesTo)-[:TO_ResolvesTo]->(target:Module) " + "RETURN target.id, target.label, target.qualified_name, target.path, count(ref) AS inbound_refs " + "ORDER BY inbound_refs DESC LIMIT 50" + ), + returns=("id", "label", "qualified_name", "path", "inbound_refs"), + ), + ArchitectureQuerySpec( + name="high_fan_out_callers", + description="Find functions or methods that call many downstream nodes.", + statement=( + "MATCH (caller:Function)-[:FROM_Calls]->(:Calls)-[:TO_Calls]->(callee) " + "RETURN caller.id, caller.label, caller.qualified_name, caller.path, count(callee) AS outgoing_calls " + "UNION ALL " + "MATCH (caller:Method)-[:FROM_Calls]->(:Calls)-[:TO_Calls]->(callee) " + "RETURN caller.id, caller.label, caller.qualified_name, caller.path, count(callee) AS outgoing_calls " + "ORDER BY outgoing_calls DESC LIMIT 50" + ), + returns=("id", "label", "qualified_name", "path", "outgoing_calls"), + ), + ArchitectureQuerySpec( + name="callable_neighborhood", + description="Inspect direct callees for a named callable.", + statement=( + "MATCH (caller)-[:FROM_Calls]->(:Calls)-[:TO_Calls]->(callee) " + "WHERE caller.label = $name OR caller.qualified_name = $name " + "RETURN caller.id, caller.label, caller.qualified_name, callee.id, " + "callee.label, callee.qualified_name, callee.path LIMIT 100" + ), + parameters=("name",), + returns=( + "caller_id", + "caller_label", + "caller_qualified_name", + "callee_id", + "callee_label", + "callee_qualified_name", + "callee_path", + ), + ), + ), + ), + "runtime_data_security": ArchitectureQueryGroup( + name="runtime_data_security", + goal="Expose data access, query execution, secrets, and configuration-sensitive paths.", + queries=( + ArchitectureQuerySpec( + name="data_query_touchpoints", + description="Find actors that execute or construct query nodes.", + statement=( + "MATCH (actor)-[:FROM_ExecutesQuery]->(:ExecutesQuery)-[:TO_ExecutesQuery]->(q:Query) " + "RETURN actor.id, actor.label, actor.qualified_name, actor.path, q.label, q.path, q.line_start " + "LIMIT 100" + ), + returns=( + "actor_id", + "actor_label", + "actor_qualified_name", + "actor_path", + "query_label", + "query_path", + "query_line_start", + ), + ), + ArchitectureQuerySpec( + name="secret_configuration_touchpoints", + description="Find actors linked to secret or sensitive configuration references.", + statement=( + "MATCH (actor)-[:FROM_UsesSecret]->(:UsesSecret)-[:TO_UsesSecret]->(s:SecretRef) " + "RETURN actor.id, actor.label, actor.qualified_name, actor.path, s.label, s.path, s.line_start " + "LIMIT 100" + ), + returns=( + "actor_id", + "actor_label", + "actor_qualified_name", + "actor_path", + "secret_label", + "secret_path", + "secret_line_start", + ), + ), + ), + ), + "documentation_context": ArchitectureQueryGroup( + name="documentation_context", + goal="Link architecture claims to documentation and parser evidence.", + queries=( + ArchitectureQuerySpec( + name="documentation_to_code_links", + description="Find documentation chunks connected to code nodes.", + statement=( + "MATCH (d:DocumentationChunk)-[:FROM_Documents]->(:Documents)-[:TO_Documents]->(n) " + "RETURN d.id, d.label, d.path, n.id, n.label, n.qualified_name, n.path LIMIT 100" + ), + returns=( + "doc_id", + "doc_label", + "doc_path", + "node_id", + "node_label", + "node_qualified_name", + "node_path", + ), + ), + ArchitectureQuerySpec( + name="evidence_for_symbol", + description="Return evidence nodes for a named symbol or qualified name.", + statement=( + "MATCH (n)-[:FROM_EvidencedBy]->(:EvidencedBy)-[:TO_EvidencedBy]->(e) " + "WHERE n.label = $name OR n.qualified_name = $name " + "RETURN n.id, n.label, n.qualified_name, n.path, e.id, e.label, e.path, e.line_start, e.line_end " + "LIMIT 100" + ), + parameters=("name",), + returns=( + "node_id", + "node_label", + "node_qualified_name", + "node_path", + "evidence_id", + "evidence_label", + "evidence_path", + "evidence_line_start", + "evidence_line_end", + ), + ), + ), + ), + "graph_quality_gaps": ArchitectureQueryGroup( + name="graph_quality_gaps", + goal="Detect graph gaps that reduce confidence in architecture claims.", + queries=( + ArchitectureQuerySpec( + name="unresolved_reference_risk", + description="Find references without resolved semantic targets.", + statement=( + "MATCH (r:Reference) " + "WHERE NOT EXISTS { MATCH (r)-[:FROM_ResolvesTo]->(:ResolvesTo)-[:TO_ResolvesTo]->() } " + "RETURN r.id, r.label, r.path, r.line_start ORDER BY r.path, r.line_start LIMIT 200" + ), + returns=("id", "label", "path", "line_start"), + ), + ), + ), +} + + +def architecture_query_catalog(group: str | None = None) -> dict[str, Any]: + groups = _selected_groups(group) + return { + "workflow": WORKFLOW_NAME, + "recommended_order": list(ARCHITECTURE_QUERY_ORDER), + "execution_tool": EXECUTION_TOOL, + "groups": [query_group.as_dict() for query_group in groups], + } + + +def _selected_groups(group: str | None) -> tuple[ArchitectureQueryGroup, ...]: + if group is None or group == "": + return tuple(ARCHITECTURE_QUERY_GROUPS[name] for name in ARCHITECTURE_QUERY_ORDER) + try: + return (ARCHITECTURE_QUERY_GROUPS[group],) + except KeyError as exc: + valid = ", ".join(ARCHITECTURE_QUERY_ORDER) + raise ValueError(f"Unknown architecture query group: {group}. Valid groups: {valid}") from exc + + +__all__ = [ + "ARCHITECTURE_QUERY_GROUPS", + "ARCHITECTURE_QUERY_ORDER", + "EXECUTION_TOOL", + "WORKFLOW_NAME", + "ArchitectureQueryGroup", + "ArchitectureQuerySpec", + "architecture_query_catalog", +] diff --git a/src/codebase_graph/reasoning/context_builder.py b/src/codebase_graph/reasoning/context_builder.py new file mode 100644 index 0000000..82d1484 --- /dev/null +++ b/src/codebase_graph/reasoning/context_builder.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from codebase_graph.db import graph_query_adapter +from codebase_graph.ontology import CONTEXT_PROFILES, RELATION_TYPES + + +DEFAULT_CONTEXT_LIMIT = 3 +DEFAULT_CONTEXT_BUDGET = 600 + + +@dataclass(frozen=True, slots=True) +class ContextNode: + relation: str + direction: str + type: str + label: str + path: str = "" + span: dict[str, int] = field(default_factory=dict) + summary: str = "" + id: str = field(default="", repr=False) + + def as_dict(self, *, detail: str = "standard") -> dict[str, Any]: + if detail not in {"standard", "slim"}: + raise ValueError(f"Unknown detail level: {detail}. Valid levels: slim, standard") + if detail == "slim": + payload: dict[str, Any] = { + "relation": self.relation, + "direction": self.direction, + "type": self.type, + "label": self.label, + } + if self.path: + payload["path"] = self.path + if self.span: + payload["span"] = dict(self.span) + if self.summary and self.summary != self.label: + payload["summary"] = self.summary + return payload + return { + "relation": self.relation, + "direction": self.direction, + "type": self.type, + "label": self.label, + "path": self.path, + "span": dict(self.span), + "summary": self.summary, + } + + +class CompactContextBuilder: + def __init__(self, store: Any) -> None: + self.store = store + self.query = graph_query_adapter(store) + self._relation_names = {relation_type.name for relation_type in RELATION_TYPES} + + def build( + self, + node_id: str, + node_type: str, + *, + profile: str = "brief", + limit: int = DEFAULT_CONTEXT_LIMIT, + budget: int = DEFAULT_CONTEXT_BUDGET, + max_depth: int | None = None, + ) -> list[ContextNode]: + profile_config = self._profile(profile) + if limit <= 0 or budget <= 0: + return [] + depth_limit = profile_config["max_depth"] if max_depth is None else max_depth + if depth_limit <= 0: + return [] + + relations = tuple( + relation + for relation in profile_config["relations"] + if relation in self._relation_names + ) + if not relations: + return [] + + context: list[ContextNode] = [] + seen = {node_id} + frontier = [(node_id, node_type, 0)] + used_budget = 0 + + while frontier and len(context) < limit: + current_id, current_type, depth = frontier.pop(0) + if depth >= depth_limit: + continue + for relation in relations: + for candidate in self._neighbors(current_id, current_type, relation, limit): + if candidate.type == "" or candidate.label == "": + continue + candidate_key = f"{candidate.type}:{candidate.label}:{candidate.path}:{candidate.span}" + node_key = _node_key(candidate) + dedupe_key = node_key or candidate_key + if dedupe_key in seen: + continue + compact_candidate, item_cost = _fit_to_budget(candidate, budget - used_budget) + if compact_candidate is None: + return context + context.append(compact_candidate) + used_budget += item_cost + seen.add(dedupe_key) + if node_key: + frontier.append((node_key, candidate.type, depth + 1)) + if len(context) >= limit: + return context + return context + + def _profile(self, profile: str) -> dict[str, Any]: + if profile not in CONTEXT_PROFILES: + valid = ", ".join(sorted(CONTEXT_PROFILES)) + raise ValueError(f"Unknown context profile: {profile}. Valid profiles: {valid}") + return dict(CONTEXT_PROFILES[profile]) + + def _neighbors(self, node_id: str, node_type: str, relation: str, limit: int) -> list[ContextNode]: + outgoing = self._query_neighbors(node_id, node_type, relation, "outgoing", limit) + incoming = self._query_neighbors(node_id, node_type, relation, "incoming", limit) + return [*outgoing, *incoming] + + def _query_neighbors( + self, + node_id: str, + node_type: str, + relation: str, + direction: str, + limit: int, + ) -> list[ContextNode]: + return [ + ContextNode( + relation=relation, + direction=direction, + type=neighbor.node_type, + label=neighbor.label or neighbor.qualified_name, + path=neighbor.path, + span=_span(neighbor.line_start, neighbor.line_end), + summary=neighbor.summary, + id=neighbor.node_id, + ) + for neighbor in self.query.neighbors( + node_id=node_id, + node_type=node_type, + relation=relation, + direction=direction, + limit=limit, + ) + ] + + +def _fit_to_budget(node: ContextNode, remaining_budget: int) -> tuple[ContextNode | None, int]: + cost = _context_cost(node) + if cost <= remaining_budget: + return node, cost + fixed_cost = _context_cost(ContextNode(node.relation, node.direction, node.type, node.label, node.path, node.span, "")) + summary_budget = remaining_budget - fixed_cost + if summary_budget <= 0: + return None, 0 + summary = node.summary[:summary_budget] + compact = ContextNode(node.relation, node.direction, node.type, node.label, node.path, node.span, summary) + return compact, _context_cost(compact) + + +def _context_cost(node: ContextNode) -> int: + return sum( + len(str(value)) + for value in ( + node.relation, + node.direction, + node.type, + node.label, + node.path, + node.summary, + *node.span.values(), + ) + ) + + +def _node_key(node: ContextNode) -> str: + return node.id + + +def _span(line_start: Any, line_end: Any) -> dict[str, int]: + span: dict[str, int] = {} + if line_start is not None: + span["line_start"] = int(line_start) + if line_end is not None: + span["line_end"] = int(line_end) + return span + + +__all__ = ["CompactContextBuilder", "ContextNode", "DEFAULT_CONTEXT_BUDGET", "DEFAULT_CONTEXT_LIMIT"] diff --git a/src/codebase_graph/retrieval/__init__.py b/src/codebase_graph/retrieval/__init__.py new file mode 100644 index 0000000..761622c --- /dev/null +++ b/src/codebase_graph/retrieval/__init__.py @@ -0,0 +1,29 @@ +"""Keyword, vector, graph traversal, and ranking retrieval.""" + +from .block_format import ( + canonicalize_search_payload, + intentional_summary_omissions, + parse_search_block, + serialize_agent_search_block, + serialize_context_block, + serialize_graph_block, + serialize_parseable_search_block, + serialize_search_block, +) +from .search import DETAIL_LEVELS, CompactContextPayload, SearchHit, SearchRequest, SearchService + +__all__ = [ + "DETAIL_LEVELS", + "CompactContextPayload", + "SearchHit", + "SearchRequest", + "SearchService", + "canonicalize_search_payload", + "intentional_summary_omissions", + "parse_search_block", + "serialize_agent_search_block", + "serialize_context_block", + "serialize_graph_block", + "serialize_parseable_search_block", + "serialize_search_block", +] diff --git a/src/codebase_graph/retrieval/block_format.py b/src/codebase_graph/retrieval/block_format.py new file mode 100644 index 0000000..10c4020 --- /dev/null +++ b/src/codebase_graph/retrieval/block_format.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import json +import re +import shlex +from typing import Any, Mapping + + +SIMPLE_VALUE_RE = re.compile(r"^[A-Za-z0-9_./:\-\[\]]+$") +SPAN_RE = re.compile(r"^L(?P\d+)-L(?P\d+)$") +ONTOLOGY_TERMS = {"Class", "Method", "Scope", "Contains", "outgoing", "path", "span", "id", "label", "rank_score"} + + +def serialize_parseable_search_block(payload: Mapping[str, Any]) -> str: + """Serialize graph-search JSON into a parseable debug block format.""" + lines = [ + " | ".join( + [ + f"q {_format_value(str(payload.get('query', '')))}", + f"budget {payload.get('budget', '')}", + f"limit {payload.get('limit', '')}", + f"profile {_format_value(str(payload.get('profile', '')))}", + ] + ) + ] + current_path: str | None = None + previous_line_was_file = False + for result in payload.get("results", []): + result_path = str(result.get("path", "")) + if result_path != current_path: + if len(lines) > 1 and not previous_line_was_file: + lines.append("") + lines.append(f"file path {_format_value(result_path)}") + current_path = result_path + previous_line_was_file = True + else: + previous_line_was_file = False + + result_span = _span(result.get("span", {})) + result_parts = [ + f"- {result.get('type', '')}", + f"label={_format_value(str(result.get('label', '')))}", + f"span={_format_span(result_span)}", + ] + if "rank_score" in result: + result_parts.append(f"rank_score={result['rank_score']}") + if "id" in result: + result_parts.append(f"id={_format_value(str(result['id']))}") + summary = _meaningful_summary(result) + if summary: + result_parts.append(f"summary={_format_value(summary)}") + lines.append(" ".join(result_parts)) + + for context in result.get("context", []): + context_path = str(context.get("path", "")) + context_span = _span(context.get("span", {})) + span_text = "L=same" if context_span == result_span else _format_span(context_span) + context_parts = [ + f" {context.get('direction', '')}", + str(context.get("relation", "")), + str(context.get("type", "")), + f"label={_format_value(str(context.get('label', '')))}", + ] + if context_path and context_path != current_path: + context_parts.append(f"path={_format_value(context_path)}") + context_parts.append(f"span={span_text}") + context_summary = _meaningful_summary(context) + if context_summary: + context_parts.append(f"summary={_format_value(context_summary)}") + lines.append(" ".join(context_parts)) + previous_line_was_file = False + return "\n".join(lines) + "\n" + + +def serialize_agent_search_block(payload: Mapping[str, Any]) -> str: + """Serialize graph-search JSON into the compact runtime block format.""" + lines = [f"q {_format_value(str(payload.get('query', '')))}"] + current_path: str | None = None + result_keys = {_record_key(result) for result in payload.get("results", [])} + for result in payload.get("results", []): + result_path = str(result.get("path", "")) + if result_path != current_path: + if len(lines) > 1: + lines.append("") + lines.append(f"file path {_format_value(result_path)}") + current_path = result_path + + result_span = _span(result.get("span", {})) + result_parts = [ + f"- {result.get('type', '')}", + _format_value(str(result.get("label", ""))), + _format_span(result_span), + ] + if "rank_score" in result: + result_parts.append(f"rank_score={float(result['rank_score']):.2f}") + if "id" in result: + result_parts.append(f"id={_format_value(str(result['id']))}") + summary = _meaningful_summary(result) + if summary: + result_parts.append(f"summary={_format_value(summary)}") + lines.append(" ".join(result_parts)) + + for context in result.get("context", []): + if _omit_agent_context(context, parent_span=result_span, result_keys=result_keys): + continue + context_path = str(context.get("path", "")) + context_span = _span(context.get("span", {})) + context_parts = [ + f" {context.get('direction', '')}", + str(context.get("relation", "")), + str(context.get("type", "")), + _format_value(str(context.get("label", ""))), + _format_span(context_span), + ] + if context_path and context_path != current_path: + context_parts.append(f"path={_format_value(context_path)}") + context_summary = _meaningful_summary(context) + if context_summary: + context_parts.append(f"summary={_format_value(context_summary)}") + lines.append(" ".join(context_parts)) + return "\n".join(lines) + "\n" + + +def serialize_context_block(payload: Mapping[str, Any]) -> str: + """Serialize an explicit graph-context payload into a readable block.""" + header = [ + f"context {payload.get('node_type', '')}", + f"id={_format_value(str(payload.get('node_id', '')))}", + f"profile={_format_value(str(payload.get('profile', '')))}", + ] + lines = [" ".join(header)] + current_path: str | None = None + for context in payload.get("context", []): + context_path = str(context.get("path", "")) + if context_path != current_path: + if len(lines) > 1: + lines.append("") + lines.append(f"file path {_format_value(context_path)}") + current_path = context_path + context_parts = [ + f" {context.get('direction', '')}", + str(context.get("relation", "")), + str(context.get("type", "")), + _format_value(str(context.get("label", ""))), + _format_span(_span(context.get("span", {}))), + ] + context_summary = _meaningful_summary(context) + if context_summary: + context_parts.append(f"summary={_format_value(context_summary)}") + lines.append(" ".join(context_parts)) + return "\n".join(lines) + "\n" + + +def serialize_graph_block(payload: Mapping[str, Any]) -> str: + if "results" in payload: + return serialize_agent_search_block(payload) + if "context" in payload and "node_id" in payload and "node_type" in payload: + return serialize_context_block(payload) + raise ValueError("Block format is only supported for graph-search and graph-context payloads") + + +def serialize_search_block(payload: Mapping[str, Any]) -> str: + """Backward-compatible alias for the parseable debug block format.""" + return serialize_parseable_search_block(payload) + + +def canonicalize_search_payload(payload: Mapping[str, Any]) -> dict[str, Any]: + records: list[dict[str, Any]] = [] + for result in payload.get("results", []): + result_record = { + "type": result.get("type", ""), + "label": result.get("label", ""), + "path": result.get("path", ""), + "span": _span(result.get("span", {})), + "id": result.get("id", ""), + "rank_score": result.get("rank_score"), + "context": [], + } + result_summary = _meaningful_summary(result) + if result_summary: + result_record["summary"] = result_summary + for context in result.get("context", []): + context_record = { + "direction": context.get("direction", ""), + "relation": context.get("relation", ""), + "type": context.get("type", ""), + "label": context.get("label", ""), + "path": context.get("path", ""), + "span": _span(context.get("span", {})), + } + context_summary = _meaningful_summary(context) + if context_summary: + context_record["summary"] = context_summary + result_record["context"].append(context_record) + records.append(result_record) + return {"results": records} + + +def parse_search_block(text: str) -> dict[str, Any]: + records: list[dict[str, Any]] = [] + current_path = "" + current_result: dict[str, Any] | None = None + for raw_line in text.splitlines(): + if not raw_line.strip() or raw_line.startswith("q "): + continue + if raw_line.startswith("file path "): + current_path = _parse_value(raw_line[len("file path ") :]) + current_result = None + continue + if raw_line.startswith("- "): + tokens = shlex.split(raw_line) + fields = _keyed_fields(tokens[2:]) + current_result = { + "type": tokens[1], + "label": fields.get("label", ""), + "path": current_path, + "span": _parse_span(fields.get("span", "")), + "id": fields.get("id", ""), + "rank_score": _parse_number(fields.get("rank_score")), + "context": [], + } + if fields.get("summary"): + current_result["summary"] = fields["summary"] + records.append(current_result) + continue + if raw_line.startswith(" "): + if current_result is None: + raise ValueError(f"Context line has no parent result: {raw_line}") + tokens = shlex.split(raw_line.strip()) + fields = _keyed_fields(tokens[3:]) + span = current_result["span"] if fields.get("span") == "L=same" else _parse_span(fields.get("span", "")) + context_record = { + "direction": tokens[0], + "relation": tokens[1], + "type": tokens[2], + "label": fields.get("label", ""), + "path": fields.get("path", current_path), + "span": span, + } + if fields.get("summary"): + context_record["summary"] = fields["summary"] + current_result["context"].append(context_record) + continue + raise ValueError(f"Unknown block line: {raw_line}") + return {"results": records} + + +def intentional_summary_omissions(payload: Mapping[str, Any]) -> list[str]: + omissions: list[str] = [] + for result_index, result in enumerate(payload.get("results", [])): + if _is_boilerplate_summary(result): + omissions.append(f"results[{result_index}].summary") + for context_index, context in enumerate(result.get("context", [])): + if _is_boilerplate_summary(context): + omissions.append(f"results[{result_index}].context[{context_index}].summary") + return omissions + + +def _keyed_fields(tokens: list[str]) -> dict[str, str]: + fields: dict[str, str] = {} + for token in tokens: + if "=" not in token: + continue + key, value = token.split("=", 1) + fields[key] = value + return fields + + +def _format_value(value: str) -> str: + if value and SIMPLE_VALUE_RE.match(value): + return value + return json.dumps(value, ensure_ascii=True) + + +def _parse_value(value: str) -> str: + if value.startswith('"'): + return str(json.loads(value)) + return value + + +def _span(value: Any) -> dict[str, int]: + if not isinstance(value, Mapping): + return {} + span: dict[str, int] = {} + if value.get("line_start") is not None: + span["line_start"] = int(value["line_start"]) + if value.get("line_end") is not None: + span["line_end"] = int(value["line_end"]) + return span + + +def _format_span(span: Mapping[str, int]) -> str: + start = span.get("line_start") + end = span.get("line_end") + if start is None or end is None: + return "L?" + return f"L{start}-L{end}" + + +def _parse_span(value: str) -> dict[str, int]: + match = SPAN_RE.match(value) + if not match: + return {} + return {"line_start": int(match.group("start")), "line_end": int(match.group("end"))} + + +def _parse_number(value: str | None) -> int | float | None: + if value is None: + return None + try: + as_float = float(value) + except ValueError: + return None + return int(as_float) if as_float.is_integer() else as_float + + +def _meaningful_summary(record: Mapping[str, Any]) -> str: + summary = str(record.get("summary", "")) + return "" if _is_boilerplate_summary(record) else summary + + +def _is_boilerplate_summary(record: Mapping[str, Any]) -> bool: + summary = str(record.get("summary", "")) + label = str(record.get("label", "")) + node_type = str(record.get("type", "")) + if not summary or summary == label: + return bool(summary) + if node_type == "Scope" and label.endswith(" scope"): + scoped_label = label[: -len(" scope")] + return summary == f"Scope for {scoped_label}" + return False + + +def _omit_agent_context( + context: Mapping[str, Any], + *, + parent_span: Mapping[str, int], + result_keys: set[tuple[str, str, str, tuple[tuple[str, int], ...]]], +) -> bool: + context_span = _span(context.get("span", {})) + if _is_boilerplate_summary(context) and context_span == dict(parent_span): + return True + if _record_key(context) in result_keys: + return True + return context.get("type") == "TypeAnnotation" + + +def _record_key(record: Mapping[str, Any]) -> tuple[str, str, str, tuple[tuple[str, int], ...]]: + return ( + str(record.get("type", "")), + str(record.get("label", "")), + str(record.get("path", "")), + tuple(sorted(_span(record.get("span", {})).items())), + ) + + +__all__ = [ + "ONTOLOGY_TERMS", + "canonicalize_search_payload", + "intentional_summary_omissions", + "parse_search_block", + "serialize_context_block", + "serialize_agent_search_block", + "serialize_graph_block", + "serialize_parseable_search_block", + "serialize_search_block", +] diff --git a/src/codebase_graph/retrieval/search.py b/src/codebase_graph/retrieval/search.py new file mode 100644 index 0000000..73645f7 --- /dev/null +++ b/src/codebase_graph/retrieval/search.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from codebase_graph.db import SearchIndexRow, graph_query_adapter +from codebase_graph.ontology import CONTEXT_PROFILES, SEARCH_INDEXES +from codebase_graph.reasoning.context_builder import CompactContextBuilder, ContextNode, DEFAULT_CONTEXT_BUDGET, DEFAULT_CONTEXT_LIMIT + + +DEFAULT_SEARCH_LIMIT = 3 +MAX_CANDIDATE_LIMIT = 50 +MIN_CANDIDATE_LIMIT = 10 +DETAIL_LEVELS = {"standard", "slim"} +DEFINITION_TYPES = {"Class", "Function", "Method", "Variable", "Constant"} +GENERIC_TYPES = {"Symbol", "Dependency"} + + +@dataclass(frozen=True, slots=True) +class SearchRequest: + query: str + limit: int = DEFAULT_SEARCH_LIMIT + profile: str = "brief" + budget: int = DEFAULT_CONTEXT_BUDGET + max_depth: int | None = None + context_limit: int = DEFAULT_CONTEXT_LIMIT + detail: str = "standard" + + def validate(self) -> None: + if not self.query.strip(): + raise ValueError("Search query must not be empty") + if self.limit <= 0: + raise ValueError("Search limit must be greater than zero") + if self.budget < 0: + raise ValueError("Context budget must be zero or greater") + if self.max_depth is not None and self.max_depth < 0: + raise ValueError("Context max depth must be zero or greater") + if self.context_limit < 0: + raise ValueError("Context limit must be zero or greater") + _validate_detail(self.detail) + if self.profile not in CONTEXT_PROFILES: + valid = ", ".join(sorted(CONTEXT_PROFILES)) + raise ValueError(f"Unknown context profile: {self.profile}. Valid profiles: {valid}") + + +@dataclass(slots=True) +class SearchHit: + id: str + type: str + label: str + qualified_name: str = "" + path: str = "" + span: dict[str, int] = field(default_factory=dict) + score: float = 0.0 + rank_score: float = 0.0 + score_components: dict[str, float] = field(default_factory=dict) + summary: str = "" + context: list[ContextNode] = field(default_factory=list) + index_order: int = 0 + + def as_dict(self, *, detail: str = "standard") -> dict[str, Any]: + _validate_detail(detail) + if detail == "slim": + payload: dict[str, Any] = { + "id": self.id, + "type": self.type, + "label": self.label, + "rank_score": self.rank_score, + } + _set_non_empty(payload, "path", self.path) + _set_non_empty(payload, "span", dict(self.span)) + _set_meaningful_summary(payload, self.summary, self.label) + context = [node.as_dict(detail=detail) for node in self.context] + _set_non_empty(payload, "context", context) + return payload + return { + "id": self.id, + "type": self.type, + "label": self.label, + "qualified_name": self.qualified_name, + "path": self.path, + "span": dict(self.span), + "score": self.score, + "rank_score": self.rank_score, + "score_components": dict(self.score_components), + "summary": self.summary, + "context": [node.as_dict(detail=detail) for node in self.context], + } + + +@dataclass(frozen=True, slots=True) +class CompactContextPayload: + query: str + profile: str + limit: int + budget: int + results: tuple[SearchHit, ...] + + def as_dict(self, *, detail: str = "standard") -> dict[str, Any]: + _validate_detail(detail) + return { + "query": self.query, + "profile": self.profile, + "limit": self.limit, + "budget": self.budget, + "results": [hit.as_dict(detail=detail) for hit in self.results], + } + + +@dataclass(frozen=True, slots=True) +class FTSIndexSpec: + node_type: str + index_name: str + order: int + + +class SearchService: + def __init__(self, store: Any) -> None: + self.store = store + self.query = graph_query_adapter(store) + self.indexes = tuple(_fts_index_specs()) + + def search(self, request: SearchRequest) -> CompactContextPayload: + request.validate() + candidate_limit = _candidate_limit(request.limit) + hits = self._rank_hits( + self._query_fts(request.query, candidate_limit), + query=request.query, + profile=request.profile, + ) + context_builder = CompactContextBuilder(self.store) + compact_hits: list[SearchHit] = [] + for hit in hits[: request.limit]: + hit.context = context_builder.build( + hit.id, + hit.type, + profile=request.profile, + limit=request.context_limit, + budget=request.budget, + max_depth=request.max_depth, + ) + compact_hits.append(hit) + return CompactContextPayload( + query=request.query, + profile=request.profile, + limit=request.limit, + budget=request.budget, + results=tuple(compact_hits), + ) + + def _query_fts(self, query: str, limit: int) -> list[SearchHit]: + hits: list[SearchHit] = [] + for spec in self.indexes: + hits.extend( + _hit_from_index_row(row, spec) + for row in self.query.search_index( + node_type=spec.node_type, + index_name=spec.index_name, + query=query, + limit=limit, + ) + ) + return hits + + def _rank_hits(self, hits: list[SearchHit], *, query: str = "", profile: str = "brief") -> list[SearchHit]: + best_by_id: dict[str, SearchHit] = {} + for hit in hits: + previous = best_by_id.get(hit.id) + if previous is None or _raw_hit_sort_key(hit) < _raw_hit_sort_key(previous): + best_by_id[hit.id] = hit + deduped = list(best_by_id.values()) + _assign_rank_scores(deduped, query=query, profile=profile) + return sorted(deduped, key=_ranked_hit_sort_key) + + +def _fts_index_specs() -> list[FTSIndexSpec]: + specs: list[FTSIndexSpec] = [] + order = 0 + for index in SEARCH_INDEXES: + index_name = str(index["name"]) + for node_type in index["node_types"]: + specs.append(FTSIndexSpec(node_type=str(node_type), index_name=f"{index_name}_{node_type}", order=order)) + order += 1 + return specs + + +def _hit_from_index_row(row: SearchIndexRow, spec: FTSIndexSpec) -> SearchHit: + return SearchHit( + id=row.id, + type=spec.node_type, + label=row.label, + qualified_name=row.qualified_name, + path=row.path, + span=_span(row.line_start, row.line_end), + summary=row.summary, + score=row.score, + index_order=spec.order, + ) + + +def _assign_rank_scores(hits: list[SearchHit], *, query: str, profile: str) -> None: + if not hits: + return + max_score = max((hit.score for hit in hits), default=0.0) + concrete_labels = { + _normalize(hit.label) + for hit in hits + if hit.type in DEFINITION_TYPES and hit.label + } + intent = _query_intent(query, profile) + + for hit in hits: + fts_score = hit.score / max_score if max_score > 0 else 0.0 + lexical_score = _lexical_score(query, hit) + type_score = _type_score(hit.type, intent) + generic_penalty = _generic_penalty(hit, concrete_labels) + rank_score = (0.45 * fts_score) + (0.35 * lexical_score) + type_score - generic_penalty + hit.score_components = { + "fts": round(fts_score, 6), + "lexical": round(lexical_score, 6), + "type": round(type_score, 6), + "generic_penalty": round(generic_penalty, 6), + } + hit.rank_score = round(rank_score, 6) + + +def _candidate_limit(limit: int) -> int: + return min(max(limit * 4, MIN_CANDIDATE_LIMIT), MAX_CANDIDATE_LIMIT) + + +def _query_intent(query: str, profile: str) -> str: + if profile in {"dependencies", "runtime", "docs"}: + return profile + if _looks_like_path(query): + return "path" + if _looks_like_identifier(query): + return "definition" + return "general" + + +def _lexical_score(query: str, hit: SearchHit) -> float: + normalized_query = _normalize(query) + if not normalized_query: + return 0.0 + label = _normalize(hit.label) + qualified_name = _normalize(hit.qualified_name) + path = _normalize(hit.path) + + if label == normalized_query: + return 1.0 + if qualified_name == normalized_query: + return 0.95 + if qualified_name.endswith(f".{normalized_query}") or qualified_name.endswith(f"/{normalized_query}"): + return 0.85 + if path == normalized_query or path.endswith(f"/{normalized_query}"): + return 0.8 + if normalized_query in label: + return 0.55 + if normalized_query in qualified_name: + return 0.45 + if normalized_query in path: + return 0.35 + return 0.0 + + +def _type_score(node_type: str, intent: str) -> float: + if intent == "definition": + if node_type in {"Class", "Function", "Method"}: + return 0.7 + if node_type in {"Variable", "Constant"}: + return 0.6 + if node_type == "Module": + return 0.2 + return 0.0 + if intent == "path": + return {"File": 0.7, "Module": 0.6, "SourceRoot": 0.25, "Repository": 0.2}.get(node_type, 0.0) + if intent == "dependencies": + return {"Dependency": 0.7, "ImportDeclaration": 0.65, "Module": 0.2}.get(node_type, 0.0) + if intent == "runtime": + return {"APIEndpoint": 0.7, "Route": 0.65, "Component": 0.55, "Query": 0.45, "SecretRef": 0.35}.get(node_type, 0.0) + if intent == "docs": + return {"DocumentationSource": 0.7, "DocumentationChunk": 0.65}.get(node_type, 0.0) + if node_type in DEFINITION_TYPES: + return 0.25 + return 0.0 + + +def _generic_penalty(hit: SearchHit, concrete_labels: set[str]) -> float: + if hit.type in GENERIC_TYPES and _normalize(hit.label) in concrete_labels: + return 0.45 + return 0.0 + + +def _looks_like_identifier(query: str) -> bool: + cleaned = query.strip() + return cleaned.replace("_", "").isalnum() and not cleaned[0:1].isdigit() + + +def _looks_like_path(query: str) -> bool: + cleaned = query.strip() + return "/" in cleaned or "\\" in cleaned or cleaned.endswith((".py", ".toml", ".md", ".json", ".yaml", ".yml")) + + +def _normalize(value: str) -> str: + return value.strip().lower() + + +def _ranked_hit_sort_key(hit: SearchHit) -> tuple[float, int, str, str, str]: + return (-hit.rank_score, hit.index_order, hit.type, hit.path, hit.label) + + +def _raw_hit_sort_key(hit: SearchHit) -> tuple[float, int, str, str, str]: + return (-hit.score, hit.index_order, hit.type, hit.path, hit.label) + + +def _span(line_start: Any, line_end: Any) -> dict[str, int]: + span: dict[str, int] = {} + if line_start is not None: + span["line_start"] = int(line_start) + if line_end is not None: + span["line_end"] = int(line_end) + return span + + +def _validate_detail(detail: str) -> None: + if detail not in DETAIL_LEVELS: + valid = ", ".join(sorted(DETAIL_LEVELS)) + raise ValueError(f"Unknown detail level: {detail}. Valid levels: {valid}") + + +def _set_non_empty(payload: dict[str, Any], key: str, value: Any) -> None: + if value not in ("", None, [], {}): + payload[key] = value + + +def _set_meaningful_summary(payload: dict[str, Any], summary: str, label: str) -> None: + if summary and summary != label: + payload["summary"] = summary + + +__all__ = [ + "CompactContextPayload", + "DETAIL_LEVELS", + "DEFAULT_SEARCH_LIMIT", + "FTSIndexSpec", + "MAX_CANDIDATE_LIMIT", + "SearchHit", + "SearchRequest", + "SearchService", +] diff --git a/src/codebase_graph/setup/__init__.py b/src/codebase_graph/setup/__init__.py new file mode 100644 index 0000000..eba8a55 --- /dev/null +++ b/src/codebase_graph/setup/__init__.py @@ -0,0 +1,54 @@ +"""Production setup orchestration for repository graph bootstrapping.""" + +from importlib import import_module +from typing import Any + +from .state import ( + CONFIG_NAME, + DEFAULT_STATE_DIR, + GraphStatePaths, + MANIFEST_NAME, + SetupPaths, + derive_graph_state_paths, + derive_setup_paths, + load_setup_config, +) + +_LAZY_EXPORTS = { + "McpInstallOptions": (".installer", "McpInstallOptions"), + "McpInstallResult": (".installer", "McpInstallResult"), + "SetupError": (".orchestrator", "SetupError"), + "SetupOptions": (".orchestrator", "SetupOptions"), + "SetupResult": (".orchestrator", "SetupResult"), + "install_mcp_clients": (".installer", "install_mcp_clients"), + "install_mcp_server": (".installer", "install_mcp_server"), + "run_setup": (".orchestrator", "run_setup"), +} + +__all__ = [ + "CONFIG_NAME", + "DEFAULT_STATE_DIR", + "GraphStatePaths", + "MANIFEST_NAME", + "McpInstallOptions", + "McpInstallResult", + "SetupError", + "SetupOptions", + "SetupPaths", + "SetupResult", + "derive_graph_state_paths", + "derive_setup_paths", + "load_setup_config", + "install_mcp_clients", + "install_mcp_server", + "run_setup", +] + + +def __getattr__(name: str) -> Any: + if name not in _LAZY_EXPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + module_name, attribute_name = _LAZY_EXPORTS[name] + value = getattr(import_module(module_name, __name__), attribute_name) + globals()[name] = value + return value diff --git a/src/codebase_graph/setup/clients/__init__.py b/src/codebase_graph/setup/clients/__init__.py new file mode 100644 index 0000000..a482497 --- /dev/null +++ b/src/codebase_graph/setup/clients/__init__.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from .base import ClientConfigAdapter, RenderedClientConfig +from .codex import CodexAdapter +from .hermes import HermesAdapter +from .json_clients import ClaudeAdapter, ClaudeProjectAdapter, GenericAdapter, LmStudioAdapter, OpenClawAdapter + +ADAPTERS: dict[str, ClientConfigAdapter] = { + adapter.client_id: adapter + for adapter in ( + CodexAdapter(), + ClaudeAdapter(), + ClaudeProjectAdapter(), + LmStudioAdapter(), + HermesAdapter(), + OpenClawAdapter(), + GenericAdapter(), + ) +} + + +def get_client_adapter(client_id: str) -> ClientConfigAdapter: + try: + return ADAPTERS[client_id] + except KeyError as exc: + supported = ", ".join(sorted([*ADAPTERS, "none"])) + raise ValueError(f"Unsupported MCP client: {client_id}. Supported clients: {supported}") from exc + + +def supported_client_ids() -> tuple[str, ...]: + return tuple(sorted([*ADAPTERS, "none"])) + + +__all__ = ["ADAPTERS", "ClientConfigAdapter", "RenderedClientConfig", "get_client_adapter", "supported_client_ids"] diff --git a/src/codebase_graph/setup/clients/base.py b/src/codebase_graph/setup/clients/base.py new file mode 100644 index 0000000..53db3e3 --- /dev/null +++ b/src/codebase_graph/setup/clients/base.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Protocol + +from codebase_graph.setup.descriptor import McpServerDescriptor + + +@dataclass(frozen=True, slots=True) +class RenderedClientConfig: + text: str + action: str + entry: dict[str, Any] + patch: Any + payload: Any + + +class ClientConfigAdapter(Protocol): + client_id: str + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + ... + + def render(self, existing_text: str | None, descriptor: McpServerDescriptor) -> RenderedClientConfig: + ... + + +def action_for_server(previous: Any, next_value: Any, *, file_exists: bool) -> str: + if previous is None: + return "created" + if previous == next_value: + return "unchanged" + return "updated" diff --git a/src/codebase_graph/setup/clients/codex.py b/src/codebase_graph/setup/clients/codex.py new file mode 100644 index 0000000..abdde9a --- /dev/null +++ b/src/codebase_graph/setup/clients/codex.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import os +import re +from pathlib import Path +from typing import Any + +from codebase_graph.setup.descriptor import McpServerDescriptor + +from .base import RenderedClientConfig + + +class CodexAdapter: + client_id = "codex" + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + base = Path(os.environ.get("CODEX_HOME", Path.home() / ".codex")) + return base / "config.toml" + + def render(self, existing_text: str | None, descriptor: McpServerDescriptor) -> RenderedClientConfig: + entry = descriptor.stdio_entry(include_timeout=True) + patch = _toml_block(descriptor, entry) + next_text, previous = _upsert_toml_block(existing_text or "", descriptor.name, patch) + if previous is None: + action = "created" + elif previous == patch.rstrip(): + action = "unchanged" + else: + action = "updated" + if existing_text == next_text: + action = "unchanged" + return RenderedClientConfig(text=next_text, action=action, entry=entry, patch=patch, payload=patch) + + +def _upsert_toml_block(existing: str, server_name: str, block: str) -> tuple[str, str | None]: + lines = existing.splitlines() + start: int | None = None + end = len(lines) + header_re = re.compile(rf"^\[mcp_servers\.{re.escape(server_name)}(?:\.env)?\]\s*$") + any_header_re = re.compile(r"^\[[^\]]+\]\s*$") + for index, line in enumerate(lines): + if header_re.match(line): + start = index + break + if start is None: + prefix = existing.rstrip() + separator = "\n\n" if prefix else "" + return f"{prefix}{separator}{block}", None + for index in range(start + 1, len(lines)): + if any_header_re.match(lines[index]) and not header_re.match(lines[index]): + end = index + break + previous = "\n".join(lines[start:end]).rstrip() + next_lines = [*lines[:start], *block.rstrip().splitlines(), *lines[end:]] + return "\n".join(next_lines).rstrip() + "\n", previous + + +def _toml_block(descriptor: McpServerDescriptor, entry: dict[str, Any]) -> str: + lines = [ + f"[mcp_servers.{descriptor.name}]", + f"command = {_toml_string(entry['command'])}", + f"args = {_toml_array(entry['args'])}", + f"startup_timeout_sec = {int(entry['startup_timeout_sec'])}", + ] + if descriptor.cwd: + lines.append(f"cwd = {_toml_string(descriptor.cwd)}") + if descriptor.env: + lines.append("") + lines.append(f"[mcp_servers.{descriptor.name}.env]") + for key, value in sorted(descriptor.env.items()): + lines.append(f"{key} = {_toml_string(value)}") + return "\n".join(lines) + "\n" + + +def _toml_array(values: list[str]) -> str: + return "[" + ", ".join(_toml_string(value) for value in values) + "]" + + +def _toml_string(value: str) -> str: + escaped = value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' diff --git a/src/codebase_graph/setup/clients/hermes.py b/src/codebase_graph/setup/clients/hermes.py new file mode 100644 index 0000000..015ade4 --- /dev/null +++ b/src/codebase_graph/setup/clients/hermes.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from codebase_graph.setup.descriptor import McpServerDescriptor + +from .base import RenderedClientConfig + +START_MARKER = "# codebaseGraph MCP server start" +END_MARKER = "# codebaseGraph MCP server end" + + +class HermesAdapter: + client_id = "hermes" + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + return Path.home() / ".hermes" / "config.yaml" + + def render(self, existing_text: str | None, descriptor: McpServerDescriptor) -> RenderedClientConfig: + entry = descriptor.stdio_entry(include_type=True) + patch = _yaml_block(descriptor, entry) + next_text, previous = _upsert_marked_block(existing_text or "", patch) + if previous is None: + action = "created" + elif previous == patch.rstrip(): + action = "unchanged" + else: + action = "updated" + if existing_text == next_text: + action = "unchanged" + return RenderedClientConfig(text=next_text, action=action, entry=entry, patch=patch, payload=patch) + + +def _upsert_marked_block(existing: str, block: str) -> tuple[str, str | None]: + start = existing.find(START_MARKER) + end = existing.find(END_MARKER) + if start == -1 or end == -1 or end < start: + prefix = existing.rstrip() + separator = "\n\n" if prefix else "" + return f"{prefix}{separator}{block}", None + after_end = end + len(END_MARKER) + previous = existing[start:after_end].rstrip() + next_text = existing[:start].rstrip() + "\n\n" + block.rstrip() + "\n\n" + existing[after_end:].lstrip() + return next_text.rstrip() + "\n", previous + + +def _yaml_block(descriptor: McpServerDescriptor, entry: dict[str, Any]) -> str: + lines = [ + START_MARKER, + "mcp_servers:", + f" {descriptor.name}:", + " type: stdio", + f" command: {_yaml_scalar(entry['command'])}", + " args:", + ] + for arg in entry["args"]: + lines.append(f" - {_yaml_scalar(arg)}") + if descriptor.cwd: + lines.append(f" cwd: {_yaml_scalar(descriptor.cwd)}") + if descriptor.env: + lines.append(" env:") + for key, value in sorted(descriptor.env.items()): + lines.append(f" {key}: {_yaml_scalar(value)}") + lines.append(END_MARKER) + return "\n".join(lines) + "\n" + + +def _yaml_scalar(value: str) -> str: + escaped = value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' diff --git a/src/codebase_graph/setup/clients/json_clients.py b/src/codebase_graph/setup/clients/json_clients.py new file mode 100644 index 0000000..7886be1 --- /dev/null +++ b/src/codebase_graph/setup/clients/json_clients.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +import os +from copy import deepcopy +from pathlib import Path +from typing import Any + +from codebase_graph.setup.descriptor import McpServerDescriptor + +from .base import RenderedClientConfig, action_for_server + + +class JsonMcpServersAdapter: + client_id = "generic" + include_type = True + root_path = ("mcpServers",) + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + return Path.home() / ".config" / "mcp" / "mcp.json" + + def entry(self, descriptor: McpServerDescriptor) -> dict[str, Any]: + return descriptor.stdio_entry(include_type=self.include_type) + + def render(self, existing_text: str | None, descriptor: McpServerDescriptor) -> RenderedClientConfig: + payload = _read_json_text(existing_text) + next_payload = deepcopy(payload) + container = _container(next_payload, self.root_path) + entry = self.entry(descriptor) + previous = container.get(descriptor.name) + container[descriptor.name] = entry + action = action_for_server(previous, entry, file_exists=existing_text is not None) + text = json.dumps(next_payload, indent=2, sort_keys=True) + "\n" + if existing_text == text: + action = "unchanged" + return RenderedClientConfig(text=text, action=action, entry=entry, patch=next_payload, payload=next_payload) + + +class ClaudeAdapter(JsonMcpServersAdapter): + client_id = "claude" + include_type = False + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + mac_path = Path.home() / "Library" / "Application Support" / "Claude" / "claude_desktop_config.json" + if mac_path.parent.exists(): + return mac_path + return Path.home() / ".config" / "claude" / "claude_desktop_config.json" + + +class ClaudeProjectAdapter(JsonMcpServersAdapter): + client_id = "claude-project" + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + if descriptor.repo_root: + return Path(descriptor.repo_root) / ".mcp.json" + return Path.cwd() / ".mcp.json" + + +class LmStudioAdapter(JsonMcpServersAdapter): + client_id = "lmstudio" + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + return Path.home() / ".lmstudio" / "mcp.json" + + +class GenericAdapter(JsonMcpServersAdapter): + client_id = "generic" + include_type = False + + +class OpenClawAdapter(JsonMcpServersAdapter): + client_id = "openclaw" + root_path = ("mcp", "servers") + + def default_config_path(self, descriptor: McpServerDescriptor) -> Path: + return Path(os.environ.get("OPENCLAW_HOME", Path.home() / ".openclaw")) / "mcp.json5" + + +def _read_json_text(existing_text: str | None) -> dict[str, Any]: + if existing_text is None or not existing_text.strip(): + return {} + payload = json.loads(existing_text) + if not isinstance(payload, dict): + raise ValueError("MCP config must contain a JSON object") + return payload + + +def _container(payload: dict[str, Any], path: tuple[str, ...]) -> dict[str, Any]: + cursor = payload + for key in path: + next_value = cursor.setdefault(key, {}) + if not isinstance(next_value, dict): + raise ValueError(f"MCP config key must contain an object: {'.'.join(path)}") + cursor = next_value + return cursor diff --git a/src/codebase_graph/setup/descriptor.py b/src/codebase_graph/setup/descriptor.py new file mode 100644 index 0000000..781ac76 --- /dev/null +++ b/src/codebase_graph/setup/descriptor.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import os +import shutil +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Mapping + +from .state import MCP_SERVER_NAME + + +@dataclass(frozen=True, slots=True) +class McpServerDescriptor: + name: str + transport: str + command: str + args: tuple[str, ...] + env: Mapping[str, str] = field(default_factory=dict) + cwd: str | None = None + setup_config_path: str | None = None + repo_root: str | None = None + timeout: int = 60 + tool_policy: str | None = "graph_query_read_only" + + def as_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "name": self.name, + "transport": self.transport, + "command": self.command, + "args": list(self.args), + "env": dict(sorted(self.env.items())), + "cwd": self.cwd, + "setup_config_path": self.setup_config_path, + "repo_root": self.repo_root, + "timeout": self.timeout, + } + if self.tool_policy: + payload["tool_policy"] = self.tool_policy + return payload + + def stdio_entry(self, *, include_type: bool = False, include_timeout: bool = False) -> dict[str, Any]: + entry: dict[str, Any] = {"command": self.command, "args": list(self.args)} + if include_type: + entry["type"] = "stdio" + if self.env: + entry["env"] = dict(sorted(self.env.items())) + if self.cwd: + entry["cwd"] = self.cwd + if include_timeout: + entry["startup_timeout_sec"] = self.timeout + return entry + + +def build_server_descriptor( + setup_config_path: Path, + *, + repo_root: Path | None = None, + name: str = MCP_SERVER_NAME, + timeout: int = 60, +) -> McpServerDescriptor: + config_path = setup_config_path.expanduser().resolve() + resolved_repo_root = repo_root.expanduser().resolve() if repo_root is not None else config_path.parent.parent + return McpServerDescriptor( + name=name, + transport="stdio", + command=resolve_server_command(), + args=("mcp", "serve", "--config", config_path.as_posix()), + env={}, + cwd=None, + setup_config_path=config_path.as_posix(), + repo_root=resolved_repo_root.as_posix(), + timeout=timeout, + ) + + +def resolve_server_command() -> str: + sibling_script = Path(sys.executable).with_name("codebase-graph") + if sibling_script.exists() and os.access(sibling_script, os.X_OK): + return sibling_script.as_posix() + return shutil.which("codebase-graph") or "codebase-graph" diff --git a/src/codebase_graph/setup/installer.py b/src/codebase_graph/setup/installer.py new file mode 100644 index 0000000..340ec10 --- /dev/null +++ b/src/codebase_graph/setup/installer.py @@ -0,0 +1,552 @@ +from __future__ import annotations + +import json +import os +import re +import shutil +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable + +from codebase_graph.mcp.protocol import LATEST_PROTOCOL_VERSION + +from .clients import get_client_adapter +from .descriptor import McpServerDescriptor, build_server_descriptor +from .state import MCP_SERVER_NAME, load_setup_config + +SCOPES = ("local", "user", "project") +NativeCommandBuilder = Callable[[McpServerDescriptor, str], list[str]] +VisibilityCommandBuilder = Callable[[], list[str]] + + +@dataclass(frozen=True, slots=True) +class McpInstallOptions: + client: str = "codex" + scope: str = "local" + setup_config_path: str | Path = ".codebaseGraph/config.json" + server_name: str | None = None + client_config_path: str | Path | None = None + dry_run: bool = False + verify: bool = False + skip: bool = False + prefer_native: bool = True + require_setup_config: bool = True + + +@dataclass(frozen=True, slots=True) +class McpInstallResult: + action: str + client: str + scope: str + server_name: str + method: str | None + path: str | None + command: list[str] | None + descriptor: dict[str, Any] + entry: dict[str, Any] + patch: Any = None + payload: Any = None + verification: dict[str, Any] | None = None + error: str | None = None + native_command: list[str] | None = None + native_error: str | None = None + + def as_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "action": self.action, + "client": self.client, + "scope": self.scope, + "server_name": self.server_name, + "method": self.method, + "path": self.path, + "command": self.command, + "descriptor": self.descriptor, + "entry": self.entry, + } + if self.patch is not None: + payload["patch"] = self.patch + if self.payload is not None: + payload["payload"] = self.payload + if self.verification is not None: + payload["verification"] = self.verification + if self.error is not None: + payload["error"] = self.error + if self.native_command is not None: + payload["native_command"] = self.native_command + if self.native_error is not None: + payload["native_error"] = self.native_error + return payload + + +@dataclass(frozen=True, slots=True) +class InstallClientStrategy: + client_id: str + adapter_id: str | None = None + project_adapter_id: str | None = None + forced_scope: str | None = None + native_executable: str | None = None + native_command_builder: NativeCommandBuilder | None = None + visibility_command_builder: VisibilityCommandBuilder | None = None + + def install_scope(self, scope: str) -> str: + return self.forced_scope or scope + + def adapter_client_id(self, scope: str) -> str: + if self.project_adapter_id is not None and self.install_scope(scope) == "project": + return self.project_adapter_id + return self.adapter_id or self.client_id + + def native_command(self, descriptor: McpServerDescriptor, *, scope: str) -> list[str] | None: + if self.native_command_builder is None: + return None + return self.native_command_builder(descriptor, self.install_scope(scope)) + + def visibility_command(self) -> list[str] | None: + if self.visibility_command_builder is None: + return None + return self.visibility_command_builder() + + +def _codex_native_command(descriptor: McpServerDescriptor, scope: str) -> list[str]: + return ["codex", "mcp", "add", descriptor.name, "--", descriptor.command, *descriptor.args] + + +def _claude_native_command(descriptor: McpServerDescriptor, scope: str) -> list[str]: + return [ + "claude", + "mcp", + "add", + "--transport", + "stdio", + "--scope", + scope, + descriptor.name, + "--", + descriptor.command, + *descriptor.args, + ] + + +def _openclaw_native_command(descriptor: McpServerDescriptor, scope: str) -> list[str]: + entry = descriptor.stdio_entry(include_type=True) + return ["openclaw", "mcp", "set", descriptor.name, json.dumps(entry, separators=(",", ":"), sort_keys=True)] + + +INSTALL_STRATEGIES: dict[str, InstallClientStrategy] = { + "codex": InstallClientStrategy( + client_id="codex", + native_executable="codex", + native_command_builder=_codex_native_command, + visibility_command_builder=lambda: ["codex", "mcp", "list"], + ), + "claude": InstallClientStrategy( + client_id="claude", + project_adapter_id="claude-project", + native_executable="claude", + native_command_builder=_claude_native_command, + visibility_command_builder=lambda: ["claude", "mcp", "list"], + ), + "claude-project": InstallClientStrategy( + client_id="claude-project", + forced_scope="project", + native_executable="claude", + native_command_builder=_claude_native_command, + visibility_command_builder=lambda: ["claude", "mcp", "list"], + ), + "lmstudio": InstallClientStrategy(client_id="lmstudio"), + "hermes": InstallClientStrategy(client_id="hermes"), + "openclaw": InstallClientStrategy( + client_id="openclaw", + native_executable="openclaw", + native_command_builder=_openclaw_native_command, + visibility_command_builder=lambda: ["openclaw", "mcp", "list"], + ), + "generic": InstallClientStrategy(client_id="generic"), +} +INSTALL_CLIENTS = tuple(INSTALL_STRATEGIES) + + +def supported_install_client_ids(*, include_all: bool = False) -> tuple[str, ...]: + values = [*INSTALL_CLIENTS] + if include_all: + values.append("all") + return tuple(sorted(values)) + + +def default_server_name(repo_name: str | None) -> str: + safe_repo_name = _safe_name(repo_name or "repository") + return f"{MCP_SERVER_NAME}_{safe_repo_name}" + + +def install_mcp_clients(options: McpInstallOptions) -> list[McpInstallResult]: + if options.client == "all": + return [_install_with_failure_result(options, client) for client in INSTALL_CLIENTS] + return [install_mcp_server(options)] + + +def install_mcp_server(options: McpInstallOptions) -> McpInstallResult: + _validate_options(options) + strategy = _client_strategy(options.client) + descriptor = _build_descriptor(options) + entry = descriptor.stdio_entry() + if options.skip or options.client == "none": + return McpInstallResult( + action="skipped", + client=options.client, + scope=options.scope, + server_name=descriptor.name, + method=None, + path=None, + command=None, + descriptor=descriptor.as_dict(), + entry=entry, + ) + + native_command = strategy.native_command(descriptor, scope=options.scope) + use_native = ( + options.prefer_native + and options.client_config_path is None + and native_command is not None + and strategy.native_executable is not None + and shutil.which(strategy.native_executable) + ) + if options.dry_run: + if use_native: + return _native_result("dry_run", options, descriptor, native_command, verification=None) + return _file_adapter_result(options, descriptor, dry_run=True, native_command=native_command) + + if use_native and native_command is not None: + try: + completed = subprocess.run(native_command, capture_output=True, text=True, check=False, timeout=30) + except subprocess.TimeoutExpired as exc: + native_error = f"timed out after {exc.timeout}s" + except OSError as exc: + native_error = str(exc) + else: + if completed.returncode == 0: + result = _native_result("updated", options, descriptor, native_command, verification=None) + return _with_verification(result, descriptor, options.verify) + native_error = _subprocess_error(completed) + return _file_adapter_result( + options, + descriptor, + dry_run=False, + native_command=native_command, + native_error=native_error, + ) + + return _file_adapter_result( + options, + descriptor, + dry_run=False, + native_command=native_command, + native_error=_missing_native_error(strategy) if native_command is not None else None, + ) + + +def _install_with_failure_result(options: McpInstallOptions, client: str) -> McpInstallResult: + client_options = McpInstallOptions( + client=client, + scope=_client_strategy(client).install_scope(options.scope), + setup_config_path=options.setup_config_path, + server_name=options.server_name, + client_config_path=options.client_config_path, + dry_run=options.dry_run, + verify=options.verify, + skip=options.skip, + prefer_native=options.prefer_native, + require_setup_config=options.require_setup_config, + ) + try: + return install_mcp_server(client_options) + except Exception as exc: + try: + descriptor = _build_descriptor(client_options) + entry = descriptor.stdio_entry() + descriptor_payload = descriptor.as_dict() + server_name = descriptor.name + except Exception: + entry = {} + descriptor_payload = {} + server_name = client_options.server_name or MCP_SERVER_NAME + return McpInstallResult( + action="failed", + client=client, + scope=client_options.scope, + server_name=server_name, + method=None, + path=None, + command=None, + descriptor=descriptor_payload, + entry=entry, + error=str(exc), + ) + + +def _file_adapter_result( + options: McpInstallOptions, + descriptor: McpServerDescriptor, + *, + dry_run: bool, + native_command: list[str] | None = None, + native_error: str | None = None, +) -> McpInstallResult: + adapter = get_client_adapter(_client_strategy(options.client).adapter_client_id(options.scope)) + path = ( + Path(options.client_config_path).expanduser().resolve() + if options.client_config_path is not None + else adapter.default_config_path(descriptor) + ) + existing_text = path.read_text(encoding="utf-8") if path.exists() else None + rendered = adapter.render(existing_text, descriptor) + action = "dry_run" if dry_run else rendered.action + if not dry_run: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + handle.write(rendered.text) + os.replace(tmp_path, path) + result = McpInstallResult( + action=action, + client=options.client, + scope=options.scope, + server_name=descriptor.name, + method="file_adapter", + path=path.as_posix(), + command=None, + descriptor=descriptor.as_dict(), + entry=rendered.entry, + patch=rendered.patch, + payload=rendered.payload, + native_command=native_command, + native_error=native_error, + ) + return _with_verification(result, descriptor, options.verify and not dry_run) + + +def _native_result( + action: str, + options: McpInstallOptions, + descriptor: McpServerDescriptor, + command: list[str], + *, + verification: dict[str, Any] | None, +) -> McpInstallResult: + return McpInstallResult( + action=action, + client=options.client, + scope=options.scope, + server_name=descriptor.name, + method="native_cli", + path=None, + command=command, + descriptor=descriptor.as_dict(), + entry=descriptor.stdio_entry(), + verification=verification, + ) + + +def _with_verification( + result: McpInstallResult, + descriptor: McpServerDescriptor, + enabled: bool, +) -> McpInstallResult: + if not enabled: + return result + verification = verify_mcp_install(descriptor, client=result.client, server_name=result.server_name) + return McpInstallResult( + action=result.action, + client=result.client, + scope=result.scope, + server_name=result.server_name, + method=result.method, + path=result.path, + command=result.command, + descriptor=result.descriptor, + entry=result.entry, + patch=result.patch, + payload=result.payload, + verification=verification, + error=result.error, + native_command=result.native_command, + native_error=result.native_error, + ) + + +def verify_mcp_install( + descriptor: McpServerDescriptor, + *, + client: str, + server_name: str, + timeout: int = 10, +) -> dict[str, Any]: + stdio = _verify_stdio(descriptor, timeout=timeout) + visibility = _verify_client_visibility(client, server_name, timeout=timeout) + return { + "ok": bool(stdio.get("ok")) and bool(visibility.get("ok", True)), + "stdio": stdio, + "client_visibility": visibility, + } + + +def _verify_stdio(descriptor: McpServerDescriptor, *, timeout: int) -> dict[str, Any]: + command = [descriptor.command, *descriptor.args] + payload = b"".join( + _frame_json_rpc(method, params, request_id=index) + for index, (method, params) in enumerate( + ( + ("initialize", {"protocolVersion": LATEST_PROTOCOL_VERSION}), + ("tools/list", {}), + ("tools/call", {"name": "graph_health", "arguments": {}}), + ("tools/call", {"name": "graph_search", "arguments": {"query": descriptor.name, "limit": 1}}), + ("tools/call", {"name": "graph_query", "arguments": {"statement": "MATCH (n) DELETE n"}}), + ), + start=1, + ) + ) + try: + process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate(payload, timeout=timeout) + except subprocess.TimeoutExpired: + process.kill() # type: ignore[possibly-unbound] + return {"ok": False, "command": command, "error": f"stdio smoke timed out after {timeout}s"} + except OSError as exc: + return {"ok": False, "command": command, "error": str(exc)} + responses = _parse_stdio_frames(stdout) + if process.returncode != 0: + return { + "ok": False, + "command": command, + "returncode": process.returncode, + "stderr": stderr.decode("utf-8", errors="replace"), + "responses": responses, + } + checks = _stdio_checks(responses) + return { + "ok": all(checks.values()), + "command": command, + "checks": checks, + "responses": responses, + "stderr": stderr.decode("utf-8", errors="replace"), + } + + +def _verify_client_visibility(client: str, server_name: str, *, timeout: int) -> dict[str, Any]: + command = _client_strategy(client).visibility_command() + if command is None: + return {"ok": True, "skipped": True, "reason": f"{client} has no CLI visibility check"} + executable = command[0] + if shutil.which(executable) is None: + return {"ok": True, "skipped": True, "reason": f"{executable} executable not found"} + completed = subprocess.run(command, capture_output=True, text=True, check=False, timeout=timeout) + output = f"{completed.stdout}\n{completed.stderr}" + return { + "ok": completed.returncode == 0 and server_name in output, + "command": command, + "returncode": completed.returncode, + "found": server_name in output, + "stdout": completed.stdout, + "stderr": completed.stderr, + } + + +def _stdio_checks(responses: list[dict[str, Any]]) -> dict[str, bool]: + by_id = {response.get("id"): response for response in responses} + initialized = by_id.get(1, {}).get("result", {}).get("protocolVersion") == LATEST_PROTOCOL_VERSION + tools = by_id.get(2, {}).get("result", {}).get("tools", []) + listed = {"graph_health", "graph_search"}.issubset({tool.get("name") for tool in tools}) + health = by_id.get(3, {}).get("result", {}).get("structuredContent", {}).get("ok") is True + search_no_rpc_error = "error" not in by_id.get(4, {}) + tool_error = by_id.get(5, {}).get("result", {}).get("isError") is True + return { + "initialize": initialized, + "tools_list": listed, + "graph_health": health, + "graph_search": search_no_rpc_error, + "tool_error_result": tool_error, + } + + +def _parse_stdio_frames(data: bytes) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + cursor = 0 + while cursor < len(data): + header_end = data.find(b"\r\n\r\n", cursor) + delimiter_length = 4 + if header_end == -1: + header_end = data.find(b"\n\n", cursor) + delimiter_length = 2 + if header_end == -1: + break + header = data[cursor:header_end].decode("ascii", errors="replace") + length = None + for line in header.splitlines(): + if line.lower().startswith("content-length:"): + length = int(line.split(":", 1)[1].strip()) + break + if length is None: + break + body_start = header_end + delimiter_length + body_end = body_start + length + messages.append(json.loads(data[body_start:body_end].decode("utf-8"))) + cursor = body_end + return messages + + +def _frame_json_rpc(method: str, params: dict[str, Any], *, request_id: int) -> bytes: + body = json.dumps( + {"jsonrpc": "2.0", "id": request_id, "method": method, "params": params}, + separators=(",", ":"), + sort_keys=True, + ).encode("utf-8") + return f"Content-Length: {len(body)}\r\n\r\n".encode("ascii") + body + + +def _build_descriptor(options: McpInstallOptions) -> McpServerDescriptor: + config_path = Path(options.setup_config_path).expanduser().resolve() + repo_root: Path | None = None + repo_name: str | None = None + if config_path.exists(): + setup_payload = load_setup_config(config_path) + repo_root = Path(str(setup_payload["repo_root"])).expanduser().resolve() + repo_name = str(setup_payload.get("repo_name") or repo_root.name) + elif options.require_setup_config: + raise FileNotFoundError( + f"codebaseGraph setup config does not exist: {config_path}. " + "Run `codebase-graph setup --mcp-client none` first." + ) + server_name = options.server_name or default_server_name(repo_name or config_path.parent.parent.name) + return build_server_descriptor(config_path, repo_root=repo_root, name=server_name) + + +def _validate_options(options: McpInstallOptions) -> None: + if options.client not in {*INSTALL_CLIENTS, "none"}: + supported = ", ".join(sorted([*INSTALL_CLIENTS, "all", "none"])) + raise ValueError(f"Unsupported MCP client: {options.client}. Supported clients: {supported}") + if options.scope not in SCOPES: + raise ValueError(f"Unsupported MCP install scope: {options.scope}. Supported scopes: {', '.join(SCOPES)}") + + +def _client_strategy(client: str) -> InstallClientStrategy: + if client == "none": + return InstallClientStrategy(client_id="none") + return INSTALL_STRATEGIES[client] + + +def _missing_native_error(strategy: InstallClientStrategy) -> str | None: + if strategy.native_executable is None: + return None + return f"{strategy.native_executable} executable not found" + + +def _subprocess_error(completed: subprocess.CompletedProcess[str]) -> str: + output = "\n".join(part for part in (completed.stdout.strip(), completed.stderr.strip()) if part) + if output: + return f"exit {completed.returncode}: {output}" + return f"exit {completed.returncode}" + + +def _safe_name(value: str) -> str: + normalized = re.sub(r"[^A-Za-z0-9_-]+", "_", value.strip()) + return normalized.strip("._-").lower() or "repository" diff --git a/src/codebase_graph/setup/instructions.py b/src/codebase_graph/setup/instructions.py new file mode 100644 index 0000000..98664c7 --- /dev/null +++ b/src/codebase_graph/setup/instructions.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +START_MARKER = "" +END_MARKER = "" + + +@dataclass(frozen=True, slots=True) +class InstructionResult: + action: str + path: str | None + + def as_dict(self) -> dict[str, str | None]: + return {"action": self.action, "path": self.path} + + +def upsert_instruction_block( + repo_root: Path, + *, + target: str = "auto", + server_name: str, + config_path: Path, + setup_command: str = "codebase-graph", +) -> InstructionResult: + if target == "skip": + return InstructionResult("skipped", None) + path = _select_instruction_path(repo_root, target) + existing = path.read_text(encoding="utf-8") if path.exists() else "" + block = _instruction_block(server_name=server_name, config_path=config_path, setup_command=setup_command) + next_text, action = _upsert_block(existing, block, created=not path.exists()) + if next_text == existing: + return InstructionResult("unchanged", path.as_posix()) + path.write_text(next_text, encoding="utf-8") + return InstructionResult(action, path.as_posix()) + + +def instruction_target_path(repo_root: Path, *, target: str = "auto") -> Path | None: + if target == "skip": + return None + return _select_instruction_path(repo_root, target) + + +def remove_instruction_block(path: Path) -> bool: + if not path.exists(): + return False + existing = path.read_text(encoding="utf-8") + start = existing.find(START_MARKER) + end = existing.find(END_MARKER) + if start == -1 or end == -1 or end < start: + return False + after_end = end + len(END_MARKER) + next_text = (existing[:start].rstrip() + "\n\n" + existing[after_end:].lstrip()).strip() + "\n" + path.write_text(next_text, encoding="utf-8") + return True + + +def _select_instruction_path(repo_root: Path, target: str) -> Path: + if target == "agents": + return repo_root / "AGENTS.md" + if target == "claude": + return repo_root / "CLAUDE.md" + if target != "auto": + raise ValueError(f"Unsupported instruction target: {target}") + agents = repo_root / "AGENTS.md" + claude = repo_root / "CLAUDE.md" + if agents.exists(): + return agents + if claude.exists(): + return claude + return agents + + +def _instruction_block(*, server_name: str, config_path: Path, setup_command: str) -> str: + return ( + f"{START_MARKER}\n" + "## codebaseGraph workflow\n" + "- Treat the repo-local `.codebaseGraph` graph as the project operating source of truth. It is prohibited to read the code source before you find the target files using the graph.\n" + "- AI agents must use block format for `graph-search` and `graph-context`; reserve `--json` for tests, APIs, or explicit structured-payload debugging.\n" + f"- Use `{setup_command} graph-search --repo-root . --no-refresh --detail slim --context-limit 1 --format block` before answering repo-structure questions or performing coding tasks.\n" + f"- Use `{setup_command} graph-context --repo-root . --profile --no-refresh --detail slim --context-limit 2 --format block` when relationships or nearby evidence matter; useful profiles include `definitions`, `dependencies`, `callgraph`, `docs`, `runtime`, and `change_impact`.\n" + f"- For architecture orientation, run `{setup_command} graph-architecture-queries`, then execute selected read-only statements with `{setup_command} graph-query \"\" --repo-root .`.\n" + f"- Use `{setup_command} graph-schema` or `{setup_command} graph-query-helpers` before writing raw graph queries, add `--pretty` for indented JSON when humans need to inspect output, and keep `graph-query` read-only.\n" + f"- Refresh the graph with `{setup_command} setup --repo-root . --mcp-client none` when files change materially. Setup config: `{config_path.as_posix()}`.\n" + f"{END_MARKER}\n" + ) + + +def _upsert_block(existing: str, block: str, *, created: bool) -> tuple[str, str]: + if not existing.strip(): + return block, "created" + start = existing.find(START_MARKER) + end = existing.find(END_MARKER) + if start != -1 and end != -1 and end > start: + after_end = end + len(END_MARKER) + return _join_sections(existing[:start], block, existing[after_end:]), "updated" + separator = "" if existing.endswith("\n") else "\n" + action = "created" if created else "updated" + return existing.rstrip() + separator + "\n" + block, action + + +def _join_sections(prefix: str, block: str, suffix: str) -> str: + sections = [section.strip() for section in (prefix, block, suffix) if section.strip()] + return "\n\n".join(sections) + "\n" diff --git a/src/codebase_graph/setup/mcp_config.py b/src/codebase_graph/setup/mcp_config.py new file mode 100644 index 0000000..9944ea7 --- /dev/null +++ b/src/codebase_graph/setup/mcp_config.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from .clients import get_client_adapter +from .descriptor import build_server_descriptor +from .installer import McpInstallOptions, McpInstallResult, install_mcp_server +from .state import MCP_SERVER_NAME + + +@dataclass(frozen=True, slots=True) +class McpConfigResult: + action: str + client: str + path: str | None + server_name: str + entry: dict[str, Any] + descriptor: dict[str, Any] | None = None + method: str | None = None + scope: str | None = None + command: list[str] | None = None + patch: Any = None + payload: Any = None + verification: dict[str, Any] | None = None + native_command: list[str] | None = None + native_error: str | None = None + + def as_dict(self) -> dict[str, Any]: + payload = { + "action": self.action, + "client": self.client, + "path": self.path, + "server_name": self.server_name, + "entry": self.entry, + } + if self.descriptor is not None: + payload["descriptor"] = self.descriptor + if self.method is not None: + payload["method"] = self.method + if self.scope is not None: + payload["scope"] = self.scope + if self.command is not None: + payload["command"] = self.command + if self.patch is not None: + payload["patch"] = self.patch + if self.payload is not None: + payload["payload"] = self.payload + if self.verification is not None: + payload["verification"] = self.verification + if self.native_command is not None: + payload["native_command"] = self.native_command + if self.native_error is not None: + payload["native_error"] = self.native_error + return payload + + @classmethod + def from_install_result(cls, result: McpInstallResult) -> McpConfigResult: + return cls( + action=result.action, + client=result.client, + path=result.path, + server_name=result.server_name, + entry=result.entry, + descriptor=result.descriptor, + method=result.method, + scope=result.scope, + command=result.command, + patch=result.patch, + payload=result.payload, + verification=result.verification, + native_command=result.native_command, + native_error=result.native_error, + ) + + +def configure_mcp_client( + *, + client: str, + config_path: str | Path | None, + setup_config_path: Path, + dry_run: bool = False, + skip: bool = False, +) -> McpConfigResult: + result = install_mcp_server( + McpInstallOptions( + client=client, + scope="project" if client == "claude-project" else "local", + setup_config_path=setup_config_path, + server_name=MCP_SERVER_NAME, + client_config_path=config_path, + dry_run=dry_run, + skip=skip, + require_setup_config=False, + ) + ) + return McpConfigResult.from_install_result(result) + + +def server_entry(setup_config_path: Path) -> dict[str, Any]: + return build_server_descriptor(setup_config_path).stdio_entry() + + +def default_config_path(client: str) -> Path: + descriptor = build_server_descriptor(Path.cwd() / ".codebaseGraph" / "config.json") + return get_client_adapter(client).default_config_path(descriptor) diff --git a/src/codebase_graph/setup/orchestrator.py b/src/codebase_graph/setup/orchestrator.py new file mode 100644 index 0000000..d7c667f --- /dev/null +++ b/src/codebase_graph/setup/orchestrator.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import shutil +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from codebase_graph.diagnostics import log_event +from codebase_graph.ingest import GraphMaterializer + +from .instructions import InstructionResult, instruction_target_path, upsert_instruction_block +from .mcp_config import McpConfigResult, configure_mcp_client, server_entry +from .preflight import validate_ladybug_runtime +from .state import MCP_SERVER_NAME, SetupPaths, build_setup_config, derive_setup_paths, write_setup_config + + +class SetupError(RuntimeError): + pass + + +@dataclass(frozen=True, slots=True) +class SetupOptions: + repo_root: str | Path = "." + mcp_client: str = "codex" + mcp_config_path: str | Path | None = None + skip_mcp_config: bool = False + dry_run: bool = False + instructions_target: str = "auto" + mode: str = "changed" + + +@dataclass(frozen=True, slots=True) +class SetupResult: + paths: SetupPaths + config_action: str + materialization: Any + mcp_config: McpConfigResult + instructions: InstructionResult + legacy_state_detected: bool + + def as_dict(self) -> dict[str, Any]: + return { + **self.paths.as_dict(), + "config_action": self.config_action, + "legacy_state_detected": self.legacy_state_detected, + "mcp_config": self.mcp_config.as_dict(), + "instructions": self.instructions.as_dict(), + "materialization": _materialization_payload(self.materialization), + } + + +def run_setup(options: SetupOptions) -> SetupResult: + try: + log_event( + "setup.start", + level="INFO", + repo_root=str(options.repo_root), + mcp_client=options.mcp_client, + dry_run=options.dry_run, + ) + paths = derive_setup_paths(options.repo_root) + validate_ladybug_runtime() + mcp_entry = server_entry(paths.config_path) + config_payload = build_setup_config(paths, mcp_command=[mcp_entry["command"], *mcp_entry["args"]]) + if options.dry_run: + materialization = _dry_run_materialization(paths) + config_action = "dry_run" if _config_would_change(paths.config_path, config_payload) else "unchanged" + target_path = instruction_target_path(paths.repo_root, target=options.instructions_target) + instructions = InstructionResult("dry_run" if target_path is not None else "skipped", _path_text(target_path)) + else: + target_path = instruction_target_path(paths.repo_root, target=options.instructions_target) + previous_config = _snapshot_file(paths.config_path) + previous_instructions = _snapshot_file(target_path) + state_dir_existed = paths.state_dir.exists() + materializer = GraphMaterializer( + paths.repo_root, + db_path=paths.db_path, + manifest_path=paths.manifest_path, + include_fts=True, + repository_label=paths.repo_name, + ) + try: + config_action = write_setup_config(paths.config_path, config_payload) + instructions = upsert_instruction_block( + paths.repo_root, + target=options.instructions_target, + server_name=MCP_SERVER_NAME, + config_path=paths.config_path, + setup_command=mcp_entry["command"], + ) + materialization = materializer.materialize(mode=options.mode) # type: ignore[arg-type] + mcp_result = configure_mcp_client( + client=options.mcp_client, + config_path=options.mcp_config_path, + setup_config_path=paths.config_path, + dry_run=False, + skip=options.skip_mcp_config, + ) + except Exception: + _restore_file(paths.config_path, previous_config) + _restore_file(target_path, previous_instructions) + if not state_dir_existed: + shutil.rmtree(paths.state_dir, ignore_errors=True) + raise + finally: + materializer.close() + if options.dry_run and not options.skip_mcp_config: + mcp_result = configure_mcp_client( + client=options.mcp_client, + config_path=options.mcp_config_path, + setup_config_path=paths.config_path, + dry_run=True, + skip=False, + ) + elif options.dry_run: + mcp_result = configure_mcp_client( + client=options.mcp_client, + config_path=options.mcp_config_path, + setup_config_path=paths.config_path, + dry_run=True, + skip=True, + ) + except Exception as exc: + log_event( + "setup.failed", + level="ERROR", + repo_root=str(options.repo_root), + error_type=exc.__class__.__name__, + message=str(exc), + ) + if isinstance(exc, SetupError): + raise + raise SetupError(str(exc)) from exc + log_event( + "setup.completed", + level="INFO", + repo_root=paths.repo_root.as_posix(), + config_action=config_action, + rebuilt=getattr(materialization, "rebuilt"), + deleted=getattr(materialization, "deleted"), + mcp_action=mcp_result.action, + ) + return SetupResult( + paths=paths, + config_action=config_action, + materialization=materialization, + mcp_config=mcp_result, + instructions=instructions, + legacy_state_detected=(paths.repo_root / ".codebase_graph").exists(), + ) + + +def _materialization_payload(result: Any) -> dict[str, Any]: + as_dict = getattr(result, "as_dict", None) + if callable(as_dict): + return as_dict() + raise TypeError(f"Unsupported materialization result: {type(result).__name__}") + + +def _dry_run_materialization(paths: SetupPaths) -> Any: + materializer = GraphMaterializer( + paths.repo_root, + db_path=paths.db_path, + manifest_path=paths.manifest_path, + include_fts=True, + repository_label=paths.repo_name, + ) + try: + snapshots, diagnostics = materializer._scan_source_files() + finally: + materializer.close() + skipped_paths = tuple(sorted(path for path, snapshot in snapshots.items() if snapshot.language is None)) + return _DryRunMaterialization( + scanned=len(snapshots), + skipped=len(skipped_paths), + diagnostics=tuple(diagnostics), + manifest_path=paths.manifest_path.as_posix(), + skipped_paths=skipped_paths, + ) + + +@dataclass(frozen=True, slots=True) +class _DryRunMaterialization: + scanned: int + skipped: int + diagnostics: tuple[str, ...] + manifest_path: str + skipped_paths: tuple[str, ...] + mode: str = "dry_run" + rebuilt: int = 0 + deleted: int = 0 + rebuilt_paths: tuple[str, ...] = () + deleted_paths: tuple[str, ...] = () + graph_summary: dict[str, Any] = field(default_factory=dict) + + def as_dict(self) -> dict[str, Any]: + return { + "mode": self.mode, + "scanned": self.scanned, + "rebuilt": self.rebuilt, + "skipped": self.skipped, + "deleted": self.deleted, + "diagnostics": list(self.diagnostics), + "manifest_path": self.manifest_path, + "rebuilt_paths": list(self.rebuilt_paths), + "skipped_paths": list(self.skipped_paths), + "deleted_paths": list(self.deleted_paths), + "graph_summary": dict(self.graph_summary), + } + + +def _config_would_change(path: Path, payload: dict[str, Any]) -> bool: + if not path.exists(): + return True + try: + import json + + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) != payload + except Exception: + return True + + +def _path_text(path: Path | None) -> str | None: + return path.as_posix() if path is not None else None + + +def _snapshot_file(path: Path | None) -> str | None: + if path is None or not path.exists(): + return None + return path.read_text(encoding="utf-8") + + +def _restore_file(path: Path | None, previous: str | None) -> None: + if path is None: + return + if previous is None: + try: + path.unlink() + except FileNotFoundError: + return + return + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(previous, encoding="utf-8") diff --git a/src/codebase_graph/setup/preflight.py b/src/codebase_graph/setup/preflight.py new file mode 100644 index 0000000..0a3aca5 --- /dev/null +++ b/src/codebase_graph/setup/preflight.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path + +from codebase_graph.db import LadybugUnavailableError, create_ladybug_database + + +def validate_ladybug_runtime() -> None: + """Fail before setup creates repo state if LadyBugDB cannot create a graph DB.""" + try: + import real_ladybug # noqa: F401 + except ImportError as exc: + raise LadybugUnavailableError( + "LadyBugDB is required for codebaseGraph setup. Install a package build that includes `real_ladybug`." + ) from exc + + with tempfile.TemporaryDirectory(prefix="codebase-graph-preflight-") as temp_dir: + db_path = Path(temp_dir) / "preflight.ldb" + store = create_ladybug_database(db_path, include_fts=False) + store.close() diff --git a/src/codebase_graph/setup/state.py b/src/codebase_graph/setup/state.py new file mode 100644 index 0000000..d1baa08 --- /dev/null +++ b/src/codebase_graph/setup/state.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import json +import os +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import Any + +from codebase_graph import paths as graph_paths +from codebase_graph.ontology import ONTOLOGY_VERSION + +CONFIG_NAME = graph_paths.CONFIG_NAME +DEFAULT_STATE_DIR = graph_paths.DEFAULT_STATE_DIR +MANIFEST_NAME = graph_paths.MANIFEST_NAME +MCP_SERVER_NAME = graph_paths.MCP_SERVER_NAME +GraphStatePaths = graph_paths.GraphStatePaths +derive_graph_state_paths = graph_paths.derive_graph_state_paths +SetupPaths = graph_paths.GraphStatePaths + + +def derive_setup_paths(repo_root: str | Path) -> SetupPaths: + paths = derive_graph_state_paths(repo_root) + if not paths.repo_root.exists(): + raise FileNotFoundError(f"Repository root does not exist: {paths.repo_root}") + if not paths.repo_root.is_dir(): + raise NotADirectoryError(f"Repository root is not a directory: {paths.repo_root}") + if DEFAULT_STATE_DIR in paths.repo_root.parts: + raise ValueError( + f"Repository root may not be inside a {DEFAULT_STATE_DIR} state directory: {paths.repo_root}" + ) + return paths + + +def build_setup_config(paths: SetupPaths, *, mcp_command: list[str]) -> dict[str, Any]: + return { + "schema_version": 1, + "repo_root": paths.repo_root.as_posix(), + "repo_name": paths.repo_name, + "state_dir": paths.state_dir.as_posix(), + "database_path": paths.db_path.as_posix(), + "manifest_path": paths.manifest_path.as_posix(), + "ontology_version": ONTOLOGY_VERSION, + "package_version": _package_version(), + "mcp": { + "server_name": MCP_SERVER_NAME, + "command": list(mcp_command), + }, + } + + +def load_setup_config(path: str | Path) -> dict[str, Any]: + config_path = Path(path).expanduser().resolve() + with config_path.open("r", encoding="utf-8") as handle: + payload = json.load(handle) + _validate_setup_config(payload, config_path) + return payload + + +def write_setup_config(path: Path, payload: dict[str, Any]) -> str: + previous = _read_json_if_exists(path) + action = "created" + if previous == payload: + return "unchanged" + if previous is not None: + action = "updated" + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + os.replace(tmp_path, path) + return action + + +def _package_version() -> str: + try: + return version("codebase-graph") + except PackageNotFoundError: + return "0.1.0" + + +def _read_json_if_exists(path: Path) -> dict[str, Any] | None: + if not path.exists(): + return None + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + +def _validate_setup_config(payload: dict[str, Any], path: Path) -> None: + required = ("repo_root", "repo_name", "state_dir", "database_path", "manifest_path") + missing = [key for key in required if not payload.get(key)] + if missing: + joined = ", ".join(missing) + raise ValueError(f"Invalid codebaseGraph setup config at {path}: missing {joined}") + + repo_root = Path(str(payload["repo_root"])).expanduser().resolve() + repo_name = str(payload["repo_name"]) + state_dir = Path(str(payload["state_dir"])).expanduser().resolve() + database_path = Path(str(payload["database_path"])).expanduser().resolve() + manifest_path = Path(str(payload["manifest_path"])).expanduser().resolve() + expected_state_dir = repo_root / DEFAULT_STATE_DIR + expected_database_path = state_dir / f"{repo_name}_graph.ldb" + expected_manifest_path = state_dir / MANIFEST_NAME + + if DEFAULT_STATE_DIR in repo_root.parts: + raise ValueError(f"Invalid codebaseGraph setup config at {path}: repo_root may not be inside {DEFAULT_STATE_DIR}") + if state_dir != expected_state_dir: + raise ValueError(f"Invalid codebaseGraph setup config at {path}: state_dir must be {expected_state_dir}") + if path.parent.resolve() != state_dir: + raise ValueError(f"Invalid codebaseGraph setup config at {path}: config must live under {state_dir}") + if database_path != expected_database_path: + raise ValueError(f"Invalid codebaseGraph setup config at {path}: database_path must be {expected_database_path}") + if manifest_path != expected_manifest_path: + raise ValueError(f"Invalid codebaseGraph setup config at {path}: manifest_path must be {expected_manifest_path}") diff --git a/src/codebase_graph/verification.py b/src/codebase_graph/verification.py deleted file mode 100644 index 11aea62..0000000 --- a/src/codebase_graph/verification.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import re -from typing import Any - -def summarize_verification_run(command: str, output: str, exit_code: int | None = None) -> dict[str, Any]: - status = "passed" if exit_code == 0 else "failed" if exit_code else "unknown" - return { - "command": command, - "status": status, - "exit_code": exit_code, - "summary": _compact_output(output), - "tool": _tool_name_from_command(command), - } - -def redact_verification_text(text: str) -> str: - text = re.sub(r"(?i)(api[_-]?key|token|secret|password)=\S+", r"\1=", text) - return text - -def _compact_output(output: str, limit: int = 1200) -> str: - cleaned = redact_verification_text(output).strip() - if len(cleaned) <= limit: - return cleaned - return f"{cleaned[:limit].rstrip()}..." - -def _tool_name_from_command(command: str) -> str: - parts = command.strip().split() - if not parts: - return "unknown" - if parts[0] in {"python", "python3"} and len(parts) > 2 and parts[1] == "-m": - return parts[2] - return parts[0] diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index fdcbc1f..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,9 +0,0 @@ -from __future__ import annotations - -import sys -from pathlib import Path - -ROOT = Path(__file__).resolve().parents[1] -SRC = ROOT / "src" -if str(SRC) not in sys.path: - sys.path.insert(0, str(SRC)) diff --git a/tests/fixtures/search_service_graph_search.json b/tests/fixtures/search_service_graph_search.json new file mode 100644 index 0000000..4c0b601 --- /dev/null +++ b/tests/fixtures/search_service_graph_search.json @@ -0,0 +1,117 @@ +{ + "budget": 600, + "limit": 3, + "profile": "brief", + "query": "SearchService", + "results": [ + { + "context": [ + { + "direction": "outgoing", + "label": "SearchService scope", + "path": "src/codebase_graph/retrieval/search.py", + "relation": "Contains", + "span": { + "line_end": 173, + "line_start": 117 + }, + "summary": "Scope for SearchService", + "type": "Scope" + }, + { + "direction": "outgoing", + "label": "__init__", + "path": "src/codebase_graph/retrieval/search.py", + "relation": "Contains", + "span": { + "line_end": 121, + "line_start": 118 + }, + "type": "Method" + } + ], + "id": "Class:943d6556d328f1c7ca67", + "label": "SearchService", + "path": "src/codebase_graph/retrieval/search.py", + "rank_score": 1.351608, + "span": { + "line_end": 173, + "line_start": 117 + }, + "type": "Class" + }, + { + "context": [ + { + "direction": "outgoing", + "label": "__init__ scope", + "path": "src/codebase_graph/retrieval/search.py", + "relation": "Contains", + "span": { + "line_end": 121, + "line_start": 118 + }, + "summary": "Scope for __init__", + "type": "Scope" + }, + { + "direction": "outgoing", + "label": "self.store", + "path": "src/codebase_graph/retrieval/search.py", + "relation": "Contains", + "span": { + "line_end": 119, + "line_start": 119 + }, + "summary": "Stores the graph backend for later search calls.", + "type": "InstanceAttribute" + } + ], + "id": "Method:3c775c9656a4d6b85843", + "label": "__init__", + "path": "src/codebase_graph/retrieval/search.py", + "rank_score": 1.047561, + "span": { + "line_end": 121, + "line_start": 118 + }, + "type": "Method" + }, + { + "context": [ + { + "direction": "outgoing", + "label": "search scope", + "path": "src/codebase_graph/retrieval/search.py", + "relation": "Contains", + "span": { + "line_end": 149, + "line_start": 123 + }, + "summary": "Scope for search", + "type": "Scope" + }, + { + "direction": "outgoing", + "label": "list[SearchHit]", + "path": "src/codebase_graph/retrieval/search.py", + "relation": "Contains", + "span": { + "line_end": 132, + "line_start": 132 + }, + "type": "TypeAnnotation" + } + ], + "id": "Method:9a6ff8b159b17320c004", + "label": "search", + "path": "src/codebase_graph/retrieval/search.py", + "rank_score": 1.047561, + "span": { + "line_end": 149, + "line_start": 123 + }, + "type": "Method" + } + ] +} diff --git a/tests/test_architecture_queries.py b/tests/test_architecture_queries.py new file mode 100644 index 0000000..c72bc03 --- /dev/null +++ b/tests/test_architecture_queries.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import re + +import pytest + +from codebase_graph.reasoning import ( + ARCHITECTURE_QUERY_GROUPS, + ARCHITECTURE_QUERY_ORDER, + architecture_query_catalog, +) + + +def test_architecture_query_catalog_serializes_in_stable_workflow_order() -> None: + payload = architecture_query_catalog() + + assert payload["workflow"] == "coding_task_architecture_discovery" + assert payload["execution_tool"] == "graph_query" + assert payload["recommended_order"] == list(ARCHITECTURE_QUERY_ORDER) + assert [group["name"] for group in payload["groups"]] == list(ARCHITECTURE_QUERY_ORDER) + + +def test_architecture_query_catalog_groups_expected_queries() -> None: + assert set(ARCHITECTURE_QUERY_GROUPS) == set(ARCHITECTURE_QUERY_ORDER) + assert _query_names("overview") == { + "graph_coverage", + "source_unit_inventory", + "package_directory_shape", + } + assert _query_names("public_surface") == { + "public_surface_candidates", + "entrypoint_runtime_surface", + } + assert _query_names("dependency_topology") == { + "external_dependency_map", + "module_import_coupling", + } + assert _query_names("execution_flow") == { + "high_fan_in_definitions", + "high_fan_out_callers", + "callable_neighborhood", + } + assert _query_names("runtime_data_security") == { + "data_query_touchpoints", + "secret_configuration_touchpoints", + } + assert _query_names("documentation_context") == { + "documentation_to_code_links", + "evidence_for_symbol", + } + assert _query_names("graph_quality_gaps") == {"unresolved_reference_risk"} + + +def test_architecture_query_names_are_unique_and_count_matches_catalog_contract() -> None: + names = [ + query.name + for group_name in ARCHITECTURE_QUERY_ORDER + for query in ARCHITECTURE_QUERY_GROUPS[group_name].queries + ] + + assert len(names) == 15 + assert len(names) == len(set(names)) + + +def test_architecture_query_catalog_filters_by_group() -> None: + payload = architecture_query_catalog("execution_flow") + + assert payload["recommended_order"] == list(ARCHITECTURE_QUERY_ORDER) + assert [group["name"] for group in payload["groups"]] == ["execution_flow"] + + +def test_architecture_query_catalog_rejects_unknown_group() -> None: + with pytest.raises(ValueError, match="Valid groups: overview"): + architecture_query_catalog("missing") + + +def test_architecture_queries_are_read_only_and_use_edge_node_traversal() -> None: + forbidden = re.compile( + r"\b(CREATE|MERGE|DELETE|SET|DROP|LOAD|COPY|INSERT|ALTER|REMOVE|RENAME|DETACH|INSTALL)\b", + re.IGNORECASE, + ) + direct_relation = re.compile(r"-\[:(?!FROM_|TO_)([A-Za-z][A-Za-z0-9_]*)\]->") + + for group_name in ARCHITECTURE_QUERY_ORDER: + for query in ARCHITECTURE_QUERY_GROUPS[group_name].queries: + assert query.statement.lstrip().upper().startswith("MATCH "), query.name + assert ";" not in query.statement, query.name + assert not forbidden.search(query.statement), query.name + assert not direct_relation.search(query.statement), query.name + + +def _query_names(group_name: str) -> set[str]: + return {query.name for query in ARCHITECTURE_QUERY_GROUPS[group_name].queries} diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index ecc9726..0000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path - -from codebase_graph.cli import main - -FIXTURE = Path(__file__).parent / "fixtures" / "sample_project" - -def test_cli_status_schema_and_search(tmp_path: Path, capsys) -> None: - state_dir = tmp_path / "graph" - assert main(["--source-root", str(FIXTURE), "--state-dir", str(state_dir), "status"]) == 0 - status = json.loads(capsys.readouterr().out) - assert status["stale"] is True - assert main(["--source-root", str(FIXTURE), "--state-dir", str(state_dir), "schema"]) == 0 - schema = json.loads(capsys.readouterr().out) - assert schema["ontology"] == "codebase_graph_v1" - assert main(["--source-root", str(FIXTURE), "--state-dir", str(state_dir), "search", "SampleService"]) == 0 - search = json.loads(capsys.readouterr().out) - assert search["count"] >= 1 diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py new file mode 100644 index 0000000..7673762 --- /dev/null +++ b/tests/test_diagnostics.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import json + +from codebase_graph.diagnostics import LOG_LEVEL_ENV, log_event + + +def test_log_event_emits_json_to_stderr_when_level_allows( + monkeypatch, + capsys, +) -> None: + monkeypatch.setenv(LOG_LEVEL_ENV, "INFO") + + log_event("sample.event", level="INFO", count=2, payload={"ok": True}) + + captured = capsys.readouterr() + assert captured.out == "" + event = json.loads(captured.err) + assert event["event"] == "sample.event" + assert event["level"] == "INFO" + assert event["count"] == 2 + assert event["payload"] == {"ok": True} + assert event["timestamp"] + + +def test_log_event_respects_configured_level(monkeypatch, capsys) -> None: + monkeypatch.setenv(LOG_LEVEL_ENV, "ERROR") + + log_event("sample.event", level="INFO") + + assert capsys.readouterr().err == "" diff --git a/tests/test_graph_builder.py b/tests/test_graph_builder.py new file mode 100644 index 0000000..a474f6b --- /dev/null +++ b/tests/test_graph_builder.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from codebase_graph.extract import CaptureRecord, CaptureTableRegistry, GraphBuilder, ParseBundle +from codebase_graph.ontology import PARSER_NODE_MAPPINGS + + +def test_graph_builder_maps_python_ast_shaped_tree_to_ontology() -> None: + parse_tree = { + "type": "Module", + "body": [ + { + "type": "ImportFrom", + "module": "dataclasses", + "names": [{"type": "alias", "name": "dataclass"}], + "line_start": 1, + }, + { + "type": "ClassDef", + "name": "WikiConfig", + "line_start": 5, + "decorator_list": [ + { + "type": "Call", + "func": {"type": "Name", "id": "dataclass"}, + "keywords": [{"type": "keyword", "arg": "slots", "value": {"type": "Constant", "value": True}}], + } + ], + "body": [ + { + "type": "AnnAssign", + "target": {"type": "Name", "id": "vault_dir"}, + "annotation": {"type": "Name", "id": "Path"}, + "value": { + "type": "Call", + "func": {"type": "Name", "id": "Path"}, + "args": [{"type": "Constant", "value": "wiki"}], + }, + "line_start": 7, + }, + { + "type": "FunctionDef", + "name": "raw_dir", + "args": {"type": "arguments", "args": [{"type": "arg", "arg": "self"}]}, + "returns": {"type": "Name", "id": "Path"}, + "decorator_list": [{"type": "Name", "id": "property"}], + "body": [ + { + "type": "Return", + "value": { + "type": "Attribute", + "value": {"type": "Name", "id": "self"}, + "attr": "vault_dir", + }, + } + ], + "line_start": 10, + }, + ], + }, + { + "type": "Assign", + "targets": [{"type": "Name", "id": "PAGE_KINDS"}], + "value": { + "type": "Tuple", + "elts": [ + {"type": "Constant", "value": "sources"}, + {"type": "Constant", "value": "entities"}, + ], + }, + "line_start": 15, + }, + ], + } + + graph = GraphBuilder(default_language="python").build(parse_tree, source_path="wiki_config.py") + + labels_by_type = { + table: {node.label for node in graph.nodes_by_type(table)} + for table in ("ImportDeclaration", "Class", "Method", "ClassAttribute", "Constant", "Decorator") + } + assert "dataclasses.dataclass" in labels_by_type["ImportDeclaration"] + assert "WikiConfig" in labels_by_type["Class"] + assert "raw_dir" in labels_by_type["Method"] + assert "vault_dir" in labels_by_type["ClassAttribute"] + assert "PAGE_KINDS" in labels_by_type["Constant"] + assert {"dataclass", "property"} <= labels_by_type["Decorator"] + assert graph.edges_by_type("DerivedFrom") + assert graph.edges_by_type("HasReturnType") + assert graph.edges_by_type("HasTypeAnnotation") + + +def test_graph_builder_uses_capture_names_as_primary_semantic_signal() -> None: + bundle = ParseBundle( + language="python", + path="api.py", + captures=( + CaptureRecord( + "definition.function", + {"type": "identifier", "text": "handler", "start_byte": 10, "end_byte": 17}, + ), + CaptureRecord( + "reference.call", + {"type": "identifier", "text": "json_response", "start_byte": 24, "end_byte": 37}, + ), + CaptureRecord( + "doc.string", + {"type": "string", "text": "Handle the API request.", "start_byte": 40, "end_byte": 65}, + ), + ), + ) + + result = GraphBuilder(repository_label="sample").build_file_graph(bundle) + graph = result.graph + + assert {node.label for node in graph.nodes_by_type("Function")} == {"handler"} + assert {node.label for node in graph.nodes_by_type("CallExpression")} == {"json_response"} + assert {node.label for node in graph.nodes_by_type("DocumentationChunk")} == {"Handle the API request."} + assert not result.diagnostics + assert not result.unresolved + + +def test_graph_builder_accepts_registered_capture_table_mapping() -> None: + registry = CaptureTableRegistry() + registry.register_exact("custom.component", "Component") + bundle = ParseBundle( + language="custom", + path="component.custom", + captures=( + CaptureRecord( + "custom.component", + {"type": "identifier", "name": "RegisteredWidget", "line_start": 1}, + ), + ), + ) + + result = GraphBuilder(capture_table_registry=registry).build_file_graph(bundle) + + assert {node.label for node in result.graph.nodes_by_type("Component")} == {"RegisteredWidget"} + + +def test_graph_builder_routes_local_imports_through_containing_scope() -> None: + parse_tree = { + "type": "Module", + "body": [ + { + "type": "ClassDef", + "name": "Loader", + "body": [ + { + "type": "FunctionDef", + "name": "load", + "body": [ + { + "type": "ImportFrom", + "module": "pathlib", + "names": [{"type": "alias", "name": "Path"}], + "line_start": 3, + }, + ], + }, + ], + }, + ], + } + + graph = GraphBuilder(default_language="python").build(parse_tree, source_path="loader.py") + + import_edge = next( + edge + for edge in graph.edges_by_type("Imports") + if graph.nodes[edge.target_id].label == "pathlib.Path" + ) + assert graph.nodes[import_edge.source_id].table == "Scope" + assert graph.nodes[import_edge.target_id].table == "ImportDeclaration" + + +def test_graph_builder_emits_relation_families_advertised_by_parser_mappings() -> None: + parse_tree = { + "type": "Module", + "body": [ + { + "type": "ImportFrom", + "module": "fastapi", + "names": [{"type": "alias", "name": "APIRouter"}], + }, + {"type": "FunctionDef", "name": "helper", "body": []}, + {"type": "FunctionDef", "name": "auth_required", "body": []}, + { + "type": "FunctionDef", + "name": "list_users", + "args": { + "type": "arguments", + "args": [ + { + "type": "arg", + "arg": "user_id", + "annotation": {"type": "Name", "id": "int"}, + } + ], + }, + "returns": {"type": "Name", "id": "Response"}, + "decorator_list": [{"type": "Name", "id": "auth_required"}], + "body": [ + {"type": "call", "capture_name": "route", "text": "/users", "handler": "list_users"}, + {"type": "Call", "func": {"type": "Name", "id": "helper"}}, + { + "type": "string", + "capture_name": "query.sql", + "text": "SELECT * FROM users", + "table": "users", + }, + {"type": "Name", "capture_name": "secret.env", "id": "DATABASE_URL"}, + { + "type": "Assign", + "targets": [{"type": "Name", "id": "CACHE"}], + "value": {"type": "Call", "func": {"type": "Name", "id": "helper"}}, + }, + {"type": "Name", "capture_name": "reference.identifier", "id": "helper"}, + {"type": "raise_statement", "capture_name": "raises", "name": "ValueError"}, + {"type": "except_clause", "capture_name": "handles", "name": "ValueError"}, + {"type": "docstring", "capture_name": "doc.string", "text": "List users."}, + ], + }, + {"type": "component_declaration", "capture_name": "component", "name": "UserService"}, + {"type": "export_statement", "name": "list_users"}, + ], + } + + graph = GraphBuilder(default_language="python").build(parse_tree, source_path="api.py") + + mapped_relations = {relation for mapping in PARSER_NODE_MAPPINGS for relation in mapping.relation_types} + emitted_relations = set(graph.summary()["edge_counts"]) + assert mapped_relations <= emitted_relations diff --git a/tests/test_graph_core.py b/tests/test_graph_core.py deleted file mode 100644 index 98c0f3d..0000000 --- a/tests/test_graph_core.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import annotations - -from pathlib import Path - -from codebase_graph import CodebaseGraph -from codebase_graph.code_map import MAX_INDEXED_FILE_BYTES - -FIXTURE = Path(__file__).parent / "fixtures" / "sample_project" - -def test_import_and_status_defaults(tmp_path: Path) -> None: - graph = CodebaseGraph(source_root=FIXTURE, state_dir=tmp_path / "graph") - status = graph.status().as_dict() - assert status["database_exists"] is False - assert status["stale"] is True - assert status["database_path"].endswith("knowledge_graph.json") - -def test_materialize_schema_search_context_and_cypher(tmp_path: Path) -> None: - graph = CodebaseGraph(source_root=FIXTURE, state_dir=tmp_path / "graph") - materialized = graph.materialize() - assert materialized["summary"]["ontology"] == "codebase_graph_v1" - assert materialized["summary"]["node_count"] > 0 - schema = graph.schema() - assert schema["ontology"] == "codebase_graph_v1" - search = graph.search("SampleService", limit=5) - assert any(item["label"] == "SampleService" for item in search["items"]) - context = graph.context("SampleService", budget=500) - assert "SampleService" in context["context"] - cypher = graph.cypher("MATCH (n:PythonClass) RETURN n.id, n.label, n.qualified_name LIMIT 5") - assert cypher["count"] == 1 - assert cypher["rows"][0]["n.label"] == "SampleService" - -def test_status_and_materialize_ignore_non_indexable_files(tmp_path: Path) -> None: - repo = tmp_path / "repo" - repo.mkdir() - (repo / "pyproject.toml").write_text('[project]\nname = "filter-fixture"\nversion = "0.1.0"\n') - (repo / "included.py").write_text("class IncludedClass:\n pass\n") - (repo / "README.MD").write_text("# Fixture\n\nUppercase markdown suffix should stay indexable.\n") - - (repo / "build").mkdir() - (repo / "build" / "generated.py").write_text("class IgnoredBuildClass:\n pass\n") - (repo / "oversized.py").write_text("#" * (MAX_INDEXED_FILE_BYTES + 1)) - (repo / "payload.json").write_text('{"ignored": true}\n') - - graph = CodebaseGraph(source_root=repo, state_dir=tmp_path / "graph") - assert graph.status().source_file_count == 3 - - graph.materialize() - assert any(item["label"] == "IncludedClass" for item in graph.search("IncludedClass", limit=5)["items"]) - assert graph.search("IgnoredBuildClass", limit=5)["count"] == 0 - assert graph.search("oversized", limit=5)["count"] == 0 - assert graph.search("payload", limit=5)["count"] == 0 diff --git a/tests/test_graph_output_block_format.py b/tests/test_graph_output_block_format.py new file mode 100644 index 0000000..6a8f38a --- /dev/null +++ b/tests/test_graph_output_block_format.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import importlib.util +import json +import sys +from pathlib import Path +from typing import Any + +from codebase_graph.retrieval.block_format import ( + ONTOLOGY_TERMS, + canonicalize_search_payload, + parse_search_block, + serialize_agent_search_block, + serialize_context_block, + serialize_parseable_search_block, + serialize_search_block, +) + + +FIXTURE_PATH = Path(__file__).parent / "fixtures" / "search_service_graph_search.json" +SCRIPT_PATH = Path(__file__).parents[1] / "scripts" / "compare_graph_output_tokens.py" + + +class _WhitespaceEncoding: + def encode(self, text: str) -> list[str]: + return text.split() + + +def test_token_counting_uses_encoded_text_length() -> None: + module = _load_benchmark_script() + + assert module.count_tokens("Class SearchService Method", _WhitespaceEncoding()) == 3 + + +def test_raw_vs_block_comparison_preserves_search_service_fixture() -> None: + payload = json.loads(FIXTURE_PATH.read_text(encoding="utf-8")) + block = serialize_parseable_search_block(payload) + + assert parse_search_block(block) == canonicalize_search_payload(payload) + assert serialize_search_block(payload) == block + + +def test_l_same_is_only_emitted_for_matching_context_spans() -> None: + payload = { + "query": "SearchService", + "profile": "brief", + "limit": 1, + "budget": 600, + "results": [ + { + "id": "Class:1", + "type": "Class", + "label": "SearchService", + "path": "src/codebase_graph/retrieval/search.py", + "span": {"line_start": 10, "line_end": 20}, + "rank_score": 1.0, + "context": [ + { + "direction": "outgoing", + "relation": "Contains", + "type": "Scope", + "label": "SearchService scope", + "path": "src/codebase_graph/retrieval/search.py", + "span": {"line_start": 10, "line_end": 20}, + "summary": "Scope for SearchService", + }, + { + "direction": "outgoing", + "relation": "Contains", + "type": "Method", + "label": "__init__", + "path": "src/codebase_graph/retrieval/search.py", + "span": {"line_start": 11, "line_end": 12}, + }, + ], + } + ], + } + + block = serialize_search_block(payload) + + assert block.count("span=L=same") == 1 + assert "Method label=__init__ span=L11-L12" in block + + +def test_non_boilerplate_context_summaries_are_preserved() -> None: + payload = json.loads(FIXTURE_PATH.read_text(encoding="utf-8")) + block = serialize_search_block(payload) + + assert 'summary="Stores the graph backend for later search calls."' in block + assert parse_search_block(block) == canonicalize_search_payload(payload) + + +def test_block_format_keeps_ontology_terms_literal() -> None: + payload = json.loads(FIXTURE_PATH.read_text(encoding="utf-8")) + block = serialize_search_block(payload) + + for term in ONTOLOGY_TERMS: + assert term in block + assert "rank_score=" in block + assert "label=" in block + assert "span=" in block + assert "path " in block + + +def test_agent_block_reduces_display_only_boilerplate() -> None: + payload = json.loads(FIXTURE_PATH.read_text(encoding="utf-8")) + block = serialize_agent_search_block(payload) + + assert "q SearchService\n" in block + assert "budget" not in block + assert "limit" not in block + assert "profile" not in block + assert "id=Class:943d6556d328f1c7ca67" in block + assert "id=Method:3c775c9656a4d6b85843" in block + assert "rank_score=1.35" in block + assert "rank_score=1.351608" not in block + assert "SearchService scope" not in block + assert "search scope" not in block + assert "outgoing Contains Method __init__" not in block + assert "TypeAnnotation" not in block + assert ( + 'outgoing Contains InstanceAttribute self.store L119-L119 ' + 'summary="Stores the graph backend for later search calls."' + ) in block + + +def test_context_block_serializes_explicit_node_context() -> None: + block = serialize_context_block( + { + "node_id": "Class:943d6556d328f1c7ca67", + "node_type": "Class", + "profile": "definitions", + "context": [ + { + "direction": "outgoing", + "relation": "Contains", + "type": "Method", + "label": "search", + "path": "src/codebase_graph/retrieval/search.py", + "span": {"line_start": 123, "line_end": 149}, + } + ], + } + ) + + assert block.startswith("context Class id=Class:943d6556d328f1c7ca67 profile=definitions") + assert "file path src/codebase_graph/retrieval/search.py" in block + assert "outgoing Contains Method search L123-L149" in block + + +def _load_benchmark_script() -> Any: + spec = importlib.util.spec_from_file_location("compare_graph_output_tokens", SCRIPT_PATH) + assert spec is not None + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module diff --git a/tests/test_materializer.py b/tests/test_materializer.py new file mode 100644 index 0000000..376a9a3 --- /dev/null +++ b/tests/test_materializer.py @@ -0,0 +1,534 @@ +from __future__ import annotations + +import json +import os +import shutil +from pathlib import Path + +import pytest + +import codebase_graph.ingest.materializer as materializer_module +from codebase_graph.db import LadybugCodeGraphStore +from codebase_graph.ingest import ( + GraphMaterializer, + MarkdownDocumentParser, + ManifestEntry, + MaterializationManifest, + ParserRegistry, + SourceSnapshot, + TreeSitterPythonParser, +) +from codebase_graph.ontology import ONTOLOGY_NAME + + +def test_manifest_diff_tracks_added_modified_unchanged_and_deleted(tmp_path: Path) -> None: + manifest = MaterializationManifest( + files={ + "same.py": _entry("same.py", "same"), + "changed.py": _entry("changed.py", "old"), + "deleted.py": _entry("deleted.py", "old"), + } + ) + current = { + "same.py": _snapshot(tmp_path, "same.py", "same"), + "changed.py": _snapshot(tmp_path, "changed.py", "new"), + "added.py": _snapshot(tmp_path, "added.py", "new"), + } + + diff = manifest.diff(current) + + assert diff.added == ("added.py",) + assert diff.modified == ("changed.py",) + assert diff.unchanged == ("same.py",) + assert diff.deleted == ("deleted.py",) + assert diff.rebuild_paths == ("added.py", "changed.py") + assert not diff.force_rebuild + + +def test_manifest_diff_forces_rebuild_on_contract_mismatch(tmp_path: Path) -> None: + manifest = MaterializationManifest(schema_version=0, ontology=ONTOLOGY_NAME, parser_version="old", files={}) + current = {"service.py": _snapshot(tmp_path, "service.py", "hash")} + + diff = manifest.diff(current) + + assert diff.force_rebuild + assert diff.added == ("service.py",) + assert diff.rebuild_paths == ("service.py",) + + +def test_tree_sitter_python_parser_maps_sample_fixture_to_graph_tree() -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + fixture = Path("tests/fixtures/sample_project/sample_project/service.py") + parser = TreeSitterPythonParser() + + bundle = parser.parse_file( + fixture, + relative_path="sample_project/service.py", + source_root=fixture.parents[1], + repository_label="sample", + content_hash="hash", + ) + + assert bundle.language == "python" + assert bundle.tree["type"] == "module" + assert any(child["type"] == "class_definition" and child["name"] == "SampleService" for child in bundle.tree["children"]) + assert any(child["type"] == "function_definition" and child["name"] == "helper" for child in bundle.tree["children"]) + + +def test_materializer_defaults_to_canonical_codebasegraph_state_paths(tmp_path: Path) -> None: + source_root = tmp_path / "sample repo" + source_root.mkdir() + + materializer = GraphMaterializer(source_root, store=object()) + + assert materializer.state_dir == source_root / ".codebaseGraph" + assert materializer.db_path == source_root / ".codebaseGraph" / "sample_repo_graph.ldb" + assert materializer.manifest_path == source_root / ".codebaseGraph" / "manifest.json" + + +def test_scan_source_files_uses_parser_registry_for_suffix_mapping(tmp_path: Path) -> None: + registry = ParserRegistry() + registry.register( + "notes", + suffixes=(".notes",), + parser_factory=MarkdownDocumentParser, + parser_version="notes-v1", + ) + source_root = tmp_path / "project" + source_root.mkdir() + (source_root / "handoff.notes").write_text("# Handoff\n", encoding="utf-8") + + materializer = GraphMaterializer(source_root, store=object(), parser_registry=registry) + snapshots, diagnostics = materializer._scan_source_files() + + assert snapshots["handoff.notes"].language == "notes" + assert not diagnostics + + +def test_scan_source_files_prunes_excluded_directories(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + source_dir = tmp_path / "src" + source_dir.mkdir() + (source_dir / "app.py").write_text("VALUE = 1\n", encoding="utf-8") + observed_dirnames: list[tuple[str, ...]] = [] + + def fake_walk(root: Path) -> object: + dirnames = [".mypy_cache", ".tox", ".venv", "node_modules", "src", "package.egg-info", "vendor"] + yield Path(root).as_posix(), dirnames, [] + observed_dirnames.append(tuple(dirnames)) + for dirname in dirnames: + yield (Path(root) / dirname).as_posix(), [], ["app.py"] + + monkeypatch.setattr("codebase_graph.ingest.materializer.os.walk", fake_walk) + materializer = GraphMaterializer(tmp_path, db_path=":memory:", manifest_path=tmp_path / "manifest.json", store=object()) + + snapshots, diagnostics = materializer._scan_source_files() + + assert observed_dirnames == [("src",)] + assert tuple(snapshots) == ("src/app.py",) + assert not diagnostics + + +def test_scan_source_files_does_not_hash_unsupported_files( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + source_root = tmp_path / "project" + source_root.mkdir() + unsupported = source_root / "archive.bin" + supported = source_root / "service.py" + unsupported.write_bytes(b"\0" * 1024) + supported.write_text("VALUE = 1\n", encoding="utf-8") + hashed_paths: list[str] = [] + real_file_hash = materializer_module._file_hash + + def recording_file_hash(path: Path) -> str: + hashed_paths.append(path.name) + return real_file_hash(path) + + monkeypatch.setattr(materializer_module, "_file_hash", recording_file_hash) + materializer = GraphMaterializer(source_root, db_path=":memory:", manifest_path=tmp_path / "manifest.json", store=object()) + + snapshots, diagnostics = materializer._scan_source_files() + + assert hashed_paths == ["service.py"] + assert snapshots["archive.bin"].language is None + assert snapshots["archive.bin"].content_hash == "" + assert snapshots["service.py"].content_hash + assert diagnostics == ["Skipped unsupported file: archive.bin"] + + +def test_full_materialization_writes_python_graph_to_ladybug(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + ignored_dir = source_root / ".venv" + ignored_dir.mkdir() + (ignored_dir / "ignored.py").write_text("def ignored() -> None:\n pass\n", encoding="utf-8") + + materializer = GraphMaterializer( + source_root, + db_path=":memory:", + manifest_path=tmp_path / "manifest.json", + include_fts=False, + ) + result = materializer.materialize(mode="full") + + assert result.rebuilt == 4 + assert result.deleted == 0 + assert result.graph_summary["partition_count"] == 4 + assert _labels(materializer, "File") == { + "__init__.py", + "README.md", + "cli.py", + "service.py", + } + assert "README.md" in _labels(materializer, "DocumentationSource") + assert _labels(materializer, "DocumentationChunk") + assert "SampleService" in _labels(materializer, "Class") + assert "run" in _labels(materializer, "Method") + assert {"helper", "main"} <= _labels(materializer, "Function") + assert "ignored" not in _labels(materializer, "Function") + + +def test_full_materialization_handles_local_imports_inside_methods(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = tmp_path / "local_import_project" + source_root.mkdir() + (source_root / "service.py").write_text( + "class Loader:\n" + " def load(self) -> object:\n" + " from pathlib import Path\n" + " return Path('.')\n", + encoding="utf-8", + ) + + materializer = GraphMaterializer( + source_root, + db_path=":memory:", + manifest_path=tmp_path / "manifest.json", + include_fts=False, + ) + result = materializer.materialize(mode="full") + + assert result.rebuilt == 1 + assert "pathlib.Path" in _labels(materializer, "ImportDeclaration") + assert "Path" in _labels(materializer, "CallExpression") + metadata = materializer.store.execute( + "MATCH (n:`ImportDeclaration` {label: 'pathlib.Path'}) RETURN n.metadata" + ).get_all() + assert '"imported_name":"pathlib.Path"' in metadata[0][0] + + +def test_changed_materialization_only_rebuilds_changed_files(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + manifest_path = tmp_path / "manifest.json" + materializer = GraphMaterializer(source_root, db_path=":memory:", manifest_path=manifest_path, include_fts=False) + + first = materializer.materialize(mode="changed") + second = materializer.materialize(mode="changed") + service_path = source_root / "sample_project" / "service.py" + service_path.write_text(service_path.read_text(encoding="utf-8") + "\n\ndef added() -> str:\n return 'added'\n", encoding="utf-8") + third = materializer.materialize(mode="changed") + (source_root / "sample_project" / "cli.py").unlink() + fourth = materializer.materialize(mode="changed") + + assert first.rebuilt == 4 + assert second.rebuilt == 0 + assert third.rebuilt == 1 + assert third.rebuilt_paths == ("sample_project/service.py",) + assert "added" in _labels(materializer, "Function") + assert fourth.rebuilt == 0 + assert fourth.deleted == 1 + assert fourth.deleted_paths == ("sample_project/cli.py",) + assert "cli.py" not in _labels(materializer, "File") + + +def test_changed_ondisk_materialization_rebuilds_atomically_without_inplace_deletes( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + service_path = source_root / "sample_project" / "service.py" + service_path.write_text( + service_path.read_text(encoding="utf-8") + "\n\ndef changed_mode_added() -> str:\n return 'added'\n", + encoding="utf-8", + ) + + def fail_clear_graph(self: LadybugCodeGraphStore) -> None: + raise AssertionError("on-disk changed mode must not clear the target DB in place") + + def fail_delete_partition(self: LadybugCodeGraphStore, *args: object, **kwargs: object) -> None: + raise AssertionError("on-disk changed mode must not delete target partitions in place") + + monkeypatch.setattr(LadybugCodeGraphStore, "clear_graph", fail_clear_graph) + monkeypatch.setattr(LadybugCodeGraphStore, "delete_partition", fail_delete_partition) + + result = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize( + mode="changed" + ) + + assert result.mode == "changed" + assert result.rebuilt == 4 + assert "changed_mode_added" in _labels( + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False), + "Function", + ) + + +def test_changed_ondisk_materialization_noop_does_not_rebuild(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + result = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize( + mode="changed" + ) + + assert result.mode == "changed" + assert result.rebuilt == 0 + assert result.deleted == 0 + + +def test_changed_ondisk_materialization_failure_keeps_previous_db_and_manifest( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = tmp_path / "project" + source_root.mkdir() + service_path = source_root / "service.py" + service_path.write_text("def old_name() -> str:\n return 'old'\n", encoding="utf-8") + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + previous_manifest = manifest_path.read_text(encoding="utf-8") + service_path.write_text("def new_name() -> str:\n return 'new'\n", encoding="utf-8") + real_create_ladybug_database = materializer_module.create_ladybug_database + + def failing_create_ladybug_database(db_path: str | Path, *, include_fts: bool = True) -> LadybugCodeGraphStore: + store = real_create_ladybug_database(db_path, include_fts=include_fts) + + def fail_insert(*args: object, **kwargs: object) -> None: + raise RuntimeError("changed bulk insert failed") + + store.insert_graphs_bulk = fail_insert # type: ignore[method-assign] + return store + + monkeypatch.setattr(materializer_module, "create_ladybug_database", failing_create_ladybug_database) + + with pytest.raises(RuntimeError, match="changed bulk insert failed"): + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize( + mode="changed" + ) + + assert manifest_path.read_text(encoding="utf-8") == previous_manifest + reader = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False) + assert "old_name" in _labels(reader, "Function") + assert "new_name" not in _labels(reader, "Function") + assert not _marker_path(manifest_path).exists() + + +def test_full_ondisk_materialization_failure_keeps_previous_db_and_manifest( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = tmp_path / "project" + source_root.mkdir() + service_path = source_root / "service.py" + service_path.write_text("def old_name() -> str:\n return 'old'\n", encoding="utf-8") + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + previous_manifest = manifest_path.read_text(encoding="utf-8") + service_path.write_text("def new_name() -> str:\n return 'new'\n", encoding="utf-8") + real_create_ladybug_database = materializer_module.create_ladybug_database + + def failing_create_ladybug_database(db_path: str | Path, *, include_fts: bool = True) -> LadybugCodeGraphStore: + store = real_create_ladybug_database(db_path, include_fts=include_fts) + + def fail_insert(*args: object, **kwargs: object) -> None: + raise RuntimeError("bulk insert failed") + + store.insert_graphs_bulk = fail_insert # type: ignore[method-assign] + return store + + monkeypatch.setattr(materializer_module, "create_ladybug_database", failing_create_ladybug_database) + + with pytest.raises(RuntimeError, match="bulk insert failed"): + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + + assert manifest_path.read_text(encoding="utf-8") == previous_manifest + reader = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False) + assert "old_name" in _labels(reader, "Function") + assert "new_name" not in _labels(reader, "Function") + assert not _marker_path(manifest_path).exists() + + +def test_full_ondisk_materialization_replaces_stale_db_without_clear_graph( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = tmp_path / "project" + source_root.mkdir() + old_path = source_root / "old_module.py" + old_path.write_text("def old_name() -> str:\n return 'old'\n", encoding="utf-8") + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + old_path.unlink() + (source_root / "new_module.py").write_text("def new_name() -> str:\n return 'new'\n", encoding="utf-8") + + def fail_clear_graph(self: LadybugCodeGraphStore) -> None: + raise AssertionError("full on-disk rebuild must not clear the target DB in place") + + monkeypatch.setattr(LadybugCodeGraphStore, "clear_graph", fail_clear_graph) + result = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + + assert result.rebuilt == 1 + reader = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False) + assert "new_name" in _labels(reader, "Function") + assert "old_name" not in _labels(reader, "Function") + + +def test_full_ondisk_materialization_replaces_stale_sidecars(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + Path(f"{db_path}.wal").write_text("stale wal from previous database", encoding="utf-8") + Path(f"{db_path}.shadow").write_text("stale shadow from previous database", encoding="utf-8") + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + + reader = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False) + assert "SampleService" in _labels(reader, "Class") + assert not Path(f"{db_path}.shadow").exists() + + +def test_ondisk_materialization_rejects_concurrent_writer_lock(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + lock_path = Path(f"{db_path}.lock") + lock_path.write_text(json.dumps({"pid": os.getpid()}) + "\n", encoding="utf-8") + + with pytest.raises(RuntimeError, match="materialization is already in progress"): + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize( + mode="full" + ) + + assert lock_path.exists() + + +def test_ondisk_materialization_recovers_stale_writer_lock(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + lock_path = Path(f"{db_path}.lock") + lock_path.write_text(json.dumps({"pid": 123456, "db_path": db_path.as_posix()}) + "\n", encoding="utf-8") + monkeypatch.setattr(materializer_module, "_process_is_running", lambda pid: False) + + result = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize( + mode="full" + ) + + assert result.rebuilt == 4 + assert not lock_path.exists() + + +def test_pending_rebuild_marker_forces_changed_mode_atomic_rebuild( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="full") + marker_path = _marker_path(manifest_path) + marker_path.write_text("{}\n", encoding="utf-8") + + def fail_clear_graph(self: LadybugCodeGraphStore) -> None: + raise AssertionError("marker recovery must use the atomic rebuild path") + + monkeypatch.setattr(LadybugCodeGraphStore, "clear_graph", fail_clear_graph) + result = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False).materialize(mode="changed") + + assert result.mode == "changed" + assert result.rebuilt == 4 + assert not marker_path.exists() + reader = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path, include_fts=False) + assert "SampleService" in _labels(reader, "Class") + + +def _entry(path: str, content_hash: str) -> ManifestEntry: + return ManifestEntry( + path=path, + content_hash=content_hash, + language="python", + partition_id=path, + node_ids=(), + edge_ids=(), + ) + + +def _snapshot(tmp_path: Path, path: str, content_hash: str) -> SourceSnapshot: + return SourceSnapshot(path=path, absolute_path=tmp_path / path, content_hash=content_hash, language="python") + + +def _copy_fixture(tmp_path: Path) -> Path: + source = Path("tests/fixtures/sample_project") + target = tmp_path / "sample_project" + shutil.copytree(source, target) + return target + + +def _labels(materializer: GraphMaterializer, table: str) -> set[str]: + result = materializer.store.execute(f"MATCH (n:`{table}`) RETURN n.label") + return {row[0] for row in result.get_all()} + + +def _marker_path(manifest_path: Path) -> Path: + return manifest_path.with_suffix(manifest_path.suffix + ".rebuild-pending") diff --git a/tests/test_mcp_entrypoint.py b/tests/test_mcp_entrypoint.py new file mode 100644 index 0000000..dbc876b --- /dev/null +++ b/tests/test_mcp_entrypoint.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import subprocess +import sys + + +def test_mcp_entrypoint_help_imports_without_setup_cycle() -> None: + completed = subprocess.run( + [ + sys.executable, + "-c", + "from codebase_graph.mcp.server import main; raise SystemExit(main())", + "--help", + ], + capture_output=True, + text=True, + check=False, + ) + + assert completed.returncode == 0, completed.stderr + assert "usage: codebase-graph-mcp" in completed.stdout diff --git a/tests/test_mcp_installer.py b/tests/test_mcp_installer.py new file mode 100644 index 0000000..7809422 --- /dev/null +++ b/tests/test_mcp_installer.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import json +import subprocess +from pathlib import Path + +import pytest + +from codebase_graph.cli import main as cli_main +from codebase_graph.setup.clients import get_client_adapter +from codebase_graph.setup.descriptor import build_server_descriptor +from codebase_graph.setup.installer import ( + INSTALL_CLIENTS, + INSTALL_STRATEGIES, + McpInstallOptions, + default_server_name, + install_mcp_clients, + install_mcp_server, +) +from codebase_graph.setup.state import build_setup_config, derive_setup_paths, write_setup_config + + +def test_default_server_name_is_namespace_safe() -> None: + assert default_server_name("My Service") == "codebase_graph_my_service" + + +def test_install_strategy_registry_covers_advertised_clients() -> None: + assert set(INSTALL_CLIENTS) == set(INSTALL_STRATEGIES) + for client, strategy in INSTALL_STRATEGIES.items(): + assert strategy.adapter_client_id("local") + if strategy.native_command_builder is not None: + assert strategy.native_executable + if client == "claude": + assert strategy.adapter_client_id("project") == "claude-project" + if client == "claude-project": + assert strategy.install_scope("local") == "project" + + +def test_codex_native_command_generation_uses_repo_server_name( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = _write_setup_config(tmp_path / "fresh repo") + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: f"/usr/bin/{name}") + + result = install_mcp_server(McpInstallOptions(setup_config_path=config_path, dry_run=True)) + + assert result.action == "dry_run" + assert result.method == "native_cli" + assert result.server_name == "codebase_graph_fresh_repo" + assert result.command[:4] == ["codex", "mcp", "add", "codebase_graph_fresh_repo"] + assert result.command[4] == "--" + + +def test_claude_native_command_includes_transport_and_scope( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: f"/usr/bin/{name}") + + result = install_mcp_server( + McpInstallOptions(client="claude", scope="user", setup_config_path=config_path, dry_run=True) + ) + + assert result.command[:8] == [ + "claude", + "mcp", + "add", + "--transport", + "stdio", + "--scope", + "user", + "codebase_graph_fresh_repo", + ] + + +def test_claude_project_native_command_forces_project_scope( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: f"/usr/bin/{name}") + + result = install_mcp_server( + McpInstallOptions(client="claude-project", scope="user", setup_config_path=config_path, dry_run=True) + ) + + assert result.command[6:8] == ["project", "codebase_graph_fresh_repo"] + + +def test_openclaw_native_command_emits_server_json( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: f"/usr/bin/{name}") + + result = install_mcp_server( + McpInstallOptions(client="openclaw", setup_config_path=config_path, dry_run=True) + ) + entry = json.loads(result.command[-1]) + + assert result.command[:4] == ["openclaw", "mcp", "set", "codebase_graph_fresh_repo"] + assert entry["type"] == "stdio" + assert entry["args"][:2] == ["mcp", "serve"] + + +def test_missing_native_cli_falls_back_to_file_adapter( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + codex_home = tmp_path / "codex-home" + monkeypatch.setenv("CODEX_HOME", codex_home.as_posix()) + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: None) + + result = install_mcp_server(McpInstallOptions(client="codex", setup_config_path=config_path)) + + assert result.action == "created" + assert result.method == "file_adapter" + assert result.path == (codex_home / "config.toml").as_posix() + assert "executable not found" in result.native_error + assert (codex_home / "config.toml").exists() + + +def test_native_cli_failure_falls_back_to_adapter( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + codex_home = tmp_path / "codex-home" + monkeypatch.setenv("CODEX_HOME", codex_home.as_posix()) + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: f"/usr/bin/{name}") + + def fail_run(command: list[str], **kwargs: object) -> subprocess.CompletedProcess[str]: + return subprocess.CompletedProcess(command, 2, stdout="", stderr="native failed") + + monkeypatch.setattr("codebase_graph.setup.installer.subprocess.run", fail_run) + + result = install_mcp_server(McpInstallOptions(client="codex", setup_config_path=config_path)) + + assert result.action == "created" + assert result.method == "file_adapter" + assert result.native_command[:4] == ["codex", "mcp", "add", "codebase_graph_fresh_repo"] + assert result.native_error == "exit 2: native failed" + assert (codex_home / "config.toml").exists() + + +def test_dry_run_never_writes_files_or_calls_native_cli( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + monkeypatch.setenv("HOME", tmp_path.as_posix()) + + def fail_run(*args: object, **kwargs: object) -> subprocess.CompletedProcess[str]: + raise AssertionError("dry-run must not call subprocess.run") + + monkeypatch.setattr("codebase_graph.setup.installer.subprocess.run", fail_run) + + result = install_mcp_server( + McpInstallOptions(client="generic", setup_config_path=config_path, dry_run=True) + ) + + assert result.action == "dry_run" + assert result.method == "file_adapter" + assert not (tmp_path / ".config" / "mcp" / "mcp.json").exists() + + +def test_setup_compatibility_uses_snake_case_server_name_and_file_adapter( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + mcp_config_path = tmp_path / "codex.toml" + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: f"/usr/bin/{name}") + + exit_code = cli_main( + [ + "setup", + "--repo-root", + repo_root.as_posix(), + "--mcp-client", + "codex", + "--mcp-config-path", + mcp_config_path.as_posix(), + "--instructions-target", + "skip", + ] + ) + output = json.loads(capsys.readouterr().out) + + assert exit_code == 0 + assert output["mcp_config"]["server_name"] == "codebase_graph" + assert output["mcp_config"]["method"] == "file_adapter" + assert output["mcp_config"]["path"] == mcp_config_path.as_posix() + + +def test_hermes_default_path_is_documented_home_config( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(Path, "home", classmethod(lambda cls: tmp_path)) + descriptor = build_server_descriptor(tmp_path / ".codebaseGraph" / "config.json") + + assert get_client_adapter("hermes").default_config_path(descriptor) == tmp_path / ".hermes" / "config.yaml" + + +def test_all_client_install_reports_partial_failure( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + import codebase_graph.setup.installer as installer + + config_path = _write_setup_config(tmp_path / "fresh_repo") + monkeypatch.setenv("CODEX_HOME", (tmp_path / "codex-home").as_posix()) + monkeypatch.setattr(installer, "INSTALL_CLIENTS", ("codex", "generic")) + monkeypatch.setattr(installer.shutil, "which", lambda name: None) + original_get_adapter = installer.get_client_adapter + + def get_adapter(client: str) -> object: + if client == "generic": + raise ValueError("adapter unavailable") + return original_get_adapter(client) + + monkeypatch.setattr(installer, "get_client_adapter", get_adapter) + + results = install_mcp_clients(McpInstallOptions(client="all", setup_config_path=config_path)) + + assert [result.client for result in results] == ["codex", "generic"] + assert results[0].action == "created" + assert results[1].action == "failed" + assert results[1].error == "adapter unavailable" + + +def test_mcp_install_cli_dry_run_json( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + monkeypatch.setattr("codebase_graph.setup.installer.shutil.which", lambda name: f"/usr/bin/{name}") + + exit_code = cli_main( + ["mcp", "install", "--client", "codex", "--config-path", config_path.as_posix(), "--dry-run", "--json"] + ) + output = json.loads(capsys.readouterr().out) + + assert exit_code == 0 + assert output["action"] == "dry_run" + assert output["method"] == "native_cli" + assert output["server_name"] == default_server_name("fresh_repo") + + +def test_mcp_install_cli_accepts_client_config_path( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + config_path = _write_setup_config(tmp_path / "fresh_repo") + client_config_path = tmp_path / "client" / "mcp.json" + + exit_code = cli_main( + [ + "mcp", + "install", + "--client", + "generic", + "--config-path", + config_path.as_posix(), + "--client-config-path", + client_config_path.as_posix(), + "--json", + ] + ) + output = json.loads(capsys.readouterr().out) + + assert exit_code == 0 + assert output["path"] == client_config_path.as_posix() + assert client_config_path.exists() + + +def _write_setup_config(repo_root: Path) -> Path: + repo_root.mkdir(parents=True) + paths = derive_setup_paths(repo_root) + mcp_command = ["codebase-graph", "mcp", "serve", "--config", paths.config_path.as_posix()] + payload = build_setup_config(paths, mcp_command=mcp_command) + write_setup_config(paths.config_path, payload) + return paths.config_path + + +def _fresh_repo(tmp_path: Path) -> Path: + repo_root = tmp_path / "fresh_repo" + package = repo_root / "sample_project" + package.mkdir(parents=True) + (package / "__init__.py").write_text("", encoding="utf-8") + (package / "service.py").write_text( + "class SampleService:\n" + " def run(self) -> str:\n" + " return helper()\n\n" + "def helper() -> str:\n" + " return 'ok'\n", + encoding="utf-8", + ) + return repo_root diff --git a/tests/test_mcp_portability.py b/tests/test_mcp_portability.py new file mode 100644 index 0000000..f13eecd --- /dev/null +++ b/tests/test_mcp_portability.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import json +import subprocess +import sys +import threading +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any, BinaryIO + +try: + import tomllib +except ImportError: # pragma: no cover - Python 3.10 compatibility + import tomli as tomllib + +import pytest + +from codebase_graph.mcp.protocol import LATEST_PROTOCOL_VERSION, SUPPORTED_PROTOCOL_VERSIONS, McpGraphServer +from codebase_graph.mcp.runtime import GraphRuntimeConfig +from codebase_graph.mcp.transports.http import build_http_server +from codebase_graph.setup import SetupOptions, run_setup +from codebase_graph.setup.clients import supported_client_ids +from codebase_graph.setup.descriptor import build_server_descriptor +from codebase_graph.setup.mcp_config import configure_mcp_client + + +def test_initialize_negotiates_supported_and_fallback_protocol_versions(tmp_path: Path) -> None: + db_path = tmp_path / "graph.ldb" + db_path.write_text("", encoding="utf-8") + server = McpGraphServer(GraphRuntimeConfig(repo_root=tmp_path, db_path=db_path)) + + older = server.handle_json_rpc( + {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"protocolVersion": "2024-11-05"}} + ) + fallback = server.handle_json_rpc( + {"jsonrpc": "2.0", "id": 2, "method": "initialize", "params": {"protocolVersion": "1.0.0"}} + ) + + assert older is not None + assert fallback is not None + assert older["result"]["protocolVersion"] == "2024-11-05" + assert fallback["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert "2025-11-25" in SUPPORTED_PROTOCOL_VERSIONS + + +def test_architecture_query_catalog_is_available_over_mcp_without_opening_graph(tmp_path: Path) -> None: + db_path = tmp_path / "graph.ldb" + db_path.write_text("", encoding="utf-8") + server = McpGraphServer(GraphRuntimeConfig(repo_root=tmp_path, db_path=db_path)) + + server.handle_json_rpc({"jsonrpc": "2.0", "id": 0, "method": "initialize", "params": {}}) + listed = server.handle_json_rpc({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}) + all_queries = server.handle_json_rpc( + { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": {"name": "graph_architecture_queries", "arguments": {}}, + } + ) + filtered = server.handle_json_rpc( + { + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": {"name": "graph_architecture_queries", "arguments": {"group": "overview"}}, + } + ) + invalid = server.handle_json_rpc( + { + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": {"name": "graph_architecture_queries", "arguments": {"group": "missing"}}, + } + ) + + assert listed is not None + assert all_queries is not None + assert filtered is not None + assert invalid is not None + assert any(tool["name"] == "graph_architecture_queries" for tool in listed["result"]["tools"]) + assert all_queries["result"]["structuredContent"]["workflow"] == "coding_task_architecture_discovery" + assert all_queries["result"]["structuredContent"]["execution_tool"] == "graph_query" + assert [group["name"] for group in filtered["result"]["structuredContent"]["groups"]] == ["overview"] + assert invalid["result"]["isError"] is True + assert invalid["result"]["structuredContent"]["error"]["type"] == "ValueError" + + +def test_mcp_rejects_tools_before_initialize(tmp_path: Path) -> None: + db_path = tmp_path / "graph.ldb" + db_path.write_text("", encoding="utf-8") + server = McpGraphServer(GraphRuntimeConfig(repo_root=tmp_path, db_path=db_path)) + + listed = server.handle_json_rpc({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}) + called = server.handle_json_rpc( + {"jsonrpc": "2.0", "id": 2, "method": "tools/call", "params": {"name": "graph_health", "arguments": {}}} + ) + + assert listed is not None + assert called is not None + assert listed["error"]["code"] == -32002 + assert called["error"]["code"] == -32002 + + +def test_descriptor_prefers_current_environment_script(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + bin_dir = tmp_path / "venv" / "bin" + bin_dir.mkdir(parents=True) + python_path = bin_dir / "python" + script_path = bin_dir / "codebase-graph" + python_path.write_text("", encoding="utf-8") + script_path.write_text("", encoding="utf-8") + script_path.chmod(0o755) + monkeypatch.setattr(sys, "executable", python_path.as_posix()) + monkeypatch.setenv("PATH", "") + + descriptor = build_server_descriptor(tmp_path / ".codebaseGraph" / "config.json") + + assert descriptor.command == script_path.as_posix() + assert descriptor.stdio_entry()["command"] == script_path.as_posix() + assert descriptor.as_dict()["transport"] == "stdio" + + +def test_client_adapters_emit_native_config_shapes(tmp_path: Path) -> None: + setup_config_path = tmp_path / ".codebaseGraph" / "config.json" + setup_config_path.parent.mkdir() + clients = set(supported_client_ids()) - {"none"} + + rendered = { + client: configure_mcp_client( + client=client, + config_path=tmp_path / f"{client}.config", + setup_config_path=setup_config_path, + dry_run=True, + ).as_dict() + for client in clients + } + + codex_patch = rendered["codex"]["patch"] + codex_payload = tomllib.loads(codex_patch) + assert codex_payload["mcp_servers"]["codebase_graph"]["command"] + assert codex_payload["mcp_servers"]["codebase_graph"]["startup_timeout_sec"] == 60 + assert "type" not in rendered["claude"]["payload"]["mcpServers"]["codebase_graph"] + assert rendered["claude"]["payload"]["mcpServers"]["codebase_graph"]["command"] + assert rendered["claude-project"]["payload"]["mcpServers"]["codebase_graph"]["type"] == "stdio" + assert rendered["lmstudio"]["payload"]["mcpServers"]["codebase_graph"]["type"] == "stdio" + assert rendered["generic"]["payload"]["mcpServers"]["codebase_graph"]["args"][0:2] == ["mcp", "serve"] + assert rendered["openclaw"]["payload"]["mcp"]["servers"]["codebase_graph"]["type"] == "stdio" + assert "mcp_servers:" in rendered["hermes"]["patch"] + + +def test_unsupported_mcp_client_lists_supported_clients(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="Supported clients:"): + configure_mcp_client( + client="missing", + config_path=tmp_path / "missing.json", + setup_config_path=tmp_path / ".codebaseGraph" / "config.json", + ) + + +def test_stdio_mcp_wire_initialize_list_call_and_tool_error(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + result = run_setup(SetupOptions(repo_root=repo_root, mcp_client="none", instructions_target="skip")) + setup_payload = json.loads(result.paths.config_path.read_text(encoding="utf-8")) + command = setup_payload["mcp"]["command"] + + proc = subprocess.Popen( + command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + assert proc.stdin is not None + assert proc.stdout is not None + try: + initialized = _rpc(proc.stdin, proc.stdout, "initialize", {"protocolVersion": "2025-11-25"}) + listed = _rpc(proc.stdin, proc.stdout, "tools/list", {}) + health = _rpc(proc.stdin, proc.stdout, "tools/call", {"name": "graph_health", "arguments": {}}) + search = _rpc( + proc.stdin, + proc.stdout, + "tools/call", + {"name": "graph_search", "arguments": {"query": "SampleService", "limit": 2}}, + ) + block_search = _rpc( + proc.stdin, + proc.stdout, + "tools/call", + {"name": "graph_search", "arguments": {"query": "SampleService", "limit": 2, "output_format": "block"}}, + ) + failure = _rpc( + proc.stdin, + proc.stdout, + "tools/call", + {"name": "graph_query", "arguments": {"statement": "MATCH (n) DELETE n"}}, + ) + finally: + proc.stdin.close() + proc.wait(timeout=10) + + assert initialized["result"]["protocolVersion"] == "2025-11-25" + assert {tool["name"] for tool in listed["result"]["tools"]} >= {"graph_health", "graph_search", "graph_query"} + graph_search_tool = next(tool for tool in listed["result"]["tools"] if tool["name"] == "graph_search") + assert "context_limit" in graph_search_tool["inputSchema"]["properties"] + assert graph_search_tool["inputSchema"]["properties"]["detail"]["enum"] == ["slim", "standard"] + assert graph_search_tool["inputSchema"]["properties"]["output_format"]["enum"] == ["json", "block"] + assert health["result"]["structuredContent"]["ok"] is True + assert search["result"]["structuredContent"]["results"] + assert "\n " not in search["result"]["content"][0]["text"] + assert block_search["result"]["structuredContent"] == search["result"]["structuredContent"] + assert block_search["result"]["content"][0]["text"].startswith("q SampleService\n") + assert "id=Class:" in block_search["result"]["content"][0]["text"] + assert "error" not in failure + assert failure["result"]["isError"] is True + assert failure["result"]["structuredContent"]["error"]["type"] == "ValueError" + + +def test_stdio_mcp_malformed_frame_returns_parse_error(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + result = run_setup(SetupOptions(repo_root=repo_root, mcp_client="none", instructions_target="skip")) + setup_payload = json.loads(result.paths.config_path.read_text(encoding="utf-8")) + + completed = subprocess.run( + setup_payload["mcp"]["command"], + input=b"Content-Length: 1\r\n\r\n{", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + responses = _stdio_messages(completed.stdout) + assert completed.returncode == 0 + stderr_events = [json.loads(line) for line in completed.stderr.decode("utf-8").splitlines()] + assert stderr_events[0]["event"] == "mcp.stdio_parse_error" + assert responses[0]["error"]["code"] == -32700 + + +def test_http_mcp_rejects_remote_bind_without_explicit_opt_in(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="localhost"): + build_http_server(repo_root=tmp_path, db_path=tmp_path / "missing.ldb", host="0.0.0.0", port=0) + + +def test_http_mcp_rejects_remote_bind_without_auth_token(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="auth token"): + build_http_server( + repo_root=tmp_path, + db_path=tmp_path / "missing.ldb", + host="0.0.0.0", + port=0, + allow_remote=True, + ) + + +def test_http_mcp_accepts_remote_bind_with_auth_token(tmp_path: Path) -> None: + db_path = tmp_path / "graph.ldb" + db_path.write_text("", encoding="utf-8") + try: + httpd = build_http_server( + repo_root=tmp_path, + db_path=db_path, + host="0.0.0.0", + port=0, + allow_remote=True, + auth_token="secret-token", + ) + except PermissionError as exc: + pytest.skip(f"remote socket bind is unavailable in this environment: {exc}") + + httpd.server_close() + + +def test_http_mcp_transport_handles_initialize_list_and_call(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + result = run_setup(SetupOptions(repo_root=repo_root, mcp_client="none", instructions_target="skip")) + try: + httpd = build_http_server(config_path=result.paths.config_path, host="127.0.0.1", port=0) + except PermissionError as exc: + pytest.skip(f"local socket bind is unavailable in this environment: {exc}") + thread = threading.Thread(target=httpd.serve_forever, daemon=True) + thread.start() + host, port = httpd.server_address + try: + initialize, session_id = _http_rpc_with_session(host, port, "initialize", {"protocolVersion": "2025-11-25"}) + with pytest.raises(urllib.error.HTTPError) as missing_session: + _http_rpc(host, port, "tools/list", {}) + listed = _http_rpc(host, port, "tools/list", {}, session_id=session_id) + health = _http_rpc( + host, + port, + "tools/call", + {"name": "graph_health", "arguments": {}}, + session_id=session_id, + ) + with pytest.raises(urllib.error.HTTPError) as exc_info: + _http_rpc(host, port, "ping", {}, protocol_version="1900-01-01") + finally: + httpd.shutdown() + httpd.server_close() + thread.join(timeout=10) + + assert initialize["result"]["protocolVersion"] == "2025-11-25" + assert missing_session.value.code == 400 + assert any(tool["name"] == "graph_context" for tool in listed["result"]["tools"]) + assert health["result"]["structuredContent"]["ok"] is True + assert exc_info.value.code == 400 + + +def test_http_mcp_transport_enforces_bearer_token_when_configured(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + result = run_setup(SetupOptions(repo_root=repo_root, mcp_client="none", instructions_target="skip")) + try: + httpd = build_http_server(config_path=result.paths.config_path, host="127.0.0.1", port=0, auth_token="secret") + except PermissionError as exc: + pytest.skip(f"local socket bind is unavailable in this environment: {exc}") + thread = threading.Thread(target=httpd.serve_forever, daemon=True) + thread.start() + host, port = httpd.server_address + try: + with pytest.raises(urllib.error.HTTPError) as missing_exc: + _http_rpc(host, port, "initialize", {"protocolVersion": "2025-11-25"}, origin=f"http://{host}:{port}") + with pytest.raises(urllib.error.HTTPError) as wrong_exc: + _http_rpc( + host, + port, + "initialize", + {"protocolVersion": "2025-11-25"}, + auth_token="wrong", + origin=f"http://{host}:{port}", + ) + initialized = _http_rpc( + host, + port, + "initialize", + {"protocolVersion": "2025-11-25"}, + auth_token="secret", + origin=f"http://{host}:{port}", + ) + finally: + httpd.shutdown() + httpd.server_close() + thread.join(timeout=10) + + assert missing_exc.value.code == 401 + assert wrong_exc.value.code == 401 + assert initialized["result"]["protocolVersion"] == "2025-11-25" + + +def _rpc(stdin: BinaryIO, stdout: BinaryIO, method: str, params: dict[str, Any]) -> dict[str, Any]: + request_id = _rpc.counter + _rpc.counter += 1 + payload = json.dumps({"jsonrpc": "2.0", "id": request_id, "method": method, "params": params}).encode("utf-8") + stdin.write(f"Content-Length: {len(payload)}\r\n\r\n".encode("ascii") + payload) + stdin.flush() + return _read_stdio_response(stdout) + + +_rpc.counter = 1 # type: ignore[attr-defined] + + +def _read_stdio_response(stdout: BinaryIO) -> dict[str, Any]: + header = stdout.readline() + assert header.lower().startswith(b"content-length:") + length = int(header.split(b":", 1)[1].strip()) + assert stdout.readline() in {b"\r\n", b"\n"} + return json.loads(stdout.read(length).decode("utf-8")) + + +def _stdio_messages(data: bytes) -> list[dict[str, Any]]: + messages: list[dict[str, Any]] = [] + cursor = 0 + while cursor < len(data): + header_end = data.find(b"\r\n\r\n", cursor) + assert header_end != -1 + header = data[cursor:header_end].decode("ascii") + length = None + for line in header.splitlines(): + if line.lower().startswith("content-length:"): + length = int(line.split(":", 1)[1].strip()) + break + assert length is not None + body_start = header_end + 4 + body_end = body_start + length + messages.append(json.loads(data[body_start:body_end].decode("utf-8"))) + cursor = body_end + return messages + + +def _http_rpc( + host: str, + port: int, + method: str, + params: dict[str, Any], + *, + protocol_version: str = "2025-11-25", + auth_token: str | None = None, + origin: str | None = None, + session_id: str | None = None, +) -> dict[str, Any]: + return _http_rpc_with_headers( + host, + port, + method, + params, + protocol_version=protocol_version, + auth_token=auth_token, + origin=origin, + session_id=session_id, + )[0] + + +def _http_rpc_with_session( + host: str, + port: int, + method: str, + params: dict[str, Any], + *, + protocol_version: str = "2025-11-25", + auth_token: str | None = None, + origin: str | None = None, + session_id: str | None = None, +) -> tuple[dict[str, Any], str]: + payload, headers = _http_rpc_with_headers( + host, + port, + method, + params, + protocol_version=protocol_version, + auth_token=auth_token, + origin=origin, + session_id=session_id, + ) + resolved_session_id = headers.get("Mcp-Session-Id") + assert resolved_session_id + return payload, resolved_session_id + + +def _http_rpc_with_headers( + host: str, + port: int, + method: str, + params: dict[str, Any], + *, + protocol_version: str = "2025-11-25", + auth_token: str | None = None, + origin: str | None = None, + session_id: str | None = None, +) -> tuple[dict[str, Any], Any]: + payload = json.dumps({"jsonrpc": "2.0", "id": 1, "method": method, "params": params}).encode("utf-8") + headers = { + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + "MCP-Protocol-Version": protocol_version, + "Origin": origin or f"http://{host}:{port}", + } + if auth_token is not None: + headers["Authorization"] = f"Bearer {auth_token}" + if session_id is not None: + headers["Mcp-Session-Id"] = session_id + request = urllib.request.Request( + f"http://{host}:{port}/mcp", + data=payload, + headers=headers, + method="POST", + ) + with urllib.request.urlopen(request, timeout=10) as response: + return json.loads(response.read().decode("utf-8")), response.headers + + +def _fresh_repo(tmp_path: Path) -> Path: + repo_root = tmp_path / "fresh_repo" + package = repo_root / "sample_project" + package.mkdir(parents=True) + (package / "__init__.py").write_text("", encoding="utf-8") + (package / "service.py").write_text( + "class SampleService:\n" + " def run(self) -> str:\n" + " return helper()\n\n" + "def helper() -> str:\n" + " return 'ok'\n", + encoding="utf-8", + ) + (repo_root / "README.md").write_text( + "# Fresh Repo\n\nThis repository documents the SampleService workflow.\n", + encoding="utf-8", + ) + return repo_root diff --git a/tests/test_ontology.py b/tests/test_ontology.py new file mode 100644 index 0000000..7ea6e6a --- /dev/null +++ b/tests/test_ontology.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import json +import re + +from codebase_graph.ontology import ( + ONTOLOGY_NAME, + PARSER_NODE_MAPPINGS, + QUERY_HELPERS, + RELATION_TYPES, + get_node_type, + get_relation_type, + node_type_names, + relation_type_names, + schema_payload, +) + + +def test_schema_payload_is_json_serializable() -> None: + payload = schema_payload() + + encoded = json.dumps(payload, sort_keys=True) + + assert ONTOLOGY_NAME in encoded + assert payload["node_types"] + assert payload["relation_types"] + + +def test_required_node_types_are_declared() -> None: + names = set(node_type_names()) + + assert { + "Module", + "ImportDeclaration", + "ExportDeclaration", + "Symbol", + "Scope", + "Class", + "Function", + "Method", + "Parameter", + "ReturnType", + "TypeAnnotation", + "TypeAlias", + "Variable", + "Constant", + "ClassAttribute", + "InstanceAttribute", + "Property", + "Decorator", + "CallExpression", + "Assignment", + "Reference", + "Literal", + "Expression", + "ControlFlowBlock", + "ExceptionFlow", + "APIEndpoint", + "Component", + "Route", + "Query", + "SecretRef", + "Repository", + "SourceRoot", + "File", + "Dependency", + "DocumentationSource", + "DocumentationChunk", + "SyntaxCapture", + } <= names + + +def test_declared_relation_endpoints_reference_declared_node_types() -> None: + names = set(node_type_names()) + + for relation in RELATION_TYPES: + assert relation.source_types + assert relation.target_types + assert set(relation.source_types) <= names + assert set(relation.target_types) <= names + + +def test_parser_node_mappings_reference_declared_nodes_and_relations() -> None: + nodes = set(node_type_names()) + relations = set(relation_type_names()) + + for mapping in PARSER_NODE_MAPPINGS: + assert mapping.parser_node_types + assert mapping.target_node_types + assert set(mapping.target_node_types) <= nodes + assert set(mapping.relation_types) <= relations + + +def test_example_parser_shapes_are_covered() -> None: + covered_parser_nodes = {node for mapping in PARSER_NODE_MAPPINGS for node in mapping.parser_node_types} + + assert { + "Module", + "ImportFrom", + "ClassDef", + "FunctionDef", + "AnnAssign", + "Assign", + "Call", + "Name", + "Attribute", + "Constant", + } <= covered_parser_nodes + + +def test_query_helpers_are_read_only() -> None: + forbidden = re.compile(r"\b(CREATE|MERGE|DELETE|SET|DROP|LOAD|COPY)\b", re.IGNORECASE) + + for helper in QUERY_HELPERS: + assert helper.query.lstrip().upper().startswith("MATCH ") + assert not forbidden.search(helper.query) + + +def test_query_helpers_use_edge_node_relation_traversal() -> None: + direct_relation = re.compile(r"-\[:(?!FROM_|TO_)([A-Za-z][A-Za-z0-9_]*)\]->") + + for helper in QUERY_HELPERS: + assert not direct_relation.search(helper.query), helper.name + + helper_queries = {helper.name: helper.query for helper in QUERY_HELPERS} + assert "[:FROM_ResolvesTo]" in helper_queries["callgraph_neighborhood"] + assert "[:TO_ResolvesTo]" in helper_queries["callgraph_neighborhood"] + assert "[:FROM_DependsOn]" in helper_queries["dependency_map"] + assert "[:TO_DependsOn]" in helper_queries["dependency_map"] + assert "[:FROM_RoutesTo]" in helper_queries["runtime_surface"] + assert "[:TO_RoutesTo]" in helper_queries["runtime_surface"] + assert "[:FROM_Documents]" in helper_queries["documentation_context"] + assert "[:TO_Documents]" in helper_queries["documentation_context"] + + +def test_query_helper_semantics_match_public_names() -> None: + helper_queries = {helper.name: helper.query for helper in QUERY_HELPERS} + + symbol_lookup = helper_queries["symbol_lookup"] + assert ":Class" in symbol_lookup + assert ":Function" in symbol_lookup + assert ":Method" in symbol_lookup + assert "s:Symbol" not in symbol_lookup + + unresolved_references = helper_queries["unresolved_references"] + assert "NOT EXISTS" in unresolved_references + assert "FROM_ResolvesTo" in unresolved_references + assert "TO_ResolvesTo" in unresolved_references + + +def test_lookup_helpers_return_expected_specs() -> None: + assert get_node_type("Class").name == "Class" + assert get_relation_type("Calls").name == "Calls" diff --git a/tests/test_question_query_registry.py b/tests/test_question_query_registry.py deleted file mode 100644 index cf0468f..0000000 --- a/tests/test_question_query_registry.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -import json - -import pytest - -import codebase_graph.graph_core as graph_core -from codebase_graph.question_query_registry import ( - PHASE_ARCHITECTURE_UNDERSTANDING, - PHASE_BREAKING_CHANGE_PREPARATION, - get_engineering_question_query, - list_engineering_question_queries, - main, -) - -class _StubCore: - def cypher(self, query, parameters=None): - return {"query": query, "parameters": parameters or {}} - -def test_registry_lists_queries_and_filters_by_phase() -> None: - all_queries = list_engineering_question_queries() - assert len(all_queries) >= 3 - architecture_queries = list_engineering_question_queries(phase=PHASE_ARCHITECTURE_UNDERSTANDING) - assert architecture_queries - assert all(query.phase == PHASE_ARCHITECTURE_UNDERSTANDING for query in architecture_queries) - -def test_get_query_by_id_and_run() -> None: - query = get_engineering_question_query("se.breaking.consumers_of_contract.v1") - assert query.phase == PHASE_BREAKING_CHANGE_PREPARATION - response = query.run(_StubCore(), contract_id="contract:example") - assert response["parameters"]["contract_id"] == "contract:example" - -def test_required_params_are_validated() -> None: - query = get_engineering_question_query("se.change.tests_for_artifact.v1") - with pytest.raises(ValueError): - query.run(_StubCore(), path="src/package", symbol="") - -def test_cli_lists_queries_by_phase(capsys) -> None: - exit_code = main(["list", "--phase", PHASE_ARCHITECTURE_UNDERSTANDING]) - assert exit_code == 0 - payload = json.loads(capsys.readouterr().out) - assert payload["count"] >= 1 - assert all(item["phase"] == PHASE_ARCHITECTURE_UNDERSTANDING for item in payload["items"]) - -def test_cli_runs_query_with_params(monkeypatch, capsys) -> None: - class _FakeGraphCore: - def __init__(self, **kwargs): - self.kwargs = kwargs - - def ensure_current(self): - return {"ok": True} - - def cypher(self, query, parameters=None): - return {"query": query, "parameters": parameters or {}} - - monkeypatch.setattr(graph_core, "CodebaseGraph", _FakeGraphCore) - exit_code = main(["run", "se.breaking.consumers_of_contract.v1", "--params-json", '{"contract_id": "api:contract"}']) - assert exit_code == 0 - payload = json.loads(capsys.readouterr().out) - assert payload["question"]["id"] == "se.breaking.consumers_of_contract.v1" - assert payload["result"]["parameters"] == {"contract_id": "api:contract"} diff --git a/tests/test_release_workflows.py b/tests/test_release_workflows.py new file mode 100644 index 0000000..7217620 --- /dev/null +++ b/tests/test_release_workflows.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import re +from pathlib import Path + +from scripts import check_release_gate +from scripts.check_release_gate import ( + PYPI_CONFIRMATION_FLAGS, + _jobs_missing_timeout, + _workflow_action_pin_issues, + run_checks, +) + + +WORKFLOWS = ( + Path(".github/workflows/ci.yml"), + Path(".github/workflows/release.yml"), +) + + +def test_github_actions_are_pinned_to_immutable_commits() -> None: + for path in WORKFLOWS: + mutable_refs = _workflow_action_pin_issues(path, path.read_text(encoding="utf-8")) + + assert mutable_refs == [] + + +def test_action_pin_checker_rejects_bare_external_actions() -> None: + text = """ +jobs: + lint: + steps: + - uses: actions/checkout + - uses: ./.github/actions/local-smoke + - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 +""" + + issues = _workflow_action_pin_issues(Path(".github/workflows/example.yml"), text) + + assert [issue.code for issue in issues] == ["workflow-action-not-pinned"] + assert "actions/checkout" in issues[0].message + + +def test_release_workflows_smoke_test_wheel_and_sdist() -> None: + for path in WORKFLOWS: + text = path.read_text(encoding="utf-8") + assert "pip install dist/*.whl" in text + assert "pip install dist/*.tar.gz" in text + + +def test_release_workflow_enforces_production_gate_before_build() -> None: + text = Path(".github/workflows/release.yml").read_text(encoding="utf-8") + + assert "production-gate:" in text + assert "python scripts/check_release_gate.py" in text + assert "--production" in text + assert "build:\n name: build release distributions\n needs:\n - release-please\n - production-gate" in text + + +def test_release_workflow_can_smoke_test_pypi_environment_without_publishing() -> None: + text = Path(".github/workflows/release.yml").read_text(encoding="utf-8") + + assert "pypi-environment-smoke:" in text + assert "github.event_name == 'workflow_dispatch' && inputs.pypi-environment-smoke" in text + assert "name: pypi" in text + assert "id-token: write" in text + assert "audience=pypi" in text + assert '"environment": "pypi"' in text + assert ".github/workflows/release.yml@" in text + assert "pypi environment OIDC claims verified" in text + + +def test_release_please_is_skipped_during_pypi_environment_smoke() -> None: + text = Path(".github/workflows/release.yml").read_text(encoding="utf-8") + + assert "release-please:\n name: release please\n if: ${{ !inputs.pypi-environment-smoke }}" in text + + +def test_conda_recipe_uses_bounded_runtime_dependencies() -> None: + text = Path("conda-forge/recipe/meta.yaml").read_text(encoding="utf-8") + + assert '{% set pypi_name = "cbasegraph" %}' in text + assert "setuptools >=77" in text + assert "real-ladybug >=0.15.3,<0.16" in text + assert "tomli >=2.0.1" in text + assert "tree-sitter >=0.25.2,<0.26" in text + assert "tree-sitter-python >=0.25.0,<0.26" in text + assert "license: MIT" in text + assert "PUT_SPDX_LICENSE_ID_HERE" not in text + + +def test_hosted_workflows_run_real_vulnerability_scans() -> None: + for path in WORKFLOWS: + text = path.read_text(encoding="utf-8") + assert "pip_audit --strict" in text + assert "pip_audit --strict ." in text + assert "--skip-editable" not in text + assert re.search(r"pip_audit\b[^\n]*--dry-run", text) is None + + +def test_supply_chain_workflow_audits_project_dependencies() -> None: + text = Path(".github/workflows/ci.yml").read_text(encoding="utf-8") + match = re.search(r" supply-chain:\n(?P.*?)(?=\n [A-Za-z0-9_-]+:|\Z)", text, re.DOTALL) + + assert match is not None + body = match.group("body") + assert 'python -m pip install ".[dev]"' in body + assert 'python -m pip install -e ".[dev]"' not in body + assert "python -m pip_audit --strict ." in body + + +def test_project_metadata_uses_published_pypi_name() -> None: + text = Path("pyproject.toml").read_text(encoding="utf-8") + + assert 'name = "cbasegraph"' in text + + +def test_security_policy_exists() -> None: + text = Path("SECURITY.md").read_text(encoding="utf-8") + + assert "Reporting a Vulnerability" in text + assert "graph_query" in text + assert "--allow-remote" in text + + +def test_release_docs_list_production_confirmation_flags() -> None: + text = Path("docs/release.md").read_text(encoding="utf-8") + + assert "PyPI project: `cbasegraph`" in text + for flag in PYPI_CONFIRMATION_FLAGS: + env_var = f"CODEBASE_GRAPH_CONFIRM_{flag.upper().replace('-', '_')}" + assert env_var in text + assert f"--confirm {flag}" in text + assert "CODEBASE_GRAPH_REQUIRE_CONDA" in text + assert "--require-conda" in text + + +def test_workflow_jobs_have_timeouts() -> None: + for path in WORKFLOWS: + missing = _jobs_missing_timeout(path.read_text(encoding="utf-8")) + + assert missing == [] + + +def test_workflows_pin_node24_capable_first_party_actions() -> None: + for path in WORKFLOWS: + text = path.read_text(encoding="utf-8") + + assert "actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd" in text + assert "actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405" in text + assert "actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5" not in text + assert "actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065" not in text + + +def test_workflows_avoid_node20_artifact_actions() -> None: + for path in WORKFLOWS: + text = path.read_text(encoding="utf-8") + + assert "actions/upload-artifact@" not in text + assert "actions/download-artifact@" not in text + + +def test_release_workflow_downloads_distributions_from_github_release() -> None: + text = Path(".github/workflows/release.yml").read_text(encoding="utf-8") + + assert 'gh release download "$RELEASE_TAG" --dir dist' in text + assert "release {artifacts=} does not include a wheel" in text + assert "release {artifacts=} does not include a source distribution" in text + + +def test_workflows_avoid_hosted_cache_warning_annotations() -> None: + for path in WORKFLOWS: + text = path.read_text(encoding="utf-8") + + assert 'PIP_NO_CACHE_DIR: "1"' in text + assert "cache: pip" not in text + + +def test_ci_uses_explicit_windows_runner_label() -> None: + text = Path(".github/workflows/ci.yml").read_text(encoding="utf-8") + + assert "windows-2022" in text + assert "windows-latest" not in text + + +def test_local_release_gate_passes() -> None: + assert run_checks(production=False, require_conda=False, confirmations=set()) == [] + + +def test_production_release_gate_reports_owner_controlled_blockers() -> None: + issues = run_checks(production=True, require_conda=True, confirmations=set()) + codes = {issue.code for issue in issues} + messages = {issue.message for issue in issues} + + assert "license-metadata-missing" not in codes + assert "license-file-missing" not in codes + assert "external-confirmation-missing" in codes + assert "conda-placeholder" in codes + assert "conda recipe still contains PUT_RELEASE_VERSION_HERE." in messages + assert "conda recipe still contains PUT_RELEASE_SDIST_SHA256_HERE." in messages + assert "conda recipe still contains PUT_SPDX_LICENSE_ID_HERE." not in messages + + +def test_release_gate_reports_missing_release_workflow(monkeypatch, tmp_path) -> None: + monkeypatch.setattr(check_release_gate, "REPO_ROOT", tmp_path) + + issues = check_release_gate._check_release_workflow_permissions() + + assert [issue.code for issue in issues] == ["workflow-missing"] + assert ".github/workflows/release.yml is required." in issues[0].message + + +def test_release_gate_reports_missing_conda_recipe(monkeypatch, tmp_path) -> None: + monkeypatch.setattr(check_release_gate, "REPO_ROOT", tmp_path) + + issues = check_release_gate._check_conda_recipe() + + assert [issue.code for issue in issues] == ["conda-recipe-missing"] + assert "conda-forge/recipe/meta.yaml is required." in issues[0].message diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..b961856 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from codebase_graph.core import CodeGraph, GraphEdge, GraphNode +from codebase_graph.db import ( + LadybugCodeGraphStore, + build_ladybug_schema, + build_ladybug_schema_statements, + create_ladybug_database, + ladybug_type, + quote_identifier, +) +from codebase_graph.ontology import NODE_TYPES, RELATION_TYPES, SEARCH_INDEXES + + +def test_ladybug_schema_declares_all_ontology_nodes_and_edge_nodes() -> None: + schema = build_ladybug_schema() + + for node_type in NODE_TYPES: + assert f"CREATE NODE TABLE IF NOT EXISTS `{node_type.name}`" in schema + for relation_type in RELATION_TYPES: + assert f"CREATE NODE TABLE IF NOT EXISTS `{relation_type.name}`" in schema + + +def test_ladybug_schema_declares_from_and_to_connector_tables() -> None: + schema = build_ladybug_schema() + + for relation_type in RELATION_TYPES: + assert f"CREATE REL TABLE IF NOT EXISTS `FROM_{relation_type.name}`" in schema + assert f"CREATE REL TABLE IF NOT EXISTS `TO_{relation_type.name}`" in schema + + for source_type in set(relation_type.source_types): + assert f"FROM `{source_type}` TO `{relation_type.name}`" in schema + for target_type in set(relation_type.target_types): + assert f"FROM `{relation_type.name}` TO `{target_type}`" in schema + + +def test_ladybug_schema_keeps_relation_payload_on_edge_nodes() -> None: + schema = build_ladybug_schema() + + contains_start = schema.index("CREATE NODE TABLE IF NOT EXISTS `Contains`") + from_contains_start = schema.index("CREATE REL TABLE IF NOT EXISTS `FROM_Contains`") + contains_table = schema[contains_start:from_contains_start] + + assert "`id` STRING PRIMARY KEY" in contains_table + assert "`kind` STRING" in contains_table + assert "`source_id` STRING" in contains_table + assert "`target_id` STRING" in contains_table + assert "`confidence` DOUBLE" in contains_table + assert "`metadata` JSON" in contains_table + + +def test_ladybug_schema_maps_types_and_quotes_identifiers() -> None: + assert ladybug_type("string") == "STRING" + assert ladybug_type("integer") == "INT64" + assert ladybug_type("number") == "DOUBLE" + assert ladybug_type("boolean") == "BOOLEAN" + assert ladybug_type("json") == "JSON" + assert quote_identifier("Query") == "`Query`" + assert quote_identifier("odd`name") == "`odd``name`" + + +def test_ladybug_schema_rejects_unknown_field_type() -> None: + with pytest.raises(ValueError, match="Unsupported ontology field type"): + ladybug_type("object") + + +def test_ladybug_schema_deduplicates_connector_endpoint_pairs() -> None: + schema = build_ladybug_schema() + + assert schema.count("FROM `Contains` TO `Assignment`") == 1 + assert schema.count("FROM `Contains` TO `Query`") == 1 + assert "TO `Assignment` |" not in schema + + +def test_ladybug_schema_creates_fts_indexes_for_semantic_node_tables_only() -> None: + statements = build_ladybug_schema_statements(include_fts=True) + schema = ";\n".join(statements) + relation_names = {relation_type.name for relation_type in RELATION_TYPES} + + assert "INSTALL fts" in statements + for index in SEARCH_INDEXES: + for node_type in index["node_types"]: + assert f"CALL CREATE_FTS_INDEX('{node_type}', '{index['name']}_{node_type}'" in schema + assert node_type not in relation_names + + +def test_ladybug_schema_can_skip_fts_statements() -> None: + statements = build_ladybug_schema_statements(include_fts=False) + + assert "INSTALL json" in statements + assert "LOAD json" in statements + assert "INSTALL fts" not in statements + assert "LOAD fts" not in statements + assert not any(statement.startswith("CALL CREATE_FTS_INDEX") for statement in statements) + + +def test_ladybug_schema_executes_against_in_memory_database() -> None: + real_ladybug = pytest.importorskip("real_ladybug") + conn = real_ladybug.Connection(real_ladybug.Database(":memory:")) + + for statement in build_ladybug_schema_statements(): + conn.execute(statement) + + +def test_ladybug_store_creates_in_memory_database_without_persistent_file() -> None: + pytest.importorskip("real_ladybug") + + store = create_ladybug_database(":memory:") + + assert isinstance(store, LadybugCodeGraphStore) + assert store.schema_sql.startswith("INSTALL json") + + +def test_ladybug_store_schema_setup_is_idempotent() -> None: + pytest.importorskip("real_ladybug") + store = create_ladybug_database(":memory:") + + store.ensure_schema() + + +def test_ladybug_store_allows_multiple_read_only_handles(tmp_path: Path) -> None: + pytest.importorskip("real_ladybug") + db_path = tmp_path / "graph.ldb" + writer = create_ladybug_database(db_path, include_fts=False) + writer.close() + + first = create_ladybug_database(db_path, include_fts=False, read_only=True) + try: + second = create_ladybug_database(db_path, include_fts=False, read_only=True) + second.close() + finally: + first.close() + + +def test_ladybug_store_bulk_loader_groups_rows_by_table() -> None: + graph = CodeGraph() + graph.add_node(GraphNode(id="file:service", table="File", label="service.py", kind="source_file")) + graph.add_node(GraphNode(id="function:one", table="Function", label="one", kind="function")) + graph.add_node(GraphNode(id="function:two", table="Function", label="two", kind="function")) + graph.add_edge(GraphEdge(id="contains:one", type="Contains", source_id="file:service", target_id="function:one")) + graph.add_edge(GraphEdge(id="contains:two", type="Contains", source_id="file:service", target_id="function:two")) + store = object.__new__(LadybugCodeGraphStore) + statements: list[str] = [] + store.execute = lambda statement, parameters=None: statements.append(statement) # type: ignore[method-assign] + + stats = store.insert_graphs_bulk([graph]) + + assert stats.node_rows == 3 + assert stats.edge_rows == 2 + assert stats.connector_rows == 4 + assert stats.copy_calls == 5 + assert len(statements) == 5 + assert any(statement.startswith("COPY `File`") and statement.endswith('";') for statement in statements) + assert any(statement.startswith("COPY `Function`") and statement.endswith('";') for statement in statements) + assert any(statement.startswith("COPY `Contains`") and statement.endswith('";') for statement in statements) + assert any('COPY `FROM_Contains`' in statement and 'from="File", to="Contains"' in statement for statement in statements) + assert any('COPY `TO_Contains`' in statement and 'from="Contains", to="Function"' in statement for statement in statements) diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..1e76066 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,803 @@ +from __future__ import annotations + +import json +import shutil +from pathlib import Path +from typing import Any + +import pytest + +from codebase_graph.cli import _build_parser, main as cli_main +from codebase_graph.db import GraphNeighbor, SearchIndexRow +from codebase_graph.ingest import GraphMaterializer +from codebase_graph.mcp.graph_commands import graph_command_spec, graph_tool_specs +from codebase_graph.mcp.runtime import GraphRuntimeConfig +from codebase_graph.mcp.tools import MAX_GRAPH_QUERY_LIMIT, _query_payload, handle_tool_call, tool_specs +from codebase_graph.reasoning import CompactContextBuilder, ContextNode +from codebase_graph.retrieval.search import CompactContextPayload, SearchHit, SearchRequest, SearchService + + +class _Result: + def __init__(self, rows: list[list[Any]]) -> None: + self.rows = rows + self.requested_n: int | None = None + self.closed = False + + def get_all(self) -> list[list[Any]]: + return self.rows + + def get_n(self, count: int) -> list[list[Any]]: + self.requested_n = count + return self.rows[:count] + + def close(self) -> None: + self.closed = True + + +class _RecordingStore: + def __init__(self, rows: list[list[Any]] | None = None) -> None: + self.rows = rows or [] + self.calls: list[tuple[str, dict[str, Any] | None]] = [] + + def execute(self, statement: str, parameters: dict[str, Any] | None = None) -> _Result: + self.calls.append((statement, parameters)) + self.result = _Result(self.rows) + return self.result + + +class _Adapter: + def __init__(self) -> None: + self.search_calls: list[dict[str, Any]] = [] + self.neighbor_calls: list[dict[str, Any]] = [] + + def search_index(self, *, node_type: str, index_name: str, query: str, limit: int) -> list[SearchIndexRow]: + self.search_calls.append({"node_type": node_type, "index_name": index_name, "query": query, "limit": limit}) + if node_type != "Class": + return [] + return [ + SearchIndexRow( + id="opaque-class-id", + node_type="Class", + label="SampleService", + qualified_name="sample.SampleService", + path="sample/service.py", + score=1.0, + ) + ] + + def neighbors( + self, + *, + node_id: str, + node_type: str, + relation: str, + direction: str, + limit: int, + ) -> list[GraphNeighbor]: + self.neighbor_calls.append( + { + "node_id": node_id, + "node_type": node_type, + "relation": relation, + "direction": direction, + "limit": limit, + } + ) + if relation != "Defines" or direction != "outgoing": + return [] + return [ + GraphNeighbor( + node_id="opaque-neighbor-id", + node_type="Method", + label="run", + path="sample/service.py", + line_start=2, + line_end=3, + summary="Run the service.", + ) + ] + + +class _AdapterStore: + def __init__(self, adapter: _Adapter) -> None: + self.graph_query_adapter = adapter + + +def test_search_query_uses_ontology_index_names_and_parameterized_user_text() -> None: + malicious_query = "SampleService'); MATCH (n) RETURN n" + store = _RecordingStore() + + SearchService(store).search(SearchRequest(malicious_query, limit=2, budget=0)) + + assert store.calls + for statement, parameters in store.calls: + assert statement.startswith("CALL QUERY_FTS_INDEX('") + assert malicious_query not in statement + assert parameters == {"query": malicious_query, "top": 10} + + +def test_search_result_ranking_dedupes_by_node_id_preserving_best_raw_score() -> None: + service = SearchService(_RecordingStore()) + hits = [ + SearchHit(id="Function:helper", type="Function", label="helper", score=0.2, index_order=4), + SearchHit(id="Function:helper", type="Function", label="helper", score=0.7, index_order=4), + SearchHit(id="Class:SampleService", type="Class", label="SampleService", score=0.7, index_order=2), + SearchHit(id="File:service", type="File", label="service.py", path="service.py", score=0.1, index_order=8), + ] + + ranked = service._rank_hits(hits) + + assert [hit.id for hit in ranked] == ["Class:SampleService", "Function:helper", "File:service"] + assert ranked[1].score == 0.7 + + +def test_identifier_query_reranks_concrete_definition_above_generic_symbol() -> None: + service = SearchService(_RecordingStore()) + hits = [ + SearchHit( + id="Symbol:SampleService", + type="Symbol", + label="SampleService", + path="sample_project/cli.py", + score=1.4, + index_order=0, + ), + SearchHit( + id="Class:SampleService", + type="Class", + label="SampleService", + qualified_name="sample_project.service.SampleService", + path="sample_project/service.py", + score=0.2, + index_order=2, + ), + ] + + ranked = service._rank_hits(hits, query="SampleService", profile="brief") + + assert ranked[0].type == "Class" + assert ranked[0].score == 0.2 + assert ranked[0].rank_score > ranked[1].rank_score + assert ranked[1].score_components["generic_penalty"] > 0 + + +def test_generic_penalty_only_applies_when_matching_concrete_definition_exists() -> None: + service = SearchService(_RecordingStore()) + + without_definition = service._rank_hits( + [SearchHit(id="Symbol:SampleService", type="Symbol", label="SampleService", score=1.0)], + query="SampleService", + profile="brief", + ) + with_definition = service._rank_hits( + [ + SearchHit(id="Symbol:SampleService", type="Symbol", label="SampleService", score=1.0), + SearchHit(id="Class:SampleService", type="Class", label="SampleService", score=0.1), + ], + query="SampleService", + profile="brief", + ) + + assert without_definition[0].score_components["generic_penalty"] == 0 + symbol_hit = next(hit for hit in with_definition if hit.type == "Symbol") + assert symbol_hit.score_components["generic_penalty"] > 0 + + +def test_path_and_dependency_intents_boost_matching_ontology_families() -> None: + service = SearchService(_RecordingStore()) + path_ranked = service._rank_hits( + [ + SearchHit(id="Function:helper", type="Function", label="helper", path="sample_project/service.py", score=1.0), + SearchHit(id="File:service", type="File", label="service.py", path="sample_project/service.py", score=0.3), + ], + query="service.py", + profile="brief", + ) + dependency_ranked = service._rank_hits( + [ + SearchHit(id="Class:SampleService", type="Class", label="SampleService", score=1.0), + SearchHit(id="Dependency:service", type="Dependency", label=".service.SampleService", score=0.3), + ], + query=".service.SampleService", + profile="dependencies", + ) + + assert path_ranked[0].type == "File" + assert dependency_ranked[0].type == "Dependency" + + +def test_compact_context_respects_max_depth_limit_and_budget() -> None: + long_summary = "x" * 200 + store = _RecordingStore( + [["Method:run", "run", "sample.SampleService.run", "sample_project/service.py", 4, 6, long_summary]] + ) + builder = CompactContextBuilder(store) + + assert builder.build("Class:SampleService", "Class", profile="definitions", max_depth=0) == [] + + context = builder.build( + "Class:SampleService", + "Class", + profile="definitions", + limit=1, + budget=80, + max_depth=1, + ) + + assert len(context) == 1 + assert context[0].relation == "Defines" + assert context[0].direction == "outgoing" + assert context[0].summary + assert len(context[0].summary) < len(long_summary) + + +def test_compact_context_uses_adapter_types_and_opaque_node_ids() -> None: + adapter = _Adapter() + builder = CompactContextBuilder(_AdapterStore(adapter)) + + context = builder.build("opaque-class-id", "Class", profile="definitions", limit=1, budget=120, max_depth=1) + + assert context[0].id == "opaque-neighbor-id" + assert context[0].type == "Method" + assert context[0].label == "run" + assert adapter.neighbor_calls[0]["node_id"] == "opaque-class-id" + + +def test_search_service_uses_query_adapter_for_fts() -> None: + adapter = _Adapter() + + payload = SearchService(_AdapterStore(adapter)).search(SearchRequest("SampleService", limit=1, budget=0)) + + data = payload.as_dict() + assert data["results"][0]["id"] == "opaque-class-id" + assert data["results"][0]["type"] == "Class" + assert adapter.search_calls + + +def test_search_request_rejects_invalid_context_limit_and_detail() -> None: + with pytest.raises(ValueError, match="Context limit must be zero or greater"): + SearchRequest("SampleService", context_limit=-1).validate() + with pytest.raises(ValueError, match="Unknown detail level"): + SearchRequest("SampleService", detail="debug").validate() + + +def test_search_service_respects_zero_context_limit() -> None: + adapter = _Adapter() + + payload = SearchService(_AdapterStore(adapter)).search(SearchRequest("SampleService", limit=1, context_limit=0)) + + data = payload.as_dict() + assert data["results"][0]["context"] == [] + assert adapter.search_calls + assert adapter.neighbor_calls == [] + + +def test_search_request_rejects_invalid_profile() -> None: + with pytest.raises(ValueError, match="Unknown context profile"): + SearchRequest("SampleService", profile="missing").validate() + + +def test_slim_payload_omits_diagnostics_and_duplicate_summaries() -> None: + payload = CompactContextPayload( + query="run", + profile="brief", + limit=1, + budget=600, + results=( + SearchHit( + id="Method:run", + type="Method", + label="run", + qualified_name="sample.Service.run", + path="sample/service.py", + span={"line_start": 4, "line_end": 8}, + score=2.0, + rank_score=0.9, + score_components={"fts": 1.0}, + summary="run", + context=[ + ContextNode("Defines", "incoming", "Module", "sample.service", "sample/service.py", summary="sample.service"), + ContextNode("Documents", "outgoing", "DocumentationChunk", "Usage", "README.md", summary="Use run to start the service."), + ], + ), + ), + ) + + hit = payload.as_dict(detail="slim")["results"][0] + + assert hit == { + "id": "Method:run", + "type": "Method", + "label": "run", + "rank_score": 0.9, + "path": "sample/service.py", + "span": {"line_start": 4, "line_end": 8}, + "context": [ + { + "relation": "Defines", + "direction": "incoming", + "type": "Module", + "label": "sample.service", + "path": "sample/service.py", + }, + { + "relation": "Documents", + "direction": "outgoing", + "type": "DocumentationChunk", + "label": "Usage", + "path": "README.md", + "summary": "Use run to start the service.", + }, + ], + } + + +def test_search_service_returns_sample_class_with_compact_context(tmp_path: Path) -> None: + _require_graph_runtime() + materializer = _materialize_fixture(tmp_path, include_fts=True) + + payload = SearchService(materializer.store).search(SearchRequest("SampleService", limit=3)) + data = payload.as_dict() + + assert data["query"] == "SampleService" + assert data["profile"] == "brief" + class_hit = next(hit for hit in data["results"] if hit["type"] == "Class" and hit["label"] == "SampleService") + assert class_hit["path"] == "sample_project/service.py" + assert class_hit["score"] > 0 + assert class_hit["rank_score"] > 0 + assert class_hit["score_components"]["type"] > 0 + assert class_hit["context"] + assert any(item["type"] in {"Module", "Method"} for item in class_hit["context"]) + + +def test_search_service_returns_sample_class_first_for_exact_identifier(tmp_path: Path) -> None: + _require_graph_runtime() + materializer = _materialize_fixture(tmp_path, include_fts=True) + + payload = SearchService(materializer.store).search(SearchRequest("SampleService", limit=1)) + hit = payload.as_dict()["results"][0] + + assert hit["type"] == "Class" + assert hit["label"] == "SampleService" + assert hit["path"] == "sample_project/service.py" + assert hit["score"] < 1.0 + assert hit["rank_score"] > hit["score"] + + +def test_search_service_returns_function_hit_with_score_and_context(tmp_path: Path) -> None: + _require_graph_runtime() + materializer = _materialize_fixture(tmp_path, include_fts=True) + + payload = SearchService(materializer.store).search(SearchRequest("helper", limit=3)) + helper_hit = next(hit for hit in payload.as_dict()["results"] if hit["type"] == "Function" and hit["label"] == "helper") + + assert helper_hit["path"] == "sample_project/service.py" + assert helper_hit["span"]["line_start"] > 0 + assert helper_hit["score"] > 0 + assert helper_hit["rank_score"] > 0 + assert helper_hit["context"] + + +def test_cli_search_and_context_return_compact_json_without_refresh(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + _require_graph_runtime() + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + assert cli_main([ + "materialize", + "--source-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--mode", + "full", + ]) == 0 + capsys.readouterr() + + assert cli_main([ + "search", + "SampleService", + "--source-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--no-refresh", + "--json", + ]) == 0 + search_payload = json.loads(capsys.readouterr().out) + assert search_payload["results"] + assert any(hit["label"] == "SampleService" for hit in search_payload["results"]) + + assert cli_main([ + "search", + "SampleService", + "--source-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--limit", + "1", + "--no-refresh", + "--json", + ]) == 0 + top_payload = json.loads(capsys.readouterr().out) + assert top_payload["results"][0]["type"] == "Class" + assert top_payload["results"][0]["rank_score"] > top_payload["results"][0]["score"] + + assert cli_main([ + "context", + "helper", + "--source-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--no-refresh", + ]) == 0 + context_payload = json.loads(capsys.readouterr().out) + assert context_payload["results"] + assert any(hit["label"] == "helper" and hit["context"] for hit in context_payload["results"]) + + +def test_cli_graph_commands_match_mcp_tool_payloads(tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None: + _require_graph_runtime() + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + + assert cli_main([ + "materialize", + "--source-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--mode", + "full", + ]) == 0 + capsys.readouterr() + runtime = GraphRuntimeConfig(repo_root=source_root, db_path=db_path, manifest_path=manifest_path) + + assert cli_main([ + "graph-health", + "--repo-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + ]) == 0 + assert json.loads(capsys.readouterr().out) == handle_tool_call("graph_health", {}, runtime=runtime) + + search_args = { + "query": "SampleService", + "limit": 2, + "profile": "brief", + "budget": 600, + "context_limit": 1, + "detail": "slim", + } + assert cli_main([ + "graph-search", + "SampleService", + "--repo-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--limit", + "2", + "--context-limit", + "1", + "--detail", + "slim", + "--no-refresh", + "--json", + ]) == 0 + search_payload = json.loads(capsys.readouterr().out) + assert search_payload == handle_tool_call("graph_search", search_args, runtime=runtime) + assert "score" not in search_payload["results"][0] + assert len(search_payload["results"][0].get("context", [])) <= 1 + + assert cli_main([ + "graph-search", + "SampleService", + "--repo-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--limit", + "2", + "--context-limit", + "1", + "--detail", + "slim", + "--no-refresh", + "--format", + "block", + ]) == 0 + block_output = capsys.readouterr().out + assert block_output.startswith("q SampleService\n") + assert "file path sample_project/service.py" in block_output + assert "id=Class:" in block_output + assert not block_output.lstrip().startswith("{") + + hit = next(item for item in search_payload["results"] if item["label"] == "SampleService") + context_args = { + "node_id": hit["id"], + "node_type": hit["type"], + "limit": 1, + "profile": "definitions", + "budget": 600, + "context_limit": 3, + "detail": "slim", + } + assert cli_main([ + "graph-context", + "--node-id", + hit["id"], + "--node-type", + hit["type"], + "--repo-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--profile", + "definitions", + "--limit", + "1", + "--detail", + "slim", + ]) == 0 + assert json.loads(capsys.readouterr().out) == handle_tool_call("graph_context", context_args, runtime=runtime) + + assert cli_main([ + "graph-context", + "--node-id", + hit["id"], + "--node-type", + hit["type"], + "--repo-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--profile", + "definitions", + "--limit", + "1", + "--detail", + "slim", + "--format", + "block", + ]) == 0 + context_block = capsys.readouterr().out + assert context_block.startswith(f"context {hit['type']} id={hit['id']} profile=definitions\n") + assert "file path " in context_block + + statement = "MATCH (n) RETURN count(n) AS total_nodes LIMIT 1" + query_args = {"statement": statement, "parameters": {}, "limit": 5} + assert cli_main([ + "graph-query", + statement, + "--repo-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + "--limit", + "5", + ]) == 0 + assert json.loads(capsys.readouterr().out) == handle_tool_call("graph_query", query_args, runtime=runtime) + + +def test_graph_command_specs_drive_mcp_tool_specs() -> None: + assert tool_specs() == graph_tool_specs() + + +def test_graph_command_specs_build_cli_payloads() -> None: + parser = _build_parser() + cases = [ + ( + [ + "graph-search", + "SampleService", + "--limit", + "2", + "--context-limit", + "1", + "--detail", + "slim", + ], + "graph_search", + { + "query": "SampleService", + "limit": 2, + "profile": "brief", + "budget": 600, + "context_limit": 1, + "detail": "slim", + }, + ), + ( + [ + "graph-context", + "--node-id", + "Class:1", + "--node-type", + "Class", + "--profile", + "definitions", + "--limit", + "1", + "--detail", + "slim", + ], + "graph_context", + { + "node_id": "Class:1", + "node_type": "Class", + "limit": 1, + "profile": "definitions", + "budget": 600, + "context_limit": 3, + "detail": "slim", + }, + ), + ( + [ + "graph-query", + "MATCH (n) RETURN n", + "--parameters", + '{"limit": 1}', + "--limit", + "5", + ], + "graph_query", + {"statement": "MATCH (n) RETURN n", "parameters": {"limit": 1}, "limit": 5}, + ), + ] + for argv, tool_name, expected_payload in cases: + args = parser.parse_args(argv) + spec = graph_command_spec(args.command) + + assert spec.tool_name == tool_name + assert spec.payload_from_args(args) == expected_payload + + +def test_cli_graph_metadata_commands_do_not_open_graph_db(capsys: pytest.CaptureFixture[str]) -> None: + assert cli_main(["graph-schema"]) == 0 + schema_output = capsys.readouterr().out + assert "\n " not in schema_output + schema = json.loads(schema_output) + assert schema["ontology"] + assert schema["context_profiles"] + + assert cli_main(["graph-schema", "--pretty"]) == 0 + pretty_schema_output = capsys.readouterr().out + assert "\n " in pretty_schema_output + assert json.loads(pretty_schema_output)["ontology"] + + assert cli_main(["graph-query-helpers"]) == 0 + helpers = json.loads(capsys.readouterr().out) + assert any(helper["name"] == "repository_overview" for helper in helpers["query_helpers"]) + + assert cli_main(["graph-architecture-queries", "--group", "overview"]) == 0 + architecture = json.loads(capsys.readouterr().out) + assert [group["name"] for group in architecture["groups"]] == ["overview"] + + +def test_cli_graph_query_rejects_write_like_statements(tmp_path: Path) -> None: + _require_graph_runtime() + source_root = _copy_fixture(tmp_path) + db_path = tmp_path / "graph.lbug" + manifest_path = tmp_path / "manifest.json" + materializer = GraphMaterializer(source_root, db_path=db_path, manifest_path=manifest_path) + try: + materializer.materialize(mode="full") + finally: + materializer.close() + + with pytest.raises(SystemExit) as exc_info: + cli_main([ + "graph-query", + "MATCH (n) DELETE n", + "--repo-root", + source_root.as_posix(), + "--db", + db_path.as_posix(), + "--manifest", + manifest_path.as_posix(), + ]) + + assert exc_info.value.code == 2 + + +def test_graph_query_fetches_limit_plus_one_rows_without_materializing_all() -> None: + store = _RecordingStore([[1], [2], [3], [4]]) + + payload = _query_payload(store, {"statement": "MATCH (n) RETURN n", "limit": 2}) + + assert store.result.requested_n == 3 + assert store.result.closed is True + assert payload == { + "statement": "MATCH (n) RETURN n", + "row_count": 2, + "rows": [[1], [2]], + "truncated": True, + } + + +def test_graph_query_rejects_unbounded_response_limits() -> None: + store = _RecordingStore([[1]]) + + with pytest.raises(ValueError, match="greater than zero"): + _query_payload(store, {"statement": "MATCH (n) RETURN n", "limit": 0}) + with pytest.raises(ValueError, match=f"{MAX_GRAPH_QUERY_LIMIT} or less"): + _query_payload(store, {"statement": "MATCH (n) RETURN n", "limit": MAX_GRAPH_QUERY_LIMIT + 1}) + + +def test_graph_query_rejects_procedure_calls() -> None: + store = _RecordingStore([[1]]) + + with pytest.raises(ValueError, match="blocked keyword: CALL"): + _query_payload(store, {"statement": "CALL CREATE_FTS_INDEX('File', 'label')"}) + + +@pytest.mark.parametrize( + ("statement", "keyword"), + [ + ("EXPORT DATABASE '/tmp/graph-export'", "EXPORT"), + ("IMPORT DATABASE '/tmp/graph-export'", "IMPORT"), + ("ATTACH '/tmp/other.ldb' AS other", "ATTACH"), + ("USE other", "USE"), + ("TRUNCATE TABLE File", "TRUNCATE"), + ("UPDATE File SET label = 'x'", "UPDATE"), + ], +) +def test_graph_query_rejects_database_administration_statements(statement: str, keyword: str) -> None: + store = _RecordingStore([[1]]) + + with pytest.raises(ValueError, match=f"blocked keyword: {keyword}"): + _query_payload(store, {"statement": statement}) + + +def _require_graph_runtime() -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + + +def _materialize_fixture(tmp_path: Path, *, include_fts: bool) -> GraphMaterializer: + source_root = _copy_fixture(tmp_path) + materializer = GraphMaterializer( + source_root, + db_path=":memory:", + manifest_path=tmp_path / "manifest.json", + include_fts=include_fts, + ) + materializer.materialize(mode="full") + return materializer + + +def _copy_fixture(tmp_path: Path) -> Path: + source = Path("tests/fixtures/sample_project") + target = tmp_path / "sample_project" + shutil.copytree(source, target) + return target diff --git a/tests/test_setup_workflow.py b/tests/test_setup_workflow.py new file mode 100644 index 0000000..5ec3739 --- /dev/null +++ b/tests/test_setup_workflow.py @@ -0,0 +1,348 @@ +from __future__ import annotations + +import json +import re +import sys +from pathlib import Path + +try: + import tomllib +except ImportError: # pragma: no cover - Python 3.10 compatibility + import tomli as tomllib + +import pytest + +from codebase_graph.cli import main as cli_main +from codebase_graph.db import LadybugUnavailableError +from codebase_graph.mcp.runtime import runtime_config +from codebase_graph.mcp.server import McpGraphServer, handle_tool_call +from codebase_graph.setup import SetupError, SetupOptions, run_setup +from codebase_graph.setup.instructions import END_MARKER, START_MARKER, upsert_instruction_block +from codebase_graph.setup.mcp_config import configure_mcp_client, server_entry +from codebase_graph.setup.state import build_setup_config, derive_setup_paths, load_setup_config, write_setup_config + + +def test_setup_cli_creates_state_db_mcp_config_instructions_and_searchable_docs( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + mcp_config_path = tmp_path / "config.toml" + + exit_code = cli_main( + [ + "setup", + "--repo-root", + repo_root.as_posix(), + "--mcp-client", + "codex", + "--mcp-config-path", + mcp_config_path.as_posix(), + ] + ) + first_output = json.loads(capsys.readouterr().out) + + assert exit_code == 0 + assert first_output["state_dir"] == (repo_root / ".codebaseGraph").as_posix() + assert first_output["db_path"] == (repo_root / ".codebaseGraph" / "fresh_repo_graph.ldb").as_posix() + assert Path(first_output["db_path"]).exists() + assert Path(first_output["config_path"]).exists() + assert first_output["materialization"]["rebuilt"] == 4 + assert first_output["instructions"]["path"] == (repo_root / "AGENTS.md").as_posix() + assert first_output["mcp_config"]["action"] == "created" + + agents_text = (repo_root / "AGENTS.md").read_text(encoding="utf-8") + assert agents_text.count(START_MARKER) == 1 + assert agents_text.count(END_MARKER) == 1 + assert "graph-search" in agents_text + assert "graph-context" in agents_text + assert "--format block" in agents_text + assert re.search(r"graph-search .*--json", agents_text) is None + assert re.search(r"graph-context .*--json", agents_text) is None + assert "AI agents must use block format" in agents_text + assert "graph-architecture-queries" in agents_text + assert "MCP server" not in agents_text + assert "graph_architecture_queries" not in agents_text + assert "graph_query" not in agents_text + assert ( + "It is prohibited to read the code source before you find the target files using the graph." + in agents_text + ) + mcp_payload = tomllib.loads(mcp_config_path.read_text(encoding="utf-8")) + assert "otherServer" not in mcp_payload.get("mcp_servers", {}) + assert mcp_payload["mcp_servers"]["codebase_graph"]["args"] == [ + "mcp", + "serve", + "--config", + (repo_root / ".codebaseGraph" / "config.json").as_posix(), + ] + + second_exit_code = cli_main( + [ + "setup", + "--repo-root", + repo_root.as_posix(), + "--mcp-config-path", + mcp_config_path.as_posix(), + ] + ) + second_output = json.loads(capsys.readouterr().out) + + assert second_exit_code == 0 + assert second_output["config_action"] == "unchanged" + assert second_output["instructions"]["action"] == "unchanged" + assert second_output["mcp_config"]["action"] == "unchanged" + assert (repo_root / "AGENTS.md").read_text(encoding="utf-8").count(START_MARKER) == 1 + + server = McpGraphServer.from_paths(config_path=repo_root / ".codebaseGraph" / "config.json") + docs_payload = handle_tool_call( + "graph_search", + {"query": "codebaseGraph workflow", "profile": "docs", "limit": 5}, + runtime=server.runtime, + ) + health_payload = handle_tool_call("graph_health", {}, runtime=server.runtime) + symbol_payload = handle_tool_call( + "graph_search", + {"query": "SampleService", "profile": "brief", "limit": 3}, + runtime=server.runtime, + ) + + assert health_payload["ok"] is True + assert health_payload["graph_readable"] is True + assert health_payload["total_nodes"] > 0 + assert any(hit["path"] == "AGENTS.md" for hit in docs_payload["results"]) + assert any(hit["label"] == "SampleService" for hit in symbol_payload["results"]) + + +def test_claude_instruction_target_uses_block_format(tmp_path: Path) -> None: + repo_root = tmp_path / "fresh_repo" + repo_root.mkdir() + + result = upsert_instruction_block( + repo_root, + target="claude", + server_name="codebase_graph", + config_path=repo_root / ".codebaseGraph" / "config.json", + ) + claude_text = (repo_root / "CLAUDE.md").read_text(encoding="utf-8") + + assert result.action == "created" + assert result.path == (repo_root / "CLAUDE.md").as_posix() + assert not (repo_root / "AGENTS.md").exists() + assert "--format block" in claude_text + assert re.search(r"graph-search .*--json", claude_text) is None + assert re.search(r"graph-context .*--json", claude_text) is None + + +def test_mcp_config_dry_run_preserves_existing_json_servers(tmp_path: Path) -> None: + config_path = tmp_path / "mcp.json" + config_path.write_text( + json.dumps({"mcpServers": {"otherServer": {"command": "other", "args": []}}}), + encoding="utf-8", + ) + setup_config_path = tmp_path / ".codebaseGraph" / "config.json" + + dry_run = configure_mcp_client( + client="generic", + config_path=config_path, + setup_config_path=setup_config_path, + dry_run=True, + ) + + assert dry_run.action == "dry_run" + assert "codebase_graph" not in json.loads(config_path.read_text(encoding="utf-8"))["mcpServers"] + + written = configure_mcp_client( + client="generic", + config_path=config_path, + setup_config_path=setup_config_path, + ) + payload = json.loads(config_path.read_text(encoding="utf-8")) + + assert written.action == "created" + assert set(payload["mcpServers"]) == {"otherServer", "codebase_graph"} + + +def test_server_entry_prefers_current_environment_script(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + bin_dir = tmp_path / "venv" / "bin" + bin_dir.mkdir(parents=True) + python_path = bin_dir / "python" + script_path = bin_dir / "codebase-graph" + python_path.write_text("", encoding="utf-8") + script_path.write_text("", encoding="utf-8") + script_path.chmod(0o755) + monkeypatch.setattr(sys, "executable", python_path.as_posix()) + monkeypatch.setenv("PATH", "") + + entry = server_entry(tmp_path / ".codebaseGraph" / "config.json") + + assert entry["command"] == script_path.as_posix() + + +def test_setup_preflight_failure_stops_before_state_creation(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + repo_root = _fresh_repo(tmp_path) + + def fail_preflight() -> None: + raise LadybugUnavailableError("missing LadyBugDB") + + monkeypatch.setattr("codebase_graph.setup.orchestrator.validate_ladybug_runtime", fail_preflight) + + with pytest.raises(SetupError, match="missing LadyBugDB"): + run_setup(SetupOptions(repo_root=repo_root, mcp_client="none")) + + assert not (repo_root / ".codebaseGraph").exists() + + +def test_setup_rejects_state_directory_as_repo_root(tmp_path: Path) -> None: + state_root = tmp_path / ".codebaseGraph" + state_root.mkdir() + + with pytest.raises(SetupError, match="state directory"): + run_setup(SetupOptions(repo_root=state_root, mcp_client="none")) + + assert not (state_root / ".codebaseGraph").exists() + + +def test_setup_dry_run_does_not_write_repo_or_client_state( + tmp_path: Path, + capsys: pytest.CaptureFixture[str], +) -> None: + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + mcp_config_path = tmp_path / "config.toml" + + exit_code = cli_main( + [ + "setup", + "--repo-root", + repo_root.as_posix(), + "--mcp-client", + "codex", + "--mcp-config-path", + mcp_config_path.as_posix(), + "--dry-run", + ] + ) + payload = json.loads(capsys.readouterr().out) + + assert exit_code == 0 + assert payload["config_action"] == "dry_run" + assert payload["materialization"]["mode"] == "dry_run" + assert payload["instructions"]["action"] == "dry_run" + assert payload["mcp_config"]["action"] == "dry_run" + assert not (repo_root / ".codebaseGraph").exists() + assert not (repo_root / "AGENTS.md").exists() + assert not mcp_config_path.exists() + + +def test_setup_materialization_failure_rolls_back_published_control_files( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + + def fail_materialize(self: object, *, mode: str = "changed") -> object: + raise RuntimeError("materialization failed") + + monkeypatch.setattr("codebase_graph.setup.orchestrator.GraphMaterializer.materialize", fail_materialize) + + with pytest.raises(SetupError, match="materialization failed"): + run_setup(SetupOptions(repo_root=repo_root, mcp_client="none")) + + assert not (repo_root / ".codebaseGraph").exists() + assert not (repo_root / "AGENTS.md").exists() + + +def test_mcp_graph_query_rejects_write_like_statements(tmp_path: Path) -> None: + pytest.importorskip("tree_sitter") + pytest.importorskip("tree_sitter_python") + pytest.importorskip("real_ladybug") + repo_root = _fresh_repo(tmp_path) + result = run_setup(SetupOptions(repo_root=repo_root, mcp_client="none", instructions_target="skip")) + server = McpGraphServer.from_paths(config_path=result.paths.config_path) + + with pytest.raises(ValueError, match="read-only"): + handle_tool_call( + "graph_query", + {"statement": "MATCH (n) DELETE n"}, + runtime=server.runtime, + ) + + +def test_setup_invalid_repo_root_exits_nonzero(tmp_path: Path) -> None: + missing = tmp_path / "missing" + + with pytest.raises(SystemExit) as exc_info: + cli_main(["setup", "--repo-root", missing.as_posix(), "--mcp-client", "none"]) + + assert exc_info.value.code == 2 + + +def test_runtime_config_uses_repo_root_from_setup_config(tmp_path: Path) -> None: + repo_root = _fresh_repo(tmp_path) + paths = derive_setup_paths(repo_root) + payload = build_setup_config(paths, mcp_command=["codebase-graph", "mcp", "serve", "--config", paths.config_path.as_posix()]) + write_setup_config(paths.config_path, payload) + paths.db_path.write_text("", encoding="utf-8") + paths.manifest_path.write_text("{}", encoding="utf-8") + other_root = tmp_path / "other_repo" + other_root.mkdir() + + runtime = runtime_config(repo_root=other_root, config_path=paths.config_path, db_path=None, manifest_path=None) + + assert runtime.repo_root == repo_root.resolve() + assert runtime.db_path == paths.db_path + assert runtime.manifest_path == paths.manifest_path + + +def test_setup_config_rejects_database_path_outside_state_dir(tmp_path: Path) -> None: + repo_root = _fresh_repo(tmp_path) + paths = derive_setup_paths(repo_root) + payload = build_setup_config(paths, mcp_command=["codebase-graph", "mcp", "serve", "--config", paths.config_path.as_posix()]) + payload["database_path"] = (tmp_path / "other.ldb").as_posix() + paths.config_path.parent.mkdir(parents=True) + paths.config_path.write_text(json.dumps(payload), encoding="utf-8") + + with pytest.raises(ValueError, match="database_path must be"): + load_setup_config(paths.config_path) + + +def test_packaging_requires_ladybug_and_namespaced_package_discovery() -> None: + payload = tomllib.loads(Path("pyproject.toml").read_text(encoding="utf-8")) + dependencies = "\n".join(payload["project"]["dependencies"]) + + assert "real_ladybug>=0.15.3,<0.16" in dependencies + assert "tree-sitter>=0.25.2,<0.26" in dependencies + assert "tree-sitter-python>=0.25.0,<0.26" in dependencies + assert "setuptools>=77" in payload["build-system"]["requires"] + assert payload["project"]["license"] == "MIT" + assert payload["project"]["license-files"] == ["LICENSE"] + assert payload["project"]["scripts"]["codebase-graph"] == "codebase_graph.cli:main" + assert payload["project"]["scripts"]["codebase-graph-mcp"] == "codebase_graph.mcp.server:main" + assert payload["project"]["urls"]["Repository"] == "https://github.com/rabii-chaarani/codebaseGraph" + assert payload["project"]["urls"]["Issues"] == "https://github.com/rabii-chaarani/codebaseGraph/issues" + assert payload["tool"]["setuptools"]["packages"]["find"]["include"] == ["codebase_graph*"] + + +def _fresh_repo(tmp_path: Path) -> Path: + repo_root = tmp_path / "fresh_repo" + package = repo_root / "sample_project" + package.mkdir(parents=True) + (package / "__init__.py").write_text("", encoding="utf-8") + (package / "service.py").write_text( + "class SampleService:\n" + " def run(self) -> str:\n" + " return helper()\n\n" + "def helper() -> str:\n" + " return 'ok'\n", + encoding="utf-8", + ) + (repo_root / "README.md").write_text( + "# Fresh Repo\n\nThis repository documents the SampleService workflow.\n", + encoding="utf-8", + ) + return repo_root