Skip to content

Commit 799436b

Browse files
committed
Implement count min sketch1
1 parent 5bb6181 commit 799436b

1 file changed

Lines changed: 294 additions & 1 deletion

File tree

Lines changed: 294 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,298 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
120
package org.apache.datasketches.count;
221

22+
import org.apache.datasketches.common.Family;
23+
import org.apache.datasketches.common.SketchesException;
24+
import org.apache.datasketches.hash.MurmurHash3;
25+
import org.apache.datasketches.tuple.Util;
26+
27+
import java.nio.Buffer;
28+
import java.nio.ByteBuffer;
29+
import java.nio.charset.StandardCharsets;
30+
import java.util.Random;
31+
32+
333
public class CountMinSketch {
4-
34+
private final byte numHashes_;
35+
private final int numBuckets_;
36+
private final long seed_;
37+
private final long[] hashSeeds_;
38+
private final long[] sketchArray_;
39+
private long totalWeight_;
40+
41+
private static final int IS_EMPTY = 0;
42+
43+
/**
44+
* Creates a CountMin sketch with given number of hash functions and buckets,
45+
* and a user-specified seed.
46+
*
47+
* @param numHashes The number of hash functions to apply to items
48+
* @param numBuckets Array size for each of the hashing function
49+
* @param seed The base hash seed
50+
*/
51+
CountMinSketch(final byte numHashes, final int numBuckets, final long seed) {
52+
numHashes_ = numHashes;
53+
numBuckets_ = numBuckets;
54+
seed_ = seed;
55+
hashSeeds_ = new long[numHashes];
56+
sketchArray_ = new long[numHashes * numBuckets];
57+
totalWeight_ = 0;
58+
59+
if (numBuckets < 3) {
60+
throw new SketchesException("Using fewer than 3 buckets incurs relative error greater than 1.");
61+
}
62+
63+
// This check is to ensure later compatibility with a Java implementation whose maximum size can only
64+
// be 2^31-1. We check only against 2^30 for simplicity.
65+
if (numBuckets * numHashes >= 1 << 30) {
66+
throw new SketchesException("These parameters generate a sketch that exceeds 2^30 elements. \n" +
67+
"Try reducing either the number of buckets or the number of hash functions.");
68+
}
69+
70+
Random rand = new Random();
71+
for (int i = 0; i < numHashes; i++) {
72+
hashSeeds_[i] = rand.nextLong();
73+
}
74+
}
75+
76+
private long[] getHashes(byte[] item) {
77+
long[] updateLocations = new long[numHashes_];
78+
79+
for (int i = 0; i < numHashes_; i++) {
80+
long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
81+
updateLocations[i] = i * (long)numBuckets_ + index[0] % numBuckets_;
82+
}
83+
84+
return updateLocations;
85+
}
86+
87+
public boolean isEmpty() {
88+
return totalWeight_ == 0;
89+
}
90+
91+
public byte getNumHashes_() {
92+
return numHashes_;
93+
}
94+
95+
public int getNumBuckets_() {
96+
return numBuckets_;
97+
}
98+
99+
public long getSeed_() {
100+
return seed_;
101+
}
102+
103+
public long getTotalWeight_() {
104+
return totalWeight_;
105+
}
106+
107+
public double getRelativeError() {
108+
return Math.exp(1.0) / (double)numBuckets_;
109+
}
110+
111+
public byte suggestNumHashes(double confidence) {
112+
if (confidence < 0 || confidence > 1) {
113+
throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive).");
114+
}
115+
int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence)));
116+
return (byte) Math.min(value, 127);
117+
}
118+
119+
public int suggestNumBuckets(double relativeError) {
120+
return (int) Math.ceil(Math.exp(1.0) / relativeError);
121+
}
122+
123+
public void update(final long item, final long weight) {
124+
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
125+
update(longByte, weight);
126+
}
127+
128+
public void update(final String item, final long weight) {
129+
if (item == null || item.isEmpty()) {
130+
return;
131+
}
132+
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
133+
update(strByte, weight);
134+
}
135+
136+
public void update(final byte[] item, final long weight) {
137+
if (item.length == 0) {
138+
return;
139+
}
140+
141+
totalWeight_ += weight > 0 ? weight : -weight;
142+
long[] hashLocations = getHashes(item);
143+
for (long h : hashLocations) {
144+
sketchArray_[(int) h] += weight;
145+
}
146+
}
147+
148+
public long getEstimate(final long item) {
149+
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
150+
return getEstimate(longByte);
151+
}
152+
153+
public long getEstimate(final String item) {
154+
if (item == null || item.isEmpty()) {
155+
return 0;
156+
}
157+
158+
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
159+
return getEstimate(strByte);
160+
}
161+
162+
public long getEstimate(final byte[] item) {
163+
if (item.length == 0) {
164+
return 0;
165+
}
166+
167+
long[] hashLocations = getHashes(item);
168+
long res = sketchArray_[(int) hashLocations[0]];
169+
for (long h : hashLocations) {
170+
res = Math.min(res, sketchArray_[(int) h]);
171+
}
172+
173+
return res;
174+
}
175+
176+
public long getUpperBound(final long item) {
177+
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
178+
return getUpperBound(longByte);
179+
}
180+
181+
public long getUpperBound(final String item) {
182+
if (item == null || item.isEmpty()) {
183+
return 0;
184+
}
185+
186+
byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
187+
return getUpperBound(strByte);
188+
}
189+
190+
public long getUpperBound(final byte[] item) {
191+
if (item.length == 0) {
192+
return 0;
193+
}
194+
195+
return getEstimate(item) + (long)(getRelativeError() * getTotalWeight_());
196+
}
197+
198+
public long getLowerBound(final long item) {
199+
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
200+
return getLowerBound(longByte);
201+
}
202+
203+
public long getLowerBound(final String item) {
204+
if (item == null || item.isEmpty()) {
205+
return 0;
206+
}
207+
208+
byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
209+
return getLowerBound(strByte);
210+
}
211+
212+
public long getLowerBound(final byte[] item) {
213+
return getEstimate(item);
214+
}
215+
216+
public void merge(final CountMinSketch other) {
217+
if (this == other) {
218+
throw new SketchesException("Cannot merge a sketch with itself");
219+
}
220+
221+
boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() &&
222+
getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_();
223+
224+
if (!acceptableConfig) {
225+
throw new SketchesException("Incompatible sketch configuration.");
226+
}
227+
228+
for (int i = 0; i < sketchArray_.length; i++) {
229+
sketchArray_[i] += other.sketchArray_[i];
230+
}
231+
232+
totalWeight_ += other.getTotalWeight_();
233+
}
234+
235+
public void serialize(ByteBuffer buf) {
236+
// Long 0
237+
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
238+
buf.put((byte) preambleLongs);
239+
final int serialVersion = 1;
240+
buf.put((byte) serialVersion);
241+
final int familyId = Family.COUNTMIN.getID();
242+
buf.put((byte) familyId);
243+
final int flagsByte = isEmpty() ? 1 << IS_EMPTY : 0;
244+
buf.put((byte)flagsByte);
245+
final int NULL_32 = 0;
246+
buf.putInt(NULL_32);
247+
248+
// Long 1
249+
buf.putInt(numBuckets_);
250+
buf.putShort(numHashes_);
251+
buf.putShort(Util.computeSeedHash(seed_));
252+
final byte NULL_8 = 0;
253+
buf.put(NULL_8);
254+
if (isEmpty()) {
255+
return;
256+
}
257+
258+
buf.putLong(totalWeight_);
259+
260+
for (long estimate: sketchArray_) {
261+
buf.putLong(estimate);
262+
}
263+
}
264+
265+
public static CountMinSketch deserialize(final byte[] b, final long seed) {
266+
ByteBuffer buf = ByteBuffer.allocate(b.length);
267+
buf.put(b);
268+
269+
final byte preambleLongs = buf.get();
270+
final byte serialVersion = buf.get();
271+
final byte familyId = buf.get();
272+
final byte flagsByte = buf.get();
273+
final int NULL_32 = buf.getInt();
274+
275+
final int numBuckets = buf.getInt();
276+
final byte numHashes = buf.get();
277+
final short seedHash = buf.getShort();
278+
final byte NULL_8 = buf.get();
279+
280+
if (seedHash != Util.computeSeedHash(seed)) {
281+
throw new SketchesException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
282+
+ String.valueOf(Util.computeSeedHash(seed)));
283+
}
284+
285+
CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
286+
final boolean empty = (flagsByte & (1 << IS_EMPTY)) > 0;
287+
if (empty) {
288+
return cms;
289+
}
290+
291+
int i = 0;
292+
while (buf.hasRemaining()) {
293+
cms.sketchArray_[i] = buf.getLong();
294+
}
295+
296+
return cms;
297+
}
5298
}

0 commit comments

Comments
 (0)