Skip to content

Commit aec9a7e

Browse files
committed
Add tests
1 parent b3d9f3f commit aec9a7e

2 files changed

Lines changed: 277 additions & 22 deletions

File tree

src/main/java/org/apache/datasketches/count/CountMinSketch.java

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
package org.apache.datasketches.count;
2121

2222
import org.apache.datasketches.common.Family;
23+
import org.apache.datasketches.common.SketchesArgumentException;
2324
import org.apache.datasketches.common.SketchesException;
2425
import org.apache.datasketches.hash.MurmurHash3;
2526
import org.apache.datasketches.tuple.Util;
2627

28+
import java.io.ByteArrayOutputStream;
2729
import java.nio.ByteBuffer;
2830
import java.nio.charset.StandardCharsets;
2931
import java.util.Random;
@@ -63,17 +65,17 @@ int mask() {
6365
totalWeight_ = 0;
6466

6567
if (numBuckets < 3) {
66-
throw new SketchesException("Using fewer than 3 buckets incurs relative error greater than 1.");
68+
throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1.");
6769
}
6870

6971
// This check is to ensure later compatibility with a Java implementation whose maximum size can only
7072
// be 2^31-1. We check only against 2^30 for simplicity.
7173
if (numBuckets * numHashes >= 1 << 30) {
72-
throw new SketchesException("These parameters generate a sketch that exceeds 2^30 elements. \n" +
74+
throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" +
7375
"Try reducing either the number of buckets or the number of hash functions.");
7476
}
7577

76-
Random rand = new Random();
78+
Random rand = new Random(seed);
7779
for (int i = 0; i < numHashes; i++) {
7880
hashSeeds_[i] = rand.nextLong();
7981
}
@@ -84,7 +86,7 @@ private long[] getHashes(byte[] item) {
8486

8587
for (int i = 0; i < numHashes_; i++) {
8688
long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
87-
updateLocations[i] = i * (long)numBuckets_ + index[0] % numBuckets_;
89+
updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_);
8890
}
8991

9092
return updateLocations;
@@ -143,7 +145,7 @@ public double getRelativeError() {
143145
* @param confidence The desired confidence level between 0 and 1.
144146
* @return Suggested number of hash functions.
145147
*/
146-
public byte suggestNumHashes(double confidence) {
148+
public static byte suggestNumHashes(double confidence) {
147149
if (confidence < 0 || confidence > 1) {
148150
throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive).");
149151
}
@@ -156,7 +158,10 @@ public byte suggestNumHashes(double confidence) {
156158
* @param relativeError The desired relative error.
157159
* @return Suggested number of buckets.
158160
*/
159-
public int suggestNumBuckets(double relativeError) {
161+
public static int suggestNumBuckets(double relativeError) {
162+
if (relativeError < 0.) {
163+
throw new SketchesException("Relative error must be at least 0.");
164+
}
160165
return (int) Math.ceil(Math.exp(1.0) / relativeError);
161166
}
162167

@@ -340,33 +345,35 @@ public void merge(final CountMinSketch other) {
340345
* Serializes the sketch into the provided ByteBuffer.
341346
* @param buf The ByteBuffer to write into.
342347
*/
343-
public void serialize(ByteBuffer buf) {
348+
public void serialize(ByteArrayOutputStream buf) {
344349
// Long 0
345350
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
346-
buf.put((byte) preambleLongs);
351+
buf.write((byte) preambleLongs);
347352
final int serialVersion = 1;
348-
buf.put((byte) serialVersion);
353+
buf.write((byte) serialVersion);
349354
final int familyId = Family.COUNTMIN.getID();
350-
buf.put((byte) familyId);
355+
buf.write((byte) familyId);
351356
final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0;
352-
buf.put((byte)flagsByte);
357+
buf.write((byte)flagsByte);
353358
final int NULL_32 = 0;
354-
buf.putInt(NULL_32);
359+
buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array());
355360

356361
// Long 1
357-
buf.putInt(numBuckets_);
358-
buf.putShort(numHashes_);
359-
buf.putShort(Util.computeSeedHash(seed_));
362+
buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array());
363+
buf.write(numHashes_);
364+
short hashSeed = Util.computeSeedHash(seed_);
365+
buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array());
360366
final byte NULL_8 = 0;
361-
buf.put(NULL_8);
367+
buf.write(NULL_8);
362368
if (isEmpty()) {
363369
return;
364370
}
365371

366-
buf.putLong(totalWeight_);
372+
final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array();
373+
buf.writeBytes(totWeightByte);
367374

368-
for (long estimate: sketchArray_) {
369-
buf.putLong(estimate);
375+
for (long w: sketchArray_) {
376+
buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array());
370377
}
371378
}
372379

@@ -379,6 +386,7 @@ public void serialize(ByteBuffer buf) {
379386
public static CountMinSketch deserialize(final byte[] b, final long seed) {
380387
ByteBuffer buf = ByteBuffer.allocate(b.length);
381388
buf.put(b);
389+
buf.flip();
382390

383391
final byte preambleLongs = buf.get();
384392
final byte serialVersion = buf.get();
@@ -392,7 +400,7 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) {
392400
final byte NULL_8 = buf.get();
393401

394402
if (seedHash != Util.computeSeedHash(seed)) {
395-
throw new SketchesException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
403+
throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
396404
+ String.valueOf(Util.computeSeedHash(seed)));
397405
}
398406

@@ -401,9 +409,10 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) {
401409
if (empty) {
402410
return cms;
403411
}
412+
long w = buf.getLong();
413+
cms.totalWeight_ = w;
404414

405-
int i = 0;
406-
while (buf.hasRemaining()) {
415+
for (int i = 0; i < cms.sketchArray_.length; i++) {
407416
cms.sketchArray_[i] = buf.getLong();
408417
}
409418

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.datasketches.count;
21+
22+
import org.apache.datasketches.common.SketchesArgumentException;
23+
import org.apache.datasketches.common.SketchesException;
24+
import org.testng.annotations.Test;
25+
26+
import java.io.ByteArrayOutputStream;
27+
import java.lang.annotation.Repeatable;
28+
import java.nio.ByteBuffer;
29+
30+
import static org.testng.Assert.*;
31+
32+
public class CountMinSketchTest {
33+
@Test
34+
public void createNewCountMinSketchTest() throws Exception {
35+
assertThrows(SketchesArgumentException.class, () -> new CountMinSketch((byte) 5, 1, 123));
36+
assertThrows(SketchesArgumentException.class, () -> new CountMinSketch((byte) 4, 268435456, 123));
37+
38+
final byte numHashes = 3;
39+
final int numBuckets = 5;
40+
final long seed = 1234567;
41+
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);
42+
43+
assertEquals(c.getNumHashes_(), numHashes);
44+
assertEquals(c.getNumBuckets_(), numBuckets);
45+
assertEquals(c.getSeed_(), seed);
46+
assertTrue(c.isEmpty());
47+
}
48+
49+
@Test
50+
public void parameterSuggestionsTest() {
51+
// Bucket suggestions
52+
assertThrows("Relative error must be at least 0.", SketchesException.class, () -> CountMinSketch.suggestNumBuckets(-1.0));
53+
assertEquals(CountMinSketch.suggestNumBuckets(0.2), 14);
54+
assertEquals(CountMinSketch.suggestNumBuckets(0.1), 28);
55+
assertEquals(CountMinSketch.suggestNumBuckets(0.05), 55);
56+
assertEquals(CountMinSketch.suggestNumBuckets(0.01), 272);
57+
58+
// Check that the sketch get_epsilon acts inversely to suggest_num_buckets
59+
final byte numHashes = 3;
60+
final long seed = 1234567;
61+
assertTrue(new CountMinSketch(numHashes, 14, seed).getRelativeError() <= 0.2);
62+
assertTrue(new CountMinSketch(numHashes, 28, seed).getRelativeError() <= 0.1);
63+
assertTrue(new CountMinSketch(numHashes, 55, seed).getRelativeError() <= 0.05);
64+
assertTrue(new CountMinSketch(numHashes, 272, seed).getRelativeError() <= 0.01);
65+
66+
// Hash suggestions
67+
assertThrows("Confidence must be between 0 and 1.0 (inclusive).", SketchesException.class, () -> CountMinSketch.suggestNumHashes(10));
68+
assertThrows("Confidence must be between 0 and 1.0 (inclusive).", SketchesException.class, () -> CountMinSketch.suggestNumHashes(-1.0));
69+
assertEquals(CountMinSketch.suggestNumHashes(0.682689492), 2);
70+
assertEquals(CountMinSketch.suggestNumHashes(0.954499736), 4);
71+
assertEquals(CountMinSketch.suggestNumHashes(0.997300204), 6);
72+
}
73+
74+
@Test
75+
public void countMinSketchOneUpdateTest() {
76+
final byte numHashes = 3;
77+
final int numBuckets = 5;
78+
final long seed = 1234567;
79+
long insertedWeights = 0;
80+
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);
81+
final String x = "x";
82+
83+
assertTrue(c.isEmpty());
84+
assertEquals(c.getEstimate(x), 0);
85+
c.update(x, 1);
86+
assertFalse(c.isEmpty());
87+
assertEquals(c.getEstimate(x), 1);
88+
insertedWeights++;
89+
90+
final long w = 9;
91+
insertedWeights += w;
92+
c.update(x, w);
93+
assertEquals(c.getEstimate(x), insertedWeights);
94+
95+
final double w1 = 10.0;
96+
insertedWeights += (long) w1;
97+
c.update(x, (long) w1);
98+
assertEquals(c.getEstimate(x), insertedWeights);
99+
assertEquals(c.getTotalWeight_(), insertedWeights);
100+
assertTrue(c.getEstimate(x) <= c.getUpperBound(x));
101+
assertTrue(c.getEstimate(x) >= c.getLowerBound(x));
102+
}
103+
104+
@Test
105+
public void frequencyCancellationTest() {
106+
CountMinSketch c = new CountMinSketch((byte) 1, 5, 123456);
107+
c.update("x", 1);
108+
c.update("y", -1);
109+
assertEquals(c.getTotalWeight_(), 2);
110+
assertEquals(c.getEstimate("x"), 1);
111+
assertEquals(c.getEstimate("y"), -1);
112+
}
113+
114+
@Test
115+
public void frequencyEstimates() {
116+
final int numItems = 10;
117+
long[] data = new long[numItems];
118+
long[] frequencies = new long[numItems];
119+
120+
for (int i = 0; i < numItems; i++) {
121+
data[i] = i;
122+
frequencies[i] = (long) 1 << (numItems - i);
123+
}
124+
125+
final double relativeError = 0.1;
126+
final double confidence = 0.99;
127+
final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError);
128+
final byte numHashes = CountMinSketch.suggestNumHashes(confidence);
129+
130+
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, 1234567);
131+
for (int i = 0; i < numItems; i++) {
132+
final long value = data[i];
133+
final long freq = frequencies[i];
134+
c.update(value, freq);
135+
}
136+
137+
for (final long i : data) {
138+
final long est = c.getEstimate(i);
139+
final long upp = c.getUpperBound(i);
140+
final long low = c.getLowerBound(i);
141+
assertTrue(est <= upp);
142+
assertTrue(est >= low);
143+
}
144+
}
145+
146+
@Test
147+
public void mergeFailTest() {
148+
final double relativeError = 0.25;
149+
final double confidence = 0.9;
150+
final long seed = 1234567;
151+
final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError);
152+
final byte numHashes = CountMinSketch.suggestNumHashes(confidence);
153+
CountMinSketch s = new CountMinSketch(numHashes, numBuckets, seed);
154+
155+
assertThrows("Cannot merge a sketch with itself.", SketchesException.class, () -> s.merge(s));
156+
157+
CountMinSketch s1 = new CountMinSketch((byte) (numHashes + 1), numBuckets, seed);
158+
CountMinSketch s2 = new CountMinSketch(numHashes, numBuckets + 1, seed);
159+
CountMinSketch s3 = new CountMinSketch(numHashes, numBuckets, seed + 1);
160+
161+
CountMinSketch[] sketches = {s1, s2, s3};
162+
for (final CountMinSketch sk : sketches) {
163+
assertThrows("Incompatible sketch configuration.", SketchesException.class, () -> s.merge(sk));
164+
}
165+
}
166+
167+
@Test
168+
public void mergeTest() {
169+
final double relativeError = 0.25;
170+
final double confidence = 0.9;
171+
final long seed = 123456;
172+
final int numBuckets = CountMinSketch.suggestNumBuckets(relativeError);
173+
final byte numHashes = CountMinSketch.suggestNumHashes(confidence);
174+
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);
175+
176+
final byte sHashes = c.getNumHashes_();
177+
final int sBuckets = c.getNumBuckets_();
178+
final long sSeed = c.getSeed_();
179+
CountMinSketch s = new CountMinSketch(sHashes, sBuckets, sSeed);
180+
181+
c.merge(s);
182+
assertEquals(c.getTotalWeight_(), 0);
183+
184+
final long[] data = {2, 3, 5, 7};
185+
for (final long d : data) {
186+
c.update(d, 1);
187+
s.update(d, 1);
188+
}
189+
c.merge(s);
190+
191+
assertEquals(c.getTotalWeight_(), 2 * s.getTotalWeight_());
192+
193+
for (final long d : data) {
194+
assertTrue(c.getEstimate(d) <= c.getUpperBound(d));
195+
assertTrue(s.getEstimate(d) <= 2);
196+
}
197+
}
198+
199+
@Test
200+
public void serializeDeserializeEmptyTest() {
201+
final byte numHashes = 3;
202+
final int numBuckets = 32;
203+
final long seed = 123456;
204+
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);
205+
206+
ByteArrayOutputStream buf = new ByteArrayOutputStream();
207+
c.serialize(buf);
208+
209+
byte[] b = buf.toByteArray();
210+
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1));
211+
212+
CountMinSketch d = CountMinSketch.deserialize(b, seed);
213+
assertEquals(d.getNumHashes_(), c.getNumHashes_());
214+
assertEquals(d.getNumBuckets_(), c.getNumBuckets_());
215+
assertEquals(d.getSeed_(), c.getSeed_());
216+
final long zero = 0;
217+
assertEquals(d.getEstimate(zero), c.getEstimate(zero));
218+
assertEquals(d.getTotalWeight_(), c.getTotalWeight_());
219+
}
220+
221+
@Test
222+
public void serializeDeserializeTest() {
223+
final byte numHashes = 5;
224+
final int numBuckets = 64;
225+
final long seed = 1234456;
226+
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);
227+
for (long i = 0; i < 10; i++) {
228+
c.update(i, 10*i*i);
229+
}
230+
231+
ByteArrayOutputStream buf = new ByteArrayOutputStream();
232+
c.serialize(buf);
233+
234+
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1));
235+
CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed);
236+
237+
assertEquals(d.getNumHashes_(), c.getNumHashes_());
238+
assertEquals(d.getNumBuckets_(), c.getNumBuckets_());
239+
assertEquals(d.getSeed_(), c.getSeed_());
240+
assertEquals(d.getTotalWeight_(), c.getTotalWeight_());
241+
242+
for (long i = 0; i < 10; i++) {
243+
assertEquals(d.getEstimate(i), c.getEstimate(i));
244+
}
245+
}
246+
}

0 commit comments

Comments
 (0)