Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (n *layerNode[K]) addNeighbor(newNode *layerNode[K], m int, dist DistanceFu
delete(n.neighbors, worst.Key)
// Delete backlink from the worst neighbor.
delete(worst.neighbors, n.Key)
worst.replenish(m)
worst.replenish(m, dist)
}

type searchCandidate[K cmp.Ordered] struct {
Expand All @@ -76,17 +76,16 @@ func (s searchCandidate[K]) Less(o searchCandidate[K]) bool {
return s.dist < o.dist
}

// search returns the layer node closest to the target node
// within the same layer.
// search returns the nearest neighbors of target within this layer.
// Implements HNSW Algorithm 5 (SEARCH-LAYER): the result set is bounded
// by efSearch during exploration; only the k closest are returned.
func (n *layerNode[K]) search(
// k is the number of candidates in the result set.
// k is the number of nearest neighbors to return.
k int,
efSearch int,
target Vector,
distance DistanceFunc,
) []searchCandidate[K] {
// This is a basic greedy algorithm to find the entry point at the given level
// that is closest to the target node.
candidates := heap.Heap[searchCandidate[K]]{}
candidates.Init(make([]searchCandidate[K], 0, efSearch))
candidates.Push(
Expand All @@ -99,56 +98,64 @@ func (n *layerNode[K]) search(
result = heap.Heap[searchCandidate[K]]{}
visited = make(map[K]bool)
)
result.Init(make([]searchCandidate[K], 0, k))
result.Init(make([]searchCandidate[K], 0, efSearch))

// Begin with the entry node in the result set.
result.Push(candidates.Min())
visited[n.Key] = true

for candidates.Len() > 0 {
var (
current = candidates.Pop().node
improved = false
)
current := candidates.Pop()

// Standard HNSW termination: if the closest remaining candidate
// is farther than the worst in the ef-bounded result set,
// no further improvement is possible.
if result.Len() >= efSearch && current.dist > result.Max().dist {
break
}

// We iterate the map in a sorted, deterministic fashion for
// tests.
neighborKeys := maps.Keys(current.neighbors)
neighborKeys := maps.Keys(current.node.neighbors)
slices.Sort(neighborKeys)
for _, neighborID := range neighborKeys {
neighbor := current.neighbors[neighborID]
neighbor := current.node.neighbors[neighborID]
if visited[neighborID] {
continue
}
visited[neighborID] = true

dist := distance(neighbor.Value, target)
improved = improved || dist < result.Min().dist
if result.Len() < k {
if result.Len() < efSearch {
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
candidates.Push(searchCandidate[K]{node: neighbor, dist: dist})
} else if dist < result.Max().dist {
result.PopLast()
result.Push(searchCandidate[K]{node: neighbor, dist: dist})
}

candidates.Push(searchCandidate[K]{node: neighbor, dist: dist})
// Always store candidates if we haven't reached the limit.
if candidates.Len() > efSearch {
candidates.PopLast()
candidates.Push(searchCandidate[K]{node: neighbor, dist: dist})
}
}
}

// Termination condition: no improvement in distance and at least
// kMin candidates in the result set.
if !improved && result.Len() >= k {
break
// Return only the k closest from the ef-sized result set.
res := result.Slice()
slices.SortFunc(res, func(a, b searchCandidate[K]) int {
if a.dist < b.dist {
return -1
}
if a.dist > b.dist {
return 1
}
// Deterministic tiebreaking by key.
return cmp.Compare(a.node.Key, b.node.Key)
})
if len(res) > k {
res = res[:k]
}

return result.Slice()
return res
}

func (n *layerNode[K]) replenish(m int) {
func (n *layerNode[K]) replenish(m int, dist DistanceFunc) {
if len(n.neighbors) >= m {
return
}
Expand All @@ -165,7 +172,7 @@ func (n *layerNode[K]) replenish(m int) {
if candidate == n {
continue
}
n.addNeighbor(candidate, m, CosineDistance)
n.addNeighbor(candidate, m, dist)
if len(n.neighbors) >= m {
return
}
Expand All @@ -175,13 +182,13 @@ func (n *layerNode[K]) replenish(m int) {

// isolates remove the node from the graph by removing all connections
// to neighbors.
func (n *layerNode[K]) isolate(m int) {
func (n *layerNode[K]) isolate(m int, dist DistanceFunc) {
for _, neighbor := range n.neighbors {
delete(neighbor.neighbors, n.Key)
}

for _, neighbor := range n.neighbors {
neighbor.replenish(m)
neighbor.replenish(m, dist)
}
}

Expand Down Expand Up @@ -501,7 +508,7 @@ func (h *Graph[K]) Delete(key K) bool {
if len(layer.nodes) == 0 {
deleteLayer[i] = struct{}{}
}
node.isolate(h.M)
node.isolate(h.M, h.Distance)
deleted = true
}

Expand Down
49 changes: 48 additions & 1 deletion graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,15 @@ func TestGraph_AddSearch(t *testing.T) {
)

require.Len(t, nearest, 4)
// The two closest are 64 and 65 (distance 0.5 each).
// The next two are 63 and 66 (distance 1.5 each).
require.EqualValues(
t,
[]Node[int]{
{64, Vector{64}},
{65, Vector{65}},
{62, Vector{62}},
{63, Vector{63}},
{66, Vector{66}},
},
nearest,
)
Expand Down Expand Up @@ -259,3 +261,48 @@ func TestGraph_RemoveAllNodes(t *testing.T) {
g.Add(MakeNode(1, vec))
}
}

func TestGraph_SearchFindsCorrectNearest(t *testing.T) {
// With the old search algorithm, termination was too aggressive (stopped
// when no neighbor beat result.Min) and the result set was bounded by k
// instead of efSearch. This caused the search to miss true nearest
// neighbors, especially with small k and large efSearch.
g := newTestGraph[int]()
for i := 0; i < 100; i++ {
g.Add(Node[int]{Key: i, Value: Vector{float32(i)}})
}

// Search for k=1 nearest to 50.5. The answer must be 50 or 51.
results := g.Search(Vector{50.5}, 1)
require.Len(t, results, 1)
require.Contains(t, []int{50, 51}, results[0].Key,
"expected nearest neighbor to 50.5, got key=%d", results[0].Key)

// Search for k=3 nearest to 0.0. Must return 0, 1, 2.
results = g.Search(Vector{0.0}, 3)
require.Len(t, results, 3)
keys := make([]int, 3)
for i, r := range results {
keys[i] = r.Key
}
require.Subset(t, []int{0, 1, 2}, keys)
}

func TestGraph_DeleteReplenishUsesGraphDistance(t *testing.T) {
// replenish() previously hardcoded CosineDistance. After deleting a
// node from a EuclideanDistance graph, replenish must use the correct
// distance function or the topology becomes corrupted.
g := newTestGraph[int]() // uses EuclideanDistance
for i := 0; i < 20; i++ {
g.Add(Node[int]{Key: i, Value: Vector{float32(i)}})
}

// Delete a node in the middle to trigger replenish.
g.Delete(10)

// Search should still find the correct nearest neighbor.
results := g.Search(Vector{9.5}, 1)
require.Len(t, results, 1)
// Must be 9 or 11 (both distance 0.5 from 9.5).
require.Contains(t, []int{9, 11}, results[0].Key)
}
18 changes: 16 additions & 2 deletions heap/heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ func (h *Heap[T]) Pop() T {
return heap.Pop(&h.inner).(T)
}

// PopLast removes and returns the maximum element from the heap.
func (h *Heap[T]) PopLast() T {
return h.Remove(h.Len() - 1)
return h.Remove(h.maxIndex())
}

// Remove removes and returns the element at index i from the heap.
Expand All @@ -85,9 +86,22 @@ func (h *Heap[T]) Min() T {
return h.inner.data[0]
}

// maxIndex returns the index of the maximum element by scanning leaf nodes.
// In a min-heap the max is always a leaf (indices n/2 .. n-1).
func (h *Heap[T]) maxIndex() int {
n := h.inner.Len()
best := n / 2
for i := best + 1; i < n; i++ {
if h.inner.data[best].Less(h.inner.data[i]) {
best = i
}
}
return best
}

// Max returns the maximum element in the heap.
func (h *Heap[T]) Max() T {
return h.inner.data[h.inner.Len()-1]
return h.inner.data[h.maxIndex()]
}

func (h *Heap[T]) Slice() []T {
Expand Down
17 changes: 17 additions & 0 deletions heap/heap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,20 @@ func TestHeap(t *testing.T) {
t.Errorf("Heap did not return sorted elements: %+v", inOrder)
}
}

func TestHeap_MaxAndPopLast(t *testing.T) {
h := Heap[Int]{}
values := []Int{5, 1, 9, 3, 7, 2, 8, 4, 6}
for _, v := range values {
h.Push(v)
}

require.Equal(t, Int(9), h.Max(), "Max should return the largest element")
require.Equal(t, Int(1), h.Min(), "Min should return the smallest element")

// PopLast should remove and return the maximum.
popped := h.PopLast()
require.Equal(t, Int(9), popped)
require.Equal(t, Int(8), h.Max(), "Max should be 8 after removing 9")
require.Equal(t, 8, h.Len())
}
Loading