Skip to content

Commit 9587ef3

Browse files
authored
Merge branch 'main' into ffm_phase7
2 parents 4bc22a2 + f62d552 commit 9587ef3

2 files changed

Lines changed: 132 additions & 66 deletions

File tree

src/main/java/org/apache/datasketches/count/CountMinSketch.java

Lines changed: 128 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,29 @@
1919

2020
package org.apache.datasketches.count;
2121

22-
import java.io.ByteArrayOutputStream;
23-
import java.nio.ByteBuffer;
22+
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
23+
import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED;
24+
25+
import java.lang.foreign.MemorySegment;
2426
import java.nio.charset.StandardCharsets;
2527
import java.util.Random;
2628

2729
import org.apache.datasketches.common.Family;
2830
import org.apache.datasketches.common.SketchesArgumentException;
2931
import org.apache.datasketches.common.SketchesException;
3032
import org.apache.datasketches.common.Util;
33+
import org.apache.datasketches.common.positional.PositionalSegment;
3134
import org.apache.datasketches.hash.MurmurHash3;
3235

3336
/**
34-
* CountMinSketch.
37+
* Java implementation of the CountMin sketch data structure of Cormode and Muthukrishnan.
38+
* This implementation is inspired by and compatible with the datasketches-cpp version by Charlie Dickens.
39+
*
40+
* The CountMin sketch is a probabilistic data structure that provides frequency estimates for items
41+
* in a data stream. It uses multiple hash functions to distribute items across a two-dimensional array,
42+
* providing approximate counts with configurable error bounds.
43+
*
44+
* Reference: http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf
3545
*/
3646
public class CountMinSketch {
3747
private final byte numHashes_;
@@ -41,6 +51,10 @@ public class CountMinSketch {
4151
private final long[] sketchArray_;
4252
private long totalWeight_;
4353

54+
// Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
55+
private static final ThreadLocal<MemorySegment> LONG_SEGMENT =
56+
ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[Long.BYTES]));
57+
4458
private enum Flag {
4559
IS_EMPTY;
4660

@@ -58,30 +72,58 @@ int mask() {
5872
* @param seed The base hash seed
5973
*/
6074
CountMinSketch(final byte numHashes, final int numBuckets, final long seed) {
61-
numHashes_ = numHashes;
62-
numBuckets_ = numBuckets;
63-
seed_ = seed;
64-
hashSeeds_ = new long[numHashes];
65-
sketchArray_ = new long[numHashes * numBuckets];
66-
totalWeight_ = 0;
75+
// Validate numHashes
76+
if (numHashes <= 0) {
77+
throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes);
78+
}
6779

80+
// Validate numBuckets with clear mathematical justification
81+
if (numBuckets <= 0) {
82+
throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets);
83+
}
6884
if (numBuckets < 3) {
69-
throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1.");
85+
throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. "
86+
+ "With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets));
87+
}
88+
89+
// Check for potential overflow in array size calculation
90+
// Use long arithmetic to detect overflow before casting
91+
final long totalSize = (long) numHashes * (long) numBuckets;
92+
if (totalSize > Integer.MAX_VALUE) {
93+
throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets
94+
+ " = " + totalSize + " > " + Integer.MAX_VALUE);
7095
}
7196

7297
// This check is to ensure later compatibility with a Java implementation whose maximum size can only
7398
// be 2^31-1. We check only against 2^30 for simplicity.
74-
if ((numBuckets * numHashes) >= (1 << 30)) {
75-
throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \b"
76-
+ "Try reducing either the number of buckets or the number of hash functions.");
99+
if (totalSize >= (1L << 30)) {
100+
throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets
101+
+ " = " + totalSize + " elements (~" + String.format("%d", (totalSize * Long.BYTES) / (1024 * 1024 * 1024)) + " GB). "
102+
+ "Consider reducing numHashes or numBuckets.");
77103
}
78104

105+
numHashes_ = numHashes;
106+
numBuckets_ = numBuckets;
107+
seed_ = seed;
108+
hashSeeds_ = new long[numHashes];
109+
sketchArray_ = new long[(int) totalSize];
110+
totalWeight_ = 0;
111+
79112
final Random rand = new Random(seed);
80113
for (int i = 0; i < numHashes; i++) {
81114
hashSeeds_[i] = rand.nextLong();
82115
}
83116
}
84117

118+
/**
119+
* Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness.
120+
*/
121+
private static byte[] longToBytes(final long value) {
122+
final MemorySegment segment = LONG_SEGMENT.get();
123+
segment.set(JAVA_LONG_UNALIGNED, 0, value);
124+
return segment.toArray(JAVA_BYTE);
125+
}
126+
85127
private long[] getHashes(final byte[] item) {
86128
final long[] updateLocations = new long[numHashes_];
87129

@@ -172,8 +214,7 @@ public static int suggestNumBuckets(final double relativeError) {
172214
* @param weight The weight of the item.
173215
*/
174216
public void update(final long item, final long weight) {
175-
final byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
176-
update(longByte, weight);
217+
update(longToBytes(item), weight);
177218
}
178219

179220
/**
@@ -212,8 +253,7 @@ public void update(final byte[] item, final long weight) {
212253
* @return Estimated frequency.
213254
*/
214255
public long getEstimate(final long item) {
215-
final byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
216-
return getEstimate(longByte);
256+
return getEstimate(longToBytes(item));
217257
}
218258

219259
/**
@@ -242,8 +282,9 @@ public long getEstimate(final byte[] item) {
242282

243283
final long[] hashLocations = getHashes(item);
244284
long res = sketchArray_[(int) hashLocations[0]];
245-
for (final long h : hashLocations) {
246-
res = Math.min(res, sketchArray_[(int) h]);
285+
// Start from index 1 to avoid processing first element twice
286+
for (int i = 1; i < hashLocations.length; i++) {
287+
res = Math.min(res, sketchArray_[(int) hashLocations[i]]);
247288
}
248289

249290
return res;
@@ -255,8 +296,7 @@ public long getEstimate(final byte[] item) {
255296
* @return Upper bound of estimated frequency.
256297
*/
257298
public long getUpperBound(final long item) {
258-
final byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
259-
return getUpperBound(longByte);
299+
return getUpperBound(longToBytes(item));
260300
}
261301

262302
/**
@@ -270,7 +310,7 @@ public long getUpperBound(final String item) {
270310
}
271311

272312
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
273-
return getUpperBound(strByte);
313+
return getUpperBound(strByte);
274314
}
275315

276316
/**
@@ -292,8 +332,7 @@ public long getUpperBound(final byte[] item) {
292332
* @return Lower bound of estimated frequency.
293333
*/
294334
public long getLowerBound(final long item) {
295-
final byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
296-
return getLowerBound(longByte);
335+
return getLowerBound(longToBytes(item));
297336
}
298337

299338
/**
@@ -343,39 +382,56 @@ public void merge(final CountMinSketch other) {
343382
}
344383

345384
/**
346-
* Serializes the sketch into the provided ByteBuffer.
347-
* @param buf The ByteBuffer to write into.
385+
* Returns the serialized size in bytes.
348386
*/
349-
public void serialize(final ByteArrayOutputStream buf) {
387+
private int getSerializedSizeBytes() {
388+
final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES;
389+
if (isEmpty()) {
390+
return preambleBytes;
391+
}
392+
return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES);
393+
}
394+
395+
/**
396+
* Returns the sketch as a byte array.
397+
* @return the result byte array
398+
*/
399+
public byte[] toByteArray() {
400+
final int serializedSizeBytes = getSerializedSizeBytes();
401+
final byte[] bytes = new byte[serializedSizeBytes];
402+
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(bytes));
403+
350404
// Long 0
351405
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
352-
buf.write((byte) preambleLongs);
406+
posSeg.setByte((byte) preambleLongs);
353407
final int serialVersion = 1;
354-
buf.write((byte) serialVersion);
408+
posSeg.setByte((byte) serialVersion);
355409
final int familyId = Family.COUNTMIN.getID();
356-
buf.write((byte) familyId);
410+
posSeg.setByte((byte) familyId);
357411
final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0;
358-
buf.write((byte)flagsByte);
412+
posSeg.setByte((byte) flagsByte);
359413
final int NULL_32 = 0;
360-
buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array());
414+
posSeg.setInt(NULL_32);
361415

362416
// Long 1
363-
buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array());
364-
buf.write(numHashes_);
417+
posSeg.setInt(numBuckets_);
418+
posSeg.setByte(numHashes_);
365419
final short hashSeed = Util.computeSeedHash(seed_);
366-
buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array());
420+
posSeg.setShort(hashSeed);
367421
final byte NULL_8 = 0;
368-
buf.write(NULL_8);
422+
posSeg.setByte(NULL_8);
423+
369424
if (isEmpty()) {
370-
return;
425+
return bytes;
371426
}
372427

373-
final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array();
374-
buf.writeBytes(totWeightByte);
428+
posSeg.setLong(totalWeight_);
375429

376430
for (final long w: sketchArray_) {
377-
buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array());
431+
posSeg.setLong(w);
378432
}
433+
434+
return bytes;
379435
}
380436

381437
/**
@@ -384,38 +440,52 @@ public void serialize(final ByteArrayOutputStream buf) {
384440
* @param seed The seed used during serialization.
385441
* @return The deserialized CountMinSketch.
386442
*/
387-
@SuppressWarnings("unused")
388443
public static CountMinSketch deserialize(final byte[] b, final long seed) {
389-
final ByteBuffer buf = ByteBuffer.allocate(b.length);
390-
buf.put(b);
391-
buf.flip();
392-
393-
final byte preambleLongs = buf.get();
394-
final byte serialVersion = buf.get();
395-
final byte familyId = buf.get();
396-
final byte flagsByte = buf.get();
397-
final int NULL_32 = buf.getInt();
444+
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(b));
445+
446+
final byte preambleLongs = posSeg.getByte();
447+
final byte serialVersion = posSeg.getByte();
448+
final byte familyId = posSeg.getByte();
449+
final byte flagsByte = posSeg.getByte();
450+
posSeg.getInt(); // skip NULL_32
451+
452+
// Validate serialization format
453+
final int expectedPreambleLongs = Family.COUNTMIN.getMinPreLongs();
454+
if (preambleLongs != expectedPreambleLongs) {
455+
throw new SketchesArgumentException("Preamble longs mismatch: expected " + expectedPreambleLongs
456+
+ ", actual " + preambleLongs);
457+
}
458+
final int expectedSerialVersion = 1;
459+
if (serialVersion != expectedSerialVersion) {
460+
throw new SketchesArgumentException("Serial version mismatch: expected " + expectedSerialVersion
461+
+ ", actual " + serialVersion);
462+
}
463+
final int expectedFamilyId = Family.COUNTMIN.getID();
464+
if (familyId != expectedFamilyId) {
465+
throw new SketchesArgumentException("Family ID mismatch: expected " + expectedFamilyId
466+
+ ", actual " + familyId);
467+
}
398468

399-
final int numBuckets = buf.getInt();
400-
final byte numHashes = buf.get();
401-
final short seedHash = buf.getShort();
402-
final byte NULL_8 = buf.get();
469+
final int numBuckets = posSeg.getInt();
470+
final byte numHashes = posSeg.getByte();
471+
final short seedHash = posSeg.getShort();
472+
posSeg.getByte(); // skip NULL_8
403473

404474
if (seedHash != Util.computeSeedHash(seed)) {
405-
throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
406-
+ String.valueOf(Util.computeSeedHash(seed)));
475+
throw new SketchesArgumentException("Incompatible seed hashes: " + seedHash + ", "
476+
+ Util.computeSeedHash(seed));
407477
}
408478

409479
final CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
410480
final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0;
411481
if (empty) {
412482
return cms;
413483
}
414-
final long w = buf.getLong();
484+
final long w = posSeg.getLong();
415485
cms.totalWeight_ = w;
416486

417487
for (int i = 0; i < cms.sketchArray_.length; i++) {
418-
cms.sketchArray_[i] = buf.getLong();
488+
cms.sketchArray_[i] = posSeg.getLong();
419489
}
420490

421491
return cms;

src/test/java/org/apache/datasketches/count/CountMinSketchTest.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,7 @@ public void serializeDeserializeEmptyTest() {
204204
final long seed = 123456;
205205
final CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);
206206

207-
final ByteArrayOutputStream buf = new ByteArrayOutputStream();
208-
c.serialize(buf);
209-
210-
final byte[] b = buf.toByteArray();
207+
byte[] b = c.toByteArray();
211208
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1));
212209

213210
final CountMinSketch d = CountMinSketch.deserialize(b, seed);
@@ -229,11 +226,10 @@ public void serializeDeserializeTest() {
229226
c.update(i, 10*i*i);
230227
}
231228

232-
final ByteArrayOutputStream buf = new ByteArrayOutputStream();
233-
c.serialize(buf);
229+
byte[] b = c.toByteArray();
234230

235-
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1));
236-
final CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed);
231+
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1));
232+
CountMinSketch d = CountMinSketch.deserialize(b, seed);
237233

238234
assertEquals(d.getNumHashes_(), c.getNumHashes_());
239235
assertEquals(d.getNumBuckets_(), c.getNumBuckets_());

0 commit comments

Comments
 (0)