diff --git a/graph.go b/graph.go index 053e9ac..48fa793 100644 --- a/graph.go +++ b/graph.go @@ -64,7 +64,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu delete(n.neighbors, worst.Key) // Delete backlink from the worst neighbor. delete(worst.neighbors, n.Key) - worst.replenish(m) + worst.replenish(m, dist) } type searchCandidate[K cmp.Ordered] struct { @@ -76,17 +76,16 @@ func (s searchCandidate[K]) Less(o searchCandidate[K]) bool { return s.dist < o.dist } -// search returns the layer node closest to the target node -// within the same layer. +// search returns the nearest neighbors of target within this layer. +// Implements HNSW Algorithm 5 (SEARCH-LAYER): the result set is bounded +// by efSearch during exploration; only the k closest are returned. func (n *layerNode[K]) search( - // k is the number of candidates in the result set. + // k is the number of nearest neighbors to return. k int, efSearch int, target Vector, distance DistanceFunc, ) []searchCandidate[K] { - // This is a basic greedy algorithm to find the entry point at the given level - // that is closest to the target node. candidates := heap.Heap[searchCandidate[K]]{} candidates.Init(make([]searchCandidate[K], 0, efSearch)) candidates.Push( @@ -99,56 +98,64 @@ func (n *layerNode[K]) search( result = heap.Heap[searchCandidate[K]]{} visited = make(map[K]bool) ) - result.Init(make([]searchCandidate[K], 0, k)) + result.Init(make([]searchCandidate[K], 0, efSearch)) // Begin with the entry node in the result set. result.Push(candidates.Min()) visited[n.Key] = true for candidates.Len() > 0 { - var ( - current = candidates.Pop().node - improved = false - ) + current := candidates.Pop() + + // Standard HNSW termination: if the closest remaining candidate + // is farther than the worst in the ef-bounded result set, + // no further improvement is possible. + if result.Len() >= efSearch && current.dist > result.Max().dist { + break + } // We iterate the map in a sorted, deterministic fashion for // tests. - neighborKeys := maps.Keys(current.neighbors) + neighborKeys := maps.Keys(current.node.neighbors) slices.Sort(neighborKeys) for _, neighborID := range neighborKeys { - neighbor := current.neighbors[neighborID] + neighbor := current.node.neighbors[neighborID] if visited[neighborID] { continue } visited[neighborID] = true dist := distance(neighbor.Value, target) - improved = improved || dist < result.Min().dist - if result.Len() < k { + if result.Len() < efSearch { result.Push(searchCandidate[K]{node: neighbor, dist: dist}) + candidates.Push(searchCandidate[K]{node: neighbor, dist: dist}) } else if dist < result.Max().dist { result.PopLast() result.Push(searchCandidate[K]{node: neighbor, dist: dist}) - } - - candidates.Push(searchCandidate[K]{node: neighbor, dist: dist}) - // Always store candidates if we haven't reached the limit. - if candidates.Len() > efSearch { - candidates.PopLast() + candidates.Push(searchCandidate[K]{node: neighbor, dist: dist}) } } + } - // Termination condition: no improvement in distance and at least - // kMin candidates in the result set. - if !improved && result.Len() >= k { - break + // Return only the k closest from the ef-sized result set. + res := result.Slice() + slices.SortFunc(res, func(a, b searchCandidate[K]) int { + if a.dist < b.dist { + return -1 } + if a.dist > b.dist { + return 1 + } + // Deterministic tiebreaking by key. + return cmp.Compare(a.node.Key, b.node.Key) + }) + if len(res) > k { + res = res[:k] } - - return result.Slice() + return res } -func (n *layerNode[K]) replenish(m int) { +func (n *layerNode[K]) replenish(m int, dist DistanceFunc) { if len(n.neighbors) >= m { return } @@ -165,7 +172,7 @@ func (n *layerNode[K]) replenish(m int) { if candidate == n { continue } - n.addNeighbor(candidate, m, CosineDistance) + n.addNeighbor(candidate, m, dist) if len(n.neighbors) >= m { return } @@ -175,13 +182,13 @@ func (n *layerNode[K]) replenish(m int) { // isolates remove the node from the graph by removing all connections // to neighbors. -func (n *layerNode[K]) isolate(m int) { +func (n *layerNode[K]) isolate(m int, dist DistanceFunc) { for _, neighbor := range n.neighbors { delete(neighbor.neighbors, n.Key) } for _, neighbor := range n.neighbors { - neighbor.replenish(m) + neighbor.replenish(m, dist) } } @@ -501,7 +508,7 @@ func (h *Graph[K]) Delete(key K) bool { if len(layer.nodes) == 0 { deleteLayer[i] = struct{}{} } - node.isolate(h.M) + node.isolate(h.M, h.Distance) deleted = true } diff --git a/graph_test.go b/graph_test.go index df795e6..8211507 100644 --- a/graph_test.go +++ b/graph_test.go @@ -113,13 +113,15 @@ func TestGraph_AddSearch(t *testing.T) { ) require.Len(t, nearest, 4) + // The two closest are 64 and 65 (distance 0.5 each). + // The next two are 63 and 66 (distance 1.5 each). require.EqualValues( t, []Node[int]{ {64, Vector{64}}, {65, Vector{65}}, - {62, Vector{62}}, {63, Vector{63}}, + {66, Vector{66}}, }, nearest, ) @@ -259,3 +261,48 @@ func TestGraph_RemoveAllNodes(t *testing.T) { g.Add(MakeNode(1, vec)) } } + +func TestGraph_SearchFindsCorrectNearest(t *testing.T) { + // With the old search algorithm, termination was too aggressive (stopped + // when no neighbor beat result.Min) and the result set was bounded by k + // instead of efSearch. This caused the search to miss true nearest + // neighbors, especially with small k and large efSearch. + g := newTestGraph[int]() + for i := 0; i < 100; i++ { + g.Add(Node[int]{Key: i, Value: Vector{float32(i)}}) + } + + // Search for k=1 nearest to 50.5. The answer must be 50 or 51. + results := g.Search(Vector{50.5}, 1) + require.Len(t, results, 1) + require.Contains(t, []int{50, 51}, results[0].Key, + "expected nearest neighbor to 50.5, got key=%d", results[0].Key) + + // Search for k=3 nearest to 0.0. Must return 0, 1, 2. + results = g.Search(Vector{0.0}, 3) + require.Len(t, results, 3) + keys := make([]int, 3) + for i, r := range results { + keys[i] = r.Key + } + require.Subset(t, []int{0, 1, 2}, keys) +} + +func TestGraph_DeleteReplenishUsesGraphDistance(t *testing.T) { + // replenish() previously hardcoded CosineDistance. After deleting a + // node from a EuclideanDistance graph, replenish must use the correct + // distance function or the topology becomes corrupted. + g := newTestGraph[int]() // uses EuclideanDistance + for i := 0; i < 20; i++ { + g.Add(Node[int]{Key: i, Value: Vector{float32(i)}}) + } + + // Delete a node in the middle to trigger replenish. + g.Delete(10) + + // Search should still find the correct nearest neighbor. + results := g.Search(Vector{9.5}, 1) + require.Len(t, results, 1) + // Must be 9 or 11 (both distance 0.5 from 9.5). + require.Contains(t, []int{9, 11}, results[0].Key) +} diff --git a/heap/heap.go b/heap/heap.go index 7a5052a..c919c3c 100644 --- a/heap/heap.go +++ b/heap/heap.go @@ -70,8 +70,9 @@ func (h *Heap[T]) Pop() T { return heap.Pop(&h.inner).(T) } +// PopLast removes and returns the maximum element from the heap. func (h *Heap[T]) PopLast() T { - return h.Remove(h.Len() - 1) + return h.Remove(h.maxIndex()) } // Remove removes and returns the element at index i from the heap. @@ -85,9 +86,22 @@ func (h *Heap[T]) Min() T { return h.inner.data[0] } +// maxIndex returns the index of the maximum element by scanning leaf nodes. +// In a min-heap the max is always a leaf (indices n/2 .. n-1). +func (h *Heap[T]) maxIndex() int { + n := h.inner.Len() + best := n / 2 + for i := best + 1; i < n; i++ { + if h.inner.data[best].Less(h.inner.data[i]) { + best = i + } + } + return best +} + // Max returns the maximum element in the heap. func (h *Heap[T]) Max() T { - return h.inner.data[h.inner.Len()-1] + return h.inner.data[h.maxIndex()] } func (h *Heap[T]) Slice() []T { diff --git a/heap/heap_test.go b/heap/heap_test.go index 265723e..fd2a9c3 100644 --- a/heap/heap_test.go +++ b/heap/heap_test.go @@ -32,3 +32,20 @@ func TestHeap(t *testing.T) { t.Errorf("Heap did not return sorted elements: %+v", inOrder) } } + +func TestHeap_MaxAndPopLast(t *testing.T) { + h := Heap[Int]{} + values := []Int{5, 1, 9, 3, 7, 2, 8, 4, 6} + for _, v := range values { + h.Push(v) + } + + require.Equal(t, Int(9), h.Max(), "Max should return the largest element") + require.Equal(t, Int(1), h.Min(), "Min should return the smallest element") + + // PopLast should remove and return the maximum. + popped := h.PopLast() + require.Equal(t, Int(9), popped) + require.Equal(t, Int(8), h.Max(), "Max should be 8 after removing 9") + require.Equal(t, 8, h.Len()) +}