diff --git a/internal/cmd/supply_chain.go b/internal/cmd/supply_chain.go index 0105c88..e46cc75 100644 --- a/internal/cmd/supply_chain.go +++ b/internal/cmd/supply_chain.go @@ -35,9 +35,10 @@ compromised maintainer accounts, or dependency confusion. Supported ecosystems: Node: npm, pnpm, bun, yarn (transparent proxy enforcement) Python: pip, uv (transparent proxy); poetry, pipenv, pdm (pre-install block) + Java: maven (pom.xml), gradle (gradle.lockfile) (pre-install block) No Armis Cloud authentication is required — supply-chain queries public registries -(npm registry, PyPI).`, +(npm registry, PyPI, Maven Central).`, Example: ` # Audit lockfile for recently-published packages (CI mode) armis-cli supply-chain check diff --git a/internal/cmd/supply_chain_init.go b/internal/cmd/supply_chain_init.go index cf60069..6d5c2df 100644 --- a/internal/cmd/supply_chain_init.go +++ b/internal/cmd/supply_chain_init.go @@ -26,7 +26,8 @@ var scInitCmd = &cobra.Command{ This wraps your package manager (auto-detected from lockfiles) so that armis-cli can enforce age policies on package installations. Node PMs (npm, pnpm, bun, yarn) and pip/uv use a transparent proxy that filters registry responses. poetry, pipenv, -and pdm use a pre-install check that blocks the build if violations are found. +pdm, mvn, and gradle use a pre-install check that blocks the build if violations +are found. Four modes are available: rc — Inject shell functions into ~/.bashrc / ~/.zshrc (default, interactive) @@ -144,6 +145,10 @@ func ecosystemToPM(eco supplychain.Ecosystem) string { return pmPDM case supplychain.EcosystemUV: return pmUV + case supplychain.EcosystemMaven: + return pmMaven + case supplychain.EcosystemGradle: + return pmGradle default: return "" } diff --git a/internal/cmd/supply_chain_wrap.go b/internal/cmd/supply_chain_wrap.go index 32ca36e..eb35e4b 100644 --- a/internal/cmd/supply_chain_wrap.go +++ b/internal/cmd/supply_chain_wrap.go @@ -39,6 +39,8 @@ const ( pmPoetry = "poetry" pmPipenv = "pipenv" pmPDM = "pdm" + pmMaven = "mvn" + pmGradle = "gradle" ) var scWrapCmd = &cobra.Command{ @@ -57,6 +59,7 @@ func init() { var allowedPMs = map[string]bool{ pmNPM: true, pmPNPM: true, pmBun: true, pmYarn: true, pmPip: true, pmUV: true, pmPoetry: true, pmPipenv: true, pmPDM: true, + pmMaven: true, pmGradle: true, } func runSupplyChainWrap(cmd *cobra.Command, args []string) error { @@ -70,7 +73,7 @@ func runSupplyChainWrap(cmd *cobra.Command, args []string) error { canonical := canonicalPM(pmName) if !allowedPMs[canonical] { - return fmt.Errorf("unsupported package manager: %s (allowed: npm, pnpm, bun, yarn, pip, uv, poetry, pipenv, pdm)", pmName) + return fmt.Errorf("unsupported package manager: %s (allowed: npm, pnpm, bun, yarn, pip, uv, poetry, pipenv, pdm, mvn, gradle)", pmName) } if os.Getenv(envSCActive) == "1" { @@ -176,6 +179,10 @@ func execPM(pm string, args []string, extraEnv []string) (int, error) { pmName = pmPipenv case pmPDM: pmName = pmPDM + case pmMaven: + pmName = pmMaven + case pmGradle: + pmName = pmGradle default: // Versioned pip variants (pip3, pip3.11, pip3.12) must execute the exact // binary the user invoked so the install lands in that interpreter's @@ -187,7 +194,7 @@ func execPM(pm string, args []string, extraEnv []string) (int, error) { pmName = pm break } - return 1, fmt.Errorf("unsupported package manager: %s (allowed: npm, pnpm, bun, yarn, pip, uv, poetry, pipenv, pdm)", pm) + return 1, fmt.Errorf("unsupported package manager: %s (allowed: npm, pnpm, bun, yarn, pip, uv, poetry, pipenv, pdm, mvn, gradle)", pm) } // armis:ignore cwe:426 cwe:427 reason:pmName is one of the hardcoded string literals selected by the switch above, never the user argument; resolving the user's own PM from PATH is the point of a transparent wrapper @@ -421,7 +428,7 @@ func resolveWrapPolicy() supplychain.Policy { // too-young package. func requiresPreInstallBlock(pm string) bool { switch pm { - case pmPoetry, pmPipenv, pmPDM: + case pmPoetry, pmPipenv, pmPDM, pmMaven, pmGradle: return true } return false @@ -561,6 +568,10 @@ func pmToEcosystem(pm string) supplychain.Ecosystem { return supplychain.EcosystemPipfile case pmPDM: return supplychain.EcosystemPDM + case pmMaven: + return supplychain.EcosystemMaven + case pmGradle: + return supplychain.EcosystemGradle default: return "" } diff --git a/internal/supplychain/check/check.go b/internal/supplychain/check/check.go index 1c45ecb..84ae174 100644 --- a/internal/supplychain/check/check.go +++ b/internal/supplychain/check/check.go @@ -104,6 +104,10 @@ func parseLockfile(ecosystem supplychain.Ecosystem, path string) ([]PackageEntry return ParsePDMLockfile(path) case supplychain.EcosystemUV: return ParseUVLockfile(path) + case supplychain.EcosystemMaven: + return ParseMavenDeps(path) + case supplychain.EcosystemGradle: + return ParseGradleLockfile(path) default: return ParseNPMLockfile(path) } @@ -114,6 +118,9 @@ func queryRegistry(ctx context.Context, ecosystem supplychain.Ecosystem, package case supplychain.EcosystemPip, supplychain.EcosystemPoetry, supplychain.EcosystemPipfile, supplychain.EcosystemPDM, supplychain.EcosystemUV: client := registry.NewPyPIClient() return client.GetPublishDates(ctx, packages) + case supplychain.EcosystemMaven, supplychain.EcosystemGradle: + client := registry.NewMavenClient() + return client.GetPublishDates(ctx, packages) default: client := registry.NewClient() return client.GetPublishDates(ctx, packages) @@ -137,6 +144,10 @@ func detectEcosystemFromPath(path string) supplychain.Ecosystem { return supplychain.EcosystemPDM case strings.HasSuffix(lower, "uv.lock"): return supplychain.EcosystemUV + case strings.HasSuffix(lower, "pom.xml"): + return supplychain.EcosystemMaven + case strings.HasSuffix(lower, "gradle.lockfile"): + return supplychain.EcosystemGradle case isRequirementsFile(lower): return supplychain.EcosystemPip default: diff --git a/internal/supplychain/check/gradle.go b/internal/supplychain/check/gradle.go new file mode 100644 index 0000000..af4cbc5 --- /dev/null +++ b/internal/supplychain/check/gradle.go @@ -0,0 +1,79 @@ +package check + +import ( + "bufio" + "bytes" + "fmt" + "strings" +) + +// ParseGradleLockfile parses a Gradle lockfile (gradle.lockfile). +// Format: one dependency per line as "group:artifact:version=configurations" +// after a header, where the suffix after "=" is a comma-separated list of the +// configurations that resolved the dependency (e.g. "compileClasspath,runtimeClasspath"). +// The parser treats everything after "=" as metadata and ignores it. +// armis:ignore cwe:22 cwe:23 cwe:73 reason:local CLI reading the user's own lockfile; path is from local detection or an explicit --lockfile flag, not untrusted input crossing a trust boundary +func ParseGradleLockfile(path string) ([]PackageEntry, error) { + // armis:ignore cwe:22 cwe:23 cwe:73 reason:local CLI reading the user's own lockfile; path is from local detection or an explicit --lockfile flag, not untrusted input crossing a trust boundary + data, err := readLockfile(path) + if err != nil { + return nil, err + } + + scanner := bufio.NewScanner(bytes.NewReader(data)) + // Gradle lockfile lines can carry a long, comma-separated configuration list + // after "=", so raise the scanner's per-line cap. data is already size-bounded + // by readLockfile. + scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), maxLockfileSize) + + var entries []PackageEntry + headerPassed := false + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // The header line "empty=" signals end of preamble in some formats + if !headerPassed { + if strings.Contains(line, "=") && !strings.Contains(line, ":") { + // Metadata line like "empty=" + continue + } + headerPassed = true + } + + // Expected: group:artifact:version=configurations + eqIdx := strings.Index(line, "=") + gav := line + if eqIdx > 0 { + gav = line[:eqIdx] + } + + parts := strings.Split(gav, ":") + if len(parts) < 3 { + continue + } + + group := parts[0] + artifact := parts[1] + version := parts[2] + + if group == "" || artifact == "" || version == "" { + continue + } + + entries = append(entries, PackageEntry{ + Name: group + ":" + artifact, + Version: version, + }) + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scanning gradle lockfile: %w", err) + } + + return entries, nil +} diff --git a/internal/supplychain/check/gradle_test.go b/internal/supplychain/check/gradle_test.go new file mode 100644 index 0000000..4290401 --- /dev/null +++ b/internal/supplychain/check/gradle_test.go @@ -0,0 +1,44 @@ +package check + +import ( + "path/filepath" + "sort" + "testing" +) + +func TestParseGradleLockfile(t *testing.T) { + t.Run("valid gradle lockfile", func(t *testing.T) { + entries, err := ParseGradleLockfile(filepath.Join("testdata", "gradle.lockfile")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name < entries[j].Name + }) + + expected := []PackageEntry{ + {Name: "com.fasterxml.jackson.core:jackson-core", Version: "2.16.0"}, + {Name: "com.google.guava:guava", Version: "32.1.3-jre"}, + {Name: "org.slf4j:slf4j-api", Version: "2.0.9"}, + {Name: "org.springframework:spring-core", Version: "6.1.2"}, + } + + if len(entries) != len(expected) { + t.Fatalf("expected %d entries, got %d: %+v", len(expected), len(entries), entries) + } + + for i, e := range entries { + if e.Name != expected[i].Name || e.Version != expected[i].Version { + t.Errorf("entry %d: expected %s@%s, got %s@%s", i, expected[i].Name, expected[i].Version, e.Name, e.Version) + } + } + }) + + t.Run("file not found", func(t *testing.T) { + _, err := ParseGradleLockfile("testdata/nonexistent.lockfile") + if err == nil { + t.Fatal("expected error for nonexistent file") + } + }) +} diff --git a/internal/supplychain/check/maven.go b/internal/supplychain/check/maven.go new file mode 100644 index 0000000..2e86e29 --- /dev/null +++ b/internal/supplychain/check/maven.go @@ -0,0 +1,99 @@ +package check + +import ( + "encoding/xml" + "fmt" + "strings" +) + +type pomProject struct { + XMLName xml.Name `xml:"project"` + Dependencies pomDeps `xml:"dependencies"` + DepMgmt pomDepMgmt `xml:"dependencyManagement"` +} + +type pomDepMgmt struct { + Dependencies pomDeps `xml:"dependencies"` +} + +type pomDeps struct { + Dependency []pomDependency `xml:"dependency"` +} + +type pomDependency struct { + GroupID string `xml:"groupId"` + ArtifactID string `xml:"artifactId"` + Version string `xml:"version"` + Scope string `xml:"scope"` +} + +// ParseMavenDeps parses a pom.xml file for direct dependencies with explicit versions. +// Only direct dependencies are covered; transitive dependencies resolved by Maven +// at build time are not present in pom.xml. Entries under +// are used only as a fallback version source for dependencies declared in +// that omit their own ; managed entries are not treated +// as dependencies themselves, since declaring a managed version does not pull a +// package into the build. +// armis:ignore cwe:22 cwe:23 cwe:73 reason:local CLI reading the user's own lockfile; path is from local detection or an explicit --lockfile flag, not untrusted input crossing a trust boundary +func ParseMavenDeps(path string) ([]PackageEntry, error) { + // armis:ignore cwe:22 cwe:23 cwe:73 reason:local CLI reading the user's own lockfile; path is from local detection or an explicit --lockfile flag, not untrusted input crossing a trust boundary + data, err := readLockfile(path) + if err != nil { + return nil, err + } + + var project pomProject + // armis:ignore cwe:502 cwe:770 reason:xml.Unmarshal into a typed struct does not execute code; data is size-bounded by readLockfile and is the user's own lockfile, not untrusted data + if err := xml.Unmarshal(data, &project); err != nil { + return nil, fmt.Errorf("parsing pom.xml: %w", err) + } + + // Build a groupId:artifactId -> version index from so + // dependencies that omit their can inherit the managed value. + managedVersions := make(map[string]string) + for _, dep := range project.DepMgmt.Dependencies.Dependency { + if dep.GroupID == "" || dep.ArtifactID == "" || dep.Version == "" { + continue + } + managedVersions[dep.GroupID+":"+dep.ArtifactID] = dep.Version + } + + var entries []PackageEntry + seen := make(map[string]bool) + + for _, dep := range project.Dependencies.Dependency { + // Backfill a missing version from before converting. + if dep.Version == "" { + dep.Version = managedVersions[dep.GroupID+":"+dep.ArtifactID] + } + entry := mavenDepToEntry(dep) + if entry != nil && !seen[entry.Name+"@"+entry.Version] { + seen[entry.Name+"@"+entry.Version] = true + entries = append(entries, *entry) + } + } + + return entries, nil +} + +func mavenDepToEntry(dep pomDependency) *PackageEntry { + if dep.GroupID == "" || dep.ArtifactID == "" || dep.Version == "" { + return nil + } + + // Skip property references that can't be resolved + if strings.Contains(dep.Version, "${") { + return nil + } + + // Skip test and provided scope + scope := strings.ToLower(dep.Scope) + if scope == "test" || scope == "provided" { + return nil + } + + return &PackageEntry{ + Name: dep.GroupID + ":" + dep.ArtifactID, + Version: dep.Version, + } +} diff --git a/internal/supplychain/check/maven_test.go b/internal/supplychain/check/maven_test.go new file mode 100644 index 0000000..b00a5df --- /dev/null +++ b/internal/supplychain/check/maven_test.go @@ -0,0 +1,103 @@ +package check + +import ( + "path/filepath" + "sort" + "testing" +) + +func TestParseMavenDeps(t *testing.T) { + t.Run("valid pom.xml", func(t *testing.T) { + entries, err := ParseMavenDeps(filepath.Join("testdata", "pom.xml")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name < entries[j].Name + }) + + // Should include guava, jackson-core, and commons-io (version inherited + // from ). + // Should skip: junit (test scope), servlet-api (provided scope), + // spring-core (${property}), and commons-lang3 (managed-only, never + // declared under ). + expected := []PackageEntry{ + {Name: "com.fasterxml.jackson.core:jackson-core", Version: "2.16.0"}, + {Name: "com.google.guava:guava", Version: "32.1.3-jre"}, + {Name: "commons-io:commons-io", Version: "2.15.1"}, + } + + if len(entries) != len(expected) { + t.Fatalf("expected %d entries, got %d: %+v", len(expected), len(entries), entries) + } + + for i, e := range entries { + if e.Name != expected[i].Name || e.Version != expected[i].Version { + t.Errorf("entry %d: expected %s@%s, got %s@%s", i, expected[i].Name, expected[i].Version, e.Name, e.Version) + } + } + }) + + t.Run("skips test scope", func(t *testing.T) { + entries, err := ParseMavenDeps(filepath.Join("testdata", "pom.xml")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, e := range entries { + if e.Name == "org.junit.jupiter:junit-jupiter" { + t.Error("should have skipped test-scope dependency") + } + } + }) + + t.Run("skips property version refs", func(t *testing.T) { + entries, err := ParseMavenDeps(filepath.Join("testdata", "pom.xml")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, e := range entries { + if e.Name == "org.springframework:spring-core" { + t.Error("should have skipped property-referenced version") + } + } + }) + + t.Run("inherits version from dependencyManagement", func(t *testing.T) { + entries, err := ParseMavenDeps(filepath.Join("testdata", "pom.xml")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var found bool + for _, e := range entries { + if e.Name == "commons-io:commons-io" { + found = true + if e.Version != "2.15.1" { + t.Errorf("expected commons-io to inherit managed version 2.15.1, got %s", e.Version) + } + } + } + if !found { + t.Error("expected commons-io (versionless dependency) to inherit a managed version") + } + }) + + t.Run("ignores managed-only dependencies", func(t *testing.T) { + entries, err := ParseMavenDeps(filepath.Join("testdata", "pom.xml")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for _, e := range entries { + if e.Name == "org.apache.commons:commons-lang3" { + t.Error("should not include a dependencyManagement-only entry that is never declared under ") + } + } + }) + + t.Run("file not found", func(t *testing.T) { + _, err := ParseMavenDeps("testdata/nonexistent.xml") + if err == nil { + t.Fatal("expected error for nonexistent file") + } + }) +} diff --git a/internal/supplychain/check/testdata/gradle.lockfile b/internal/supplychain/check/testdata/gradle.lockfile new file mode 100644 index 0000000..e1f7b45 --- /dev/null +++ b/internal/supplychain/check/testdata/gradle.lockfile @@ -0,0 +1,8 @@ +# This is a Gradle generated file for dependency locking. +# Manual edits can mess up your build. +# This file is expected to be part of source control. +com.google.guava:guava:32.1.3-jre=compileClasspath,runtimeClasspath +com.fasterxml.jackson.core:jackson-core:2.16.0=compileClasspath,runtimeClasspath +org.slf4j:slf4j-api:2.0.9=compileClasspath,runtimeClasspath +org.springframework:spring-core:6.1.2=compileClasspath,runtimeClasspath +empty= diff --git a/internal/supplychain/check/testdata/pom.xml b/internal/supplychain/check/testdata/pom.xml new file mode 100644 index 0000000..19525aa --- /dev/null +++ b/internal/supplychain/check/testdata/pom.xml @@ -0,0 +1,61 @@ + + + 4.0.0 + com.example + my-app + 1.0.0 + + + + + + org.apache.commons + commons-lang3 + 3.14.0 + + + + commons-io + commons-io + 2.15.1 + + + + + + + com.google.guava + guava + 32.1.3-jre + + + com.fasterxml.jackson.core + jackson-core + 2.16.0 + + + + commons-io + commons-io + + + org.junit.jupiter + junit-jupiter + 5.10.1 + test + + + javax.servlet + javax.servlet-api + 4.0.1 + provided + + + org.springframework + spring-core + ${spring.version} + + + diff --git a/internal/supplychain/detect.go b/internal/supplychain/detect.go index 5e77602..a93c0c7 100644 --- a/internal/supplychain/detect.go +++ b/internal/supplychain/detect.go @@ -18,6 +18,8 @@ const ( EcosystemPipfile Ecosystem = "pipfile" EcosystemPDM Ecosystem = "pdm" EcosystemUV Ecosystem = "uv" + EcosystemMaven Ecosystem = "maven" + EcosystemGradle Ecosystem = "gradle" ) type DetectedEcosystem struct { @@ -45,6 +47,8 @@ var lockfileChecks = []lockfileCheck{ {"pnpm-lock.yaml", EcosystemPNPM, true}, {"bun.lock", EcosystemBun, true}, {"yarn.lock", EcosystemYarn, true}, + {"pom.xml", EcosystemMaven, true}, + {"gradle.lockfile", EcosystemGradle, true}, {"poetry.lock", EcosystemPoetry, true}, {"Pipfile.lock", EcosystemPipfile, true}, {"pdm.lock", EcosystemPDM, true}, @@ -83,7 +87,7 @@ func DetectEcosystems(dir string) ([]DetectedEcosystem, error) { } if len(detected) == 0 { - return nil, fmt.Errorf("no supported lockfile found in %s\n\n Supported: package-lock.json, pnpm-lock.yaml, bun.lock, yarn.lock,\n poetry.lock, Pipfile.lock, pdm.lock, uv.lock, requirements.txt\n Try: armis-cli supply-chain check \n Or use: --lockfile ", dir) + return nil, fmt.Errorf("no supported lockfile found in %s\n\n Supported: package-lock.json, pnpm-lock.yaml, bun.lock, yarn.lock,\n pom.xml, gradle.lockfile, poetry.lock, Pipfile.lock,\n pdm.lock, uv.lock, requirements.txt\n Try: armis-cli supply-chain check \n Or use: --lockfile ", dir) } return detected, nil diff --git a/internal/supplychain/registry/maven.go b/internal/supplychain/registry/maven.go new file mode 100644 index 0000000..6494b7c --- /dev/null +++ b/internal/supplychain/registry/maven.go @@ -0,0 +1,209 @@ +package registry + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "regexp" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/cenkalti/backoff/v4" +) + +const ( + defaultMavenURL = "https://search.maven.org" + // mavenMaxElapsed bounds the total time spent retrying a single coordinate + // against Maven Central's rate limiter before giving up. + mavenMaxElapsed = 30 * time.Second + mavenInitialBackoff = 500 * time.Millisecond +) + +var validMavenCoordinate = regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) + +// MavenClient queries Maven Central for artifact publication dates. Unlike the +// npm and PyPI clients, Maven Central rate-limits aggressively (HTTP 429), so +// each lookup is wrapped in an exponential backoff. +type MavenClient struct { + httpClient *http.Client + baseURL string + cache sync.Map // map[string]time.Time, keyed by "group:artifact@version" + cacheLen atomic.Int64 +} + +type mavenSearchResponse struct { + Response struct { + Docs []struct { + Timestamp int64 `json:"timestamp"` + } `json:"docs"` + } `json:"response"` +} + +func NewMavenClient() *MavenClient { + return &MavenClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + baseURL: defaultMavenURL, + } +} + +// NewMavenClientWithHTTP builds a MavenClient with an injected HTTP client and +// base URL. It exists for tests that point the client at an httptest server; the +// baseURL is therefore a trusted construction-time value, not request- or +// network-derived input. Production code uses NewMavenClient, which hardcodes +// the search.maven.org HTTPS endpoint. +func NewMavenClientWithHTTP(client *http.Client, baseURL string) *MavenClient { + if baseURL == "" { + baseURL = defaultMavenURL + } + return &MavenClient{ + httpClient: client, + baseURL: baseURL, + } +} + +func (c *MavenClient) GetPublishDate(ctx context.Context, name, version string) (time.Time, error) { + parts := strings.SplitN(name, ":", 2) + if len(parts) != 2 { + return time.Time{}, fmt.Errorf("invalid maven coordinate: %s (expected group:artifact)", name) + } + groupID, artifactID := parts[0], parts[1] + + if !validMavenCoordinate.MatchString(groupID) { + return time.Time{}, fmt.Errorf("invalid maven groupId: %s", groupID) + } + if !validMavenCoordinate.MatchString(artifactID) { + return time.Time{}, fmt.Errorf("invalid maven artifactId: %s", artifactID) + } + + cacheKey := name + "@" + version + if cached, ok := c.cache.Load(cacheKey); ok { + return cached.(time.Time), nil + } + + var publishTime time.Time + operation := func() error { + t, err := c.fetchPublishDate(ctx, groupID, artifactID, version) + if err != nil { + return err + } + publishTime = t + return nil + } + + bo := backoff.NewExponentialBackOff() + bo.MaxElapsedTime = mavenMaxElapsed + bo.InitialInterval = mavenInitialBackoff + + if err := backoff.Retry(operation, backoff.WithContext(bo, ctx)); err != nil { + return time.Time{}, err + } + + // Memoize, but stop inserting once the cache reaches maxCacheEntries so it + // cannot grow without bound (CWE-770). + if c.cacheLen.Load() < maxCacheEntries { + if _, loaded := c.cache.LoadOrStore(cacheKey, publishTime); !loaded { + c.cacheLen.Add(1) + } + } + return publishTime, nil +} + +// escapeSolrQueryValue escapes the characters that are special inside a +// double-quoted Solr query term. URL-escaping alone does not prevent query +// injection: the value is decoded before Solr parses it, so a raw `"` or `\` +// in a lockfile-provided version could otherwise break out of the quoted term +// and change the query's semantics (potentially returning an unrelated +// artifact's timestamp and bypassing release-age enforcement). Backslash must +// be escaped first so the backslashes added for quotes are not doubled again. +func escapeSolrQueryValue(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `"`, `\"`) + return s +} + +func (c *MavenClient) fetchPublishDate(ctx context.Context, groupID, artifactID, version string) (time.Time, error) { + q := fmt.Sprintf(`g:"%s" AND a:"%s" AND v:"%s"`, + escapeSolrQueryValue(groupID), + escapeSolrQueryValue(artifactID), + escapeSolrQueryValue(version)) + // armis:ignore cwe:918 reason:baseURL is a trusted construction-time config value (production NewMavenClient hardcodes the search.maven.org HTTPS constant; the URL-accepting NewMavenClientWithHTTP is test-only); groupID/artifactID are regex-validated, every interpolated value is Solr-escaped against query injection, and the whole query is QueryEscaped, so the host is not attacker-controlled + reqURL := fmt.Sprintf("%s/solrsearch/select?q=%s&rows=1&wt=json", c.baseURL, url.QueryEscape(q)) + + // armis:ignore cwe:918 reason:reqURL is built from the trusted baseURL constant + a QueryEscaped query whose interpolated values are Solr-escaped (group/artifact also regex-validated), so the host is not attacker-controlled + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return time.Time{}, backoff.Permanent(fmt.Errorf("creating request: %w", err)) + } + req.Header.Set("Accept", "application/json") + + // armis:ignore cwe:918 reason:c.baseURL is a trusted construction-time config value (production NewMavenClient hardcodes the search.maven.org HTTPS constant; the URL-accepting NewMavenClientWithHTTP is test-only), so the request host is not attacker-controlled; group/artifact are regex-validated and every interpolated query value is Solr-escaped then QueryEscaped + resp, err := c.httpClient.Do(req) //nolint:gosec // G704: reqURL is a constant/configured registry host + Solr-escaped, QueryEscaped coordinates + if err != nil { + return time.Time{}, fmt.Errorf("fetching maven metadata: %w", err) + } + defer resp.Body.Close() //nolint:errcheck // best-effort close on read path + + // A 429 is transient: return a plain (retryable) error so backoff retries it. + if resp.StatusCode == http.StatusTooManyRequests { + return time.Time{}, fmt.Errorf("maven central rate limited (429)") + } + // Any other non-200 is permanent: wrap so backoff stops immediately. + if resp.StatusCode != http.StatusOK { + return time.Time{}, backoff.Permanent(fmt.Errorf("maven central returned %d for %s:%s", resp.StatusCode, groupID, artifactID)) + } + + // Read one byte past the cap so an oversize response is detectable instead of + // being silently truncated and failing as a confusing JSON parse error. + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize+1)) + if err != nil { + return time.Time{}, backoff.Permanent(fmt.Errorf("reading response: %w", err)) + } + if int64(len(body)) > maxResponseSize { + return time.Time{}, backoff.Permanent(fmt.Errorf("maven central response for %s:%s too large (max %d bytes)", groupID, artifactID, maxResponseSize)) + } + + var searchResp mavenSearchResponse + if err := json.Unmarshal(body, &searchResp); err != nil { + return time.Time{}, backoff.Permanent(fmt.Errorf("parsing maven response: %w", err)) + } + + if len(searchResp.Response.Docs) == 0 { + return time.Time{}, backoff.Permanent(fmt.Errorf("artifact not found on maven central: %s:%s:%s", groupID, artifactID, version)) + } + + timestamp := searchResp.Response.Docs[0].Timestamp + return time.UnixMilli(timestamp), nil +} + +func (c *MavenClient) GetPublishDates(ctx context.Context, packages []PackageRequest) []QueryResult { + results := make([]QueryResult, len(packages)) + sem := make(chan struct{}, maxConcurrent) + var wg sync.WaitGroup + + for i, pkg := range packages { + // Acquire the semaphore before spawning so that goroutine creation itself + // is bounded by maxConcurrent rather than launching one stack per package. + sem <- struct{}{} + wg.Add(1) + go func(idx int, name, version string) { + defer wg.Done() + defer func() { <-sem }() + + publishTime, err := c.GetPublishDate(ctx, name, version) + results[idx] = QueryResult{ + Name: name, + Version: version, + PublishTime: publishTime, + Err: err, + } + }(i, pkg.Name, pkg.Version) + } + + wg.Wait() + return results +} diff --git a/internal/supplychain/registry/maven_test.go b/internal/supplychain/registry/maven_test.go new file mode 100644 index 0000000..89a3dca --- /dev/null +++ b/internal/supplychain/registry/maven_test.go @@ -0,0 +1,155 @@ +package registry + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestMavenGetPublishDate(t *testing.T) { + t.Run("success", func(t *testing.T) { + // 2023-09-30T12:00:00Z in milliseconds since epoch. + ts := time.Date(2023, 9, 30, 12, 0, 0, 0, time.UTC).UnixMilli() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/solrsearch/select" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + q := r.URL.Query().Get("q") + if !strings.Contains(q, `g:"org.springframework"`) || !strings.Contains(q, `a:"spring-core"`) { + t.Errorf("unexpected query: %s", q) + } + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{"response":{"docs":[{"timestamp":%d}]}}`, ts) + })) + defer server.Close() + + client := NewMavenClientWithHTTP(server.Client(), server.URL) + publishTime, err := client.GetPublishDate(context.Background(), "org.springframework:spring-core", "6.0.0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := time.UnixMilli(ts) + if !publishTime.Equal(expected) { + t.Errorf("expected %v, got %v", expected, publishTime) + } + }) + + t.Run("invalid coordinate without colon", func(t *testing.T) { + client := NewMavenClient() + _, err := client.GetPublishDate(context.Background(), "not-a-coordinate", "1.0.0") + if err == nil { + t.Error("expected error for coordinate missing group:artifact separator") + } + }) + + t.Run("invalid groupId characters", func(t *testing.T) { + client := NewMavenClient() + _, err := client.GetPublishDate(context.Background(), `org.evil"/../:spring-core`, "1.0.0") + if err == nil { + t.Error("expected error for invalid groupId characters") + } + }) + + t.Run("artifact not found", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"response":{"docs":[]}}`)) + })) + defer server.Close() + + client := NewMavenClientWithHTTP(server.Client(), server.URL) + _, err := client.GetPublishDate(context.Background(), "com.example:missing", "1.0.0") + if err == nil { + t.Error("expected error for empty docs") + } + }) + + t.Run("non-200 status", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + client := NewMavenClientWithHTTP(server.Client(), server.URL) + _, err := client.GetPublishDate(context.Background(), "com.example:lib", "1.0.0") + if err == nil { + t.Error("expected error for 500 status") + } + }) + + t.Run("escapes solr query special characters in version", func(t *testing.T) { + ts := time.Date(2020, 5, 5, 0, 0, 0, 0, time.UTC).UnixMilli() + var gotQuery string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotQuery = r.URL.Query().Get("q") + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{"response":{"docs":[{"timestamp":%d}]}}`, ts) + })) + defer server.Close() + + client := NewMavenClientWithHTTP(server.Client(), server.URL) + // A version containing a quote and backslash would, unescaped, break out + // of the quoted Solr term. After escaping, the decoded query must keep the + // value contained inside v:"..." with both characters backslash-escaped. + if _, err := client.GetPublishDate(context.Background(), "com.example:lib", `1.0"\inject`); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if want := `v:"1.0\"\\inject"`; !strings.Contains(gotQuery, want) { + t.Errorf("expected escaped version term %q in query, got %q", want, gotQuery) + } + }) + + t.Run("cached result is reused", func(t *testing.T) { + ts := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC).UnixMilli() + var calls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls++ + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{"response":{"docs":[{"timestamp":%d}]}}`, ts) + })) + defer server.Close() + + client := NewMavenClientWithHTTP(server.Client(), server.URL) + for i := 0; i < 3; i++ { + if _, err := client.GetPublishDate(context.Background(), "com.example:cached", "1.0.0"); err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + if calls != 1 { + t.Errorf("expected 1 upstream call (rest cached), got %d", calls) + } + }) +} + +func TestMavenGetPublishDates(t *testing.T) { + ts := time.Date(2021, 6, 15, 0, 0, 0, 0, time.UTC).UnixMilli() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = fmt.Fprintf(w, `{"response":{"docs":[{"timestamp":%d}]}}`, ts) + })) + defer server.Close() + + client := NewMavenClientWithHTTP(server.Client(), server.URL) + packages := []PackageRequest{ + {Name: "org.springframework:spring-core", Version: "6.0.0"}, + {Name: "com.google.guava:guava", Version: "32.0.0"}, + } + results := client.GetPublishDates(context.Background(), packages) + + if len(results) != len(packages) { + t.Fatalf("expected %d results, got %d", len(packages), len(results)) + } + for i, r := range results { + if r.Err != nil { + t.Errorf("result %d: unexpected error: %v", i, r.Err) + } + if r.Name != packages[i].Name { + t.Errorf("result %d: expected name %q, got %q", i, packages[i].Name, r.Name) + } + } +}