Skip to content

Commit d015e75

Browse files
Create approx_nearest_neighbours.py
1 parent a71618f commit d015e75

1 file changed

Lines changed: 110 additions & 0 deletions

File tree

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Approximate Nearest Neighbor (ANN) Search
3+
https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor
4+
5+
ANN search finds "close enough" vectors instead of the exact nearest neighbor,
6+
which makes it much faster for large datasets.
7+
This implementation uses a simple **random projection hashing** method.
8+
9+
Steps:
10+
1. Generate random hyperplanes to hash vectors into buckets.
11+
2. Place dataset vectors into buckets.
12+
3. For a query vector, look into its bucket (and maybe nearby buckets).
13+
4. Return the approximate nearest neighbor from those candidates.
14+
15+
Each result contains:
16+
1. The nearest (approximate) vector.
17+
2. Its distance from the query vector.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import math
23+
from collections import defaultdict
24+
25+
import numpy as np
26+
27+
28+
def euclidean(input_a: np.ndarray, input_b: np.ndarray) -> float:
29+
"""
30+
Calculates Euclidean distance between two vectors.
31+
32+
>>> euclidean(np.array([0]), np.array([1]))
33+
1.0
34+
>>> euclidean(np.array([1, 2]), np.array([1, 5]))
35+
3.0
36+
"""
37+
return math.sqrt(sum(pow(a - b, 2) for a, b in zip(input_a, input_b)))
38+
39+
40+
class ANN:
41+
"""
42+
Approximate Nearest Neighbor using random projection hashing.
43+
"""
44+
45+
def __init__(self, dataset: np.ndarray, n_planes: int = 5, seed: int = 42):
46+
"""
47+
:param dataset: ndarray of shape (n_samples, n_features)
48+
:param n_planes: number of random hyperplanes for hashing
49+
:param seed: random seed for reproducibility
50+
"""
51+
self.dataset = dataset
52+
self.n_planes = n_planes
53+
rng = np.random.default_rng(seed)
54+
self.planes = rng.standard_normal((n_planes, dataset.shape[1]))
55+
self.buckets: dict[str, list[np.ndarray]] = defaultdict(list)
56+
self._build_index()
57+
58+
def _hash_vector(self, vec: np.ndarray) -> str:
59+
"""
60+
Hash a vector based on which side of each hyperplane it falls on.
61+
Returns a bit string.
62+
"""
63+
signs = (vec @ self.planes.T) >= 0
64+
return "".join(["1" if s else "0" for s in signs])
65+
66+
def _build_index(self):
67+
"""
68+
Build hash buckets for all dataset vectors.
69+
"""
70+
for vec in self.dataset:
71+
h = self._hash_vector(vec)
72+
self.buckets[h].append(vec)
73+
74+
def query(self, q: np.ndarray) -> list[list[list[float] | float]]:
75+
"""
76+
Find approximate nearest neighbor for query vector(s).
77+
78+
:param q: ndarray of shape (m, n_features)
79+
:return: list of [nearest_vector, distance]
80+
81+
>>> dataset = np.array([[0,0], [1,1], [2,2], [10,10]])
82+
>>> ann = ANN(dataset, n_planes=4, seed=0)
83+
>>> ann.query(np.array([[0,1]])) # doctest: +NORMALIZE_WHITESPACE
84+
[[[0, 0], 1.0]]
85+
"""
86+
results = []
87+
for vec in q:
88+
h = self._hash_vector(vec)
89+
candidates = self.buckets[h]
90+
91+
if not candidates: # fallback: search entire dataset
92+
candidates = self.dataset
93+
94+
# Approximate NN search among candidates
95+
best_vec = candidates[0]
96+
best_dist = euclidean(vec, best_vec)
97+
for cand in candidates[1:]:
98+
d = euclidean(vec, cand)
99+
if d < best_dist:
100+
best_vec, best_dist = cand, d
101+
results.append([best_vec.tolist(), best_dist])
102+
return results
103+
104+
105+
if __name__ == "__main__":
106+
import doctest
107+
doctest.testmod()
108+
109+
110+

0 commit comments

Comments
 (0)