1919
2020package org .apache .datasketches .count ;
2121
22- import java .io .ByteArrayOutputStream ;
23- import java .nio .ByteBuffer ;
22+ import static java .lang .foreign .ValueLayout .JAVA_BYTE ;
23+ import static java .lang .foreign .ValueLayout .JAVA_LONG_UNALIGNED ;
24+
25+ import java .lang .foreign .MemorySegment ;
2426import java .nio .charset .StandardCharsets ;
2527import java .util .Random ;
2628
2729import org .apache .datasketches .common .Family ;
2830import org .apache .datasketches .common .SketchesArgumentException ;
2931import org .apache .datasketches .common .SketchesException ;
3032import org .apache .datasketches .common .Util ;
33+ import org .apache .datasketches .common .positional .PositionalSegment ;
3134import org .apache .datasketches .hash .MurmurHash3 ;
3235
3336/**
34- * CountMinSketch.
37+ * Java implementation of the CountMin sketch data structure of Cormode and Muthukrishnan.
38+ * This implementation is inspired by and compatible with the datasketches-cpp version by Charlie Dickens.
39+ *
40+ * The CountMin sketch is a probabilistic data structure that provides frequency estimates for items
41+ * in a data stream. It uses multiple hash functions to distribute items across a two-dimensional array,
42+ * providing approximate counts with configurable error bounds.
43+ *
44+ * Reference: http://dimacs.rutgers.edu/~graham/pubs/papers/cm-full.pdf
3545 */
3646public class CountMinSketch {
3747 private final byte numHashes_ ;
@@ -41,6 +51,10 @@ public class CountMinSketch {
4151 private final long [] sketchArray_ ;
4252 private long totalWeight_ ;
4353
54+ // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
55+ private static final ThreadLocal <MemorySegment > LONG_SEGMENT =
56+ ThreadLocal .withInitial (() -> MemorySegment .ofArray (new byte [Long .BYTES ]));
57+
4458 private enum Flag {
4559 IS_EMPTY ;
4660
@@ -58,30 +72,58 @@ int mask() {
5872 * @param seed The base hash seed
5973 */
6074 CountMinSketch (final byte numHashes , final int numBuckets , final long seed ) {
61- numHashes_ = numHashes ;
62- numBuckets_ = numBuckets ;
63- seed_ = seed ;
64- hashSeeds_ = new long [numHashes ];
65- sketchArray_ = new long [numHashes * numBuckets ];
66- totalWeight_ = 0 ;
75+ // Validate numHashes
76+ if (numHashes <= 0 ) {
77+ throw new SketchesArgumentException ("Number of hash functions must be positive, got: " + numHashes );
78+ }
6779
80+ // Validate numBuckets with clear mathematical justification
81+ if (numBuckets <= 0 ) {
82+ throw new SketchesArgumentException ("Number of buckets must be positive, got: " + numBuckets );
83+ }
6884 if (numBuckets < 3 ) {
69- throw new SketchesArgumentException ("Using fewer than 3 buckets incurs relative error greater than 1." );
85+ throw new SketchesArgumentException ("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. "
86+ + "With " + numBuckets + " buckets, relative error would be " + String .format ("%.3f" , Math .exp (1.0 ) / numBuckets ));
87+ }
88+
89+ // Check for potential overflow in array size calculation
90+ // Use long arithmetic to detect overflow before casting
91+ final long totalSize = (long ) numHashes * (long ) numBuckets ;
92+ if (totalSize > Integer .MAX_VALUE ) {
93+ throw new SketchesArgumentException ("Sketch array size would overflow: " + numHashes + " * " + numBuckets
94+ + " = " + totalSize + " > " + Integer .MAX_VALUE );
7095 }
7196
7297 // This check is to ensure later compatibility with a Java implementation whose maximum size can only
7398 // be 2^31-1. We check only against 2^30 for simplicity.
74- if ((numBuckets * numHashes ) >= (1 << 30 )) {
75- throw new SketchesArgumentException ("These parameters generate a sketch that exceeds 2^30 elements. \b "
76- + "Try reducing either the number of buckets or the number of hash functions." );
99+ if (totalSize >= (1L << 30 )) {
100+ throw new SketchesArgumentException ("Sketch would require excessive memory: " + numHashes + " * " + numBuckets
101+ + " = " + totalSize + " elements (~" + String .format ("%d" , (totalSize * Long .BYTES ) / (1024 * 1024 * 1024 )) + " GB). "
102+ + "Consider reducing numHashes or numBuckets." );
77103 }
78104
105+ numHashes_ = numHashes ;
106+ numBuckets_ = numBuckets ;
107+ seed_ = seed ;
108+ hashSeeds_ = new long [numHashes ];
109+ sketchArray_ = new long [(int ) totalSize ];
110+ totalWeight_ = 0 ;
111+
79112 final Random rand = new Random (seed );
80113 for (int i = 0 ; i < numHashes ; i ++) {
81114 hashSeeds_ [i ] = rand .nextLong ();
82115 }
83116 }
84117
118+ /**
119+ * Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness.
120+ */
121+ private static byte [] longToBytes (final long value ) {
122+ final MemorySegment segment = LONG_SEGMENT .get ();
123+ segment .set (JAVA_LONG_UNALIGNED , 0 , value );
124+ return segment .toArray (JAVA_BYTE );
125+ }
126+
85127 private long [] getHashes (final byte [] item ) {
86128 final long [] updateLocations = new long [numHashes_ ];
87129
@@ -172,8 +214,7 @@ public static int suggestNumBuckets(final double relativeError) {
172214 * @param weight The weight of the item.
173215 */
174216 public void update (final long item , final long weight ) {
175- final byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
176- update (longByte , weight );
217+ update (longToBytes (item ), weight );
177218 }
178219
179220 /**
@@ -212,8 +253,7 @@ public void update(final byte[] item, final long weight) {
212253 * @return Estimated frequency.
213254 */
214255 public long getEstimate (final long item ) {
215- final byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
216- return getEstimate (longByte );
256+ return getEstimate (longToBytes (item ));
217257 }
218258
219259 /**
@@ -242,8 +282,9 @@ public long getEstimate(final byte[] item) {
242282
243283 final long [] hashLocations = getHashes (item );
244284 long res = sketchArray_ [(int ) hashLocations [0 ]];
245- for (final long h : hashLocations ) {
246- res = Math .min (res , sketchArray_ [(int ) h ]);
285+ // Start from index 1 to avoid processing first element twice
286+ for (int i = 1 ; i < hashLocations .length ; i ++) {
287+ res = Math .min (res , sketchArray_ [(int ) hashLocations [i ]]);
247288 }
248289
249290 return res ;
@@ -255,8 +296,7 @@ public long getEstimate(final byte[] item) {
255296 * @return Upper bound of estimated frequency.
256297 */
257298 public long getUpperBound (final long item ) {
258- final byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
259- return getUpperBound (longByte );
299+ return getUpperBound (longToBytes (item ));
260300 }
261301
262302 /**
@@ -270,7 +310,7 @@ public long getUpperBound(final String item) {
270310 }
271311
272312 final byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
273- return getUpperBound (strByte );
313+ return getUpperBound (strByte );
274314 }
275315
276316 /**
@@ -292,8 +332,7 @@ public long getUpperBound(final byte[] item) {
292332 * @return Lower bound of estimated frequency.
293333 */
294334 public long getLowerBound (final long item ) {
295- final byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
296- return getLowerBound (longByte );
335+ return getLowerBound (longToBytes (item ));
297336 }
298337
299338 /**
@@ -343,39 +382,56 @@ public void merge(final CountMinSketch other) {
343382 }
344383
345384 /**
346- * Serializes the sketch into the provided ByteBuffer.
347- * @param buf The ByteBuffer to write into.
385+ * Returns the serialized size in bytes.
348386 */
349- public void serialize (final ByteArrayOutputStream buf ) {
387+ private int getSerializedSizeBytes () {
388+ final int preambleBytes = Family .COUNTMIN .getMinPreLongs () * Long .BYTES ;
389+ if (isEmpty ()) {
390+ return preambleBytes ;
391+ }
392+ return preambleBytes + Long .BYTES + (sketchArray_ .length * Long .BYTES );
393+ }
394+
395+ /**
396+ * Returns the sketch as a byte array.
397+ * @return the result byte array
398+ */
399+ public byte [] toByteArray () {
400+ final int serializedSizeBytes = getSerializedSizeBytes ();
401+ final byte [] bytes = new byte [serializedSizeBytes ];
402+ final PositionalSegment posSeg = PositionalSegment .wrap (MemorySegment .ofArray (bytes ));
403+
350404 // Long 0
351405 final int preambleLongs = Family .COUNTMIN .getMinPreLongs ();
352- buf . write ((byte ) preambleLongs );
406+ posSeg . setByte ((byte ) preambleLongs );
353407 final int serialVersion = 1 ;
354- buf . write ((byte ) serialVersion );
408+ posSeg . setByte ((byte ) serialVersion );
355409 final int familyId = Family .COUNTMIN .getID ();
356- buf . write ((byte ) familyId );
410+ posSeg . setByte ((byte ) familyId );
357411 final int flagsByte = isEmpty () ? Flag .IS_EMPTY .mask () : 0 ;
358- buf . write ((byte )flagsByte );
412+ posSeg . setByte ((byte ) flagsByte );
359413 final int NULL_32 = 0 ;
360- buf . writeBytes ( ByteBuffer . allocate ( 4 ). putInt ( NULL_32 ). array () );
414+ posSeg . setInt ( NULL_32 );
361415
362416 // Long 1
363- buf . writeBytes ( ByteBuffer . allocate ( 4 ). putInt ( numBuckets_ ). array () );
364- buf . write (numHashes_ );
417+ posSeg . setInt ( numBuckets_ );
418+ posSeg . setByte (numHashes_ );
365419 final short hashSeed = Util .computeSeedHash (seed_ );
366- buf . writeBytes ( ByteBuffer . allocate ( 2 ). putShort ( hashSeed ). array () );
420+ posSeg . setShort ( hashSeed );
367421 final byte NULL_8 = 0 ;
368- buf .write (NULL_8 );
422+ posSeg .setByte (NULL_8 );
423+
369424 if (isEmpty ()) {
370- return ;
425+ return bytes ;
371426 }
372427
373- final byte [] totWeightByte = ByteBuffer .allocate (8 ).putLong (totalWeight_ ).array ();
374- buf .writeBytes (totWeightByte );
428+ posSeg .setLong (totalWeight_ );
375429
376430 for (final long w : sketchArray_ ) {
377- buf . writeBytes ( ByteBuffer . allocate ( 8 ). putLong ( w ). array () );
431+ posSeg . setLong ( w );
378432 }
433+
434+ return bytes ;
379435 }
380436
381437 /**
@@ -384,38 +440,52 @@ public void serialize(final ByteArrayOutputStream buf) {
384440 * @param seed The seed used during serialization.
385441 * @return The deserialized CountMinSketch.
386442 */
387- @ SuppressWarnings ("unused" )
388443 public static CountMinSketch deserialize (final byte [] b , final long seed ) {
389- final ByteBuffer buf = ByteBuffer .allocate (b .length );
390- buf .put (b );
391- buf .flip ();
392-
393- final byte preambleLongs = buf .get ();
394- final byte serialVersion = buf .get ();
395- final byte familyId = buf .get ();
396- final byte flagsByte = buf .get ();
397- final int NULL_32 = buf .getInt ();
444+ final PositionalSegment posSeg = PositionalSegment .wrap (MemorySegment .ofArray (b ));
445+
446+ final byte preambleLongs = posSeg .getByte ();
447+ final byte serialVersion = posSeg .getByte ();
448+ final byte familyId = posSeg .getByte ();
449+ final byte flagsByte = posSeg .getByte ();
450+ posSeg .getInt (); // skip NULL_32
451+
452+ // Validate serialization format
453+ final int expectedPreambleLongs = Family .COUNTMIN .getMinPreLongs ();
454+ if (preambleLongs != expectedPreambleLongs ) {
455+ throw new SketchesArgumentException ("Preamble longs mismatch: expected " + expectedPreambleLongs
456+ + ", actual " + preambleLongs );
457+ }
458+ final int expectedSerialVersion = 1 ;
459+ if (serialVersion != expectedSerialVersion ) {
460+ throw new SketchesArgumentException ("Serial version mismatch: expected " + expectedSerialVersion
461+ + ", actual " + serialVersion );
462+ }
463+ final int expectedFamilyId = Family .COUNTMIN .getID ();
464+ if (familyId != expectedFamilyId ) {
465+ throw new SketchesArgumentException ("Family ID mismatch: expected " + expectedFamilyId
466+ + ", actual " + familyId );
467+ }
398468
399- final int numBuckets = buf .getInt ();
400- final byte numHashes = buf . get ();
401- final short seedHash = buf .getShort ();
402- final byte NULL_8 = buf . get ();
469+ final int numBuckets = posSeg .getInt ();
470+ final byte numHashes = posSeg . getByte ();
471+ final short seedHash = posSeg .getShort ();
472+ posSeg . getByte (); // skip NULL_8
403473
404474 if (seedHash != Util .computeSeedHash (seed )) {
405- throw new SketchesArgumentException ("Incompatible seed hashes: " + String . valueOf ( seedHash ) + ", "
406- + String . valueOf ( Util .computeSeedHash (seed ) ));
475+ throw new SketchesArgumentException ("Incompatible seed hashes: " + seedHash + ", "
476+ + Util .computeSeedHash (seed ));
407477 }
408478
409479 final CountMinSketch cms = new CountMinSketch (numHashes , numBuckets , seed );
410480 final boolean empty = (flagsByte & Flag .IS_EMPTY .mask ()) > 0 ;
411481 if (empty ) {
412482 return cms ;
413483 }
414- final long w = buf .getLong ();
484+ final long w = posSeg .getLong ();
415485 cms .totalWeight_ = w ;
416486
417487 for (int i = 0 ; i < cms .sketchArray_ .length ; i ++) {
418- cms .sketchArray_ [i ] = buf .getLong ();
488+ cms .sketchArray_ [i ] = posSeg .getLong ();
419489 }
420490
421491 return cms ;
0 commit comments