2323import org .apache .datasketches .common .SketchesArgumentException ;
2424import org .apache .datasketches .common .SketchesException ;
2525import org .apache .datasketches .common .Util ;
26+ import org .apache .datasketches .common .positional .PositionalSegment ;
2627import org .apache .datasketches .hash .MurmurHash3 ;
2728
2829import java .lang .foreign .MemorySegment ;
2930import java .nio .charset .StandardCharsets ;
30-
31- import static java .lang .foreign .ValueLayout .JAVA_BYTE ;
3231import java .util .Random ;
3332
33+ import static java .lang .foreign .ValueLayout .JAVA_BYTE ;
3434import static java .lang .foreign .ValueLayout .JAVA_INT_UNALIGNED ;
3535import static java .lang .foreign .ValueLayout .JAVA_LONG_UNALIGNED ;
3636import static java .lang .foreign .ValueLayout .JAVA_SHORT_UNALIGNED ;
@@ -46,7 +46,7 @@ public class CountMinSketch {
4646
4747 // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
4848 private static final ThreadLocal <MemorySegment > LONG_SEGMENT =
49- ThreadLocal .withInitial (() -> MemorySegment .ofArray (new byte [8 ]));
49+ ThreadLocal .withInitial (() -> MemorySegment .ofArray (new byte [Long . BYTES ]));
5050
5151 private enum Flag {
5252 IS_EMPTY ;
@@ -83,16 +83,16 @@ int mask() {
8383 // Use long arithmetic to detect overflow before casting
8484 final long totalSize = (long ) numHashes * (long ) numBuckets ;
8585 if (totalSize > Integer .MAX_VALUE ) {
86- throw new SketchesArgumentException ("Sketch array size would overflow: " + numHashes + " * " + numBuckets +
87- " = " + totalSize + " > " + Integer .MAX_VALUE );
86+ throw new SketchesArgumentException ("Sketch array size would overflow: " + numHashes + " * " + numBuckets
87+ + " = " + totalSize + " > " + Integer .MAX_VALUE );
8888 }
8989
9090 // This check is to ensure later compatibility with a Java implementation whose maximum size can only
9191 // be 2^31-1. We check only against 2^30 for simplicity.
9292 if (totalSize >= (1L << 30 )) {
93- throw new SketchesArgumentException ("Sketch would require excessive memory: " + numHashes + " * " + numBuckets +
94- " = " + totalSize + " elements (~" + String .format ("%.1f " , totalSize * 8.0 / (1024 * 1024 * 1024 )) + " GB). " +
95- "Consider reducing numHashes or numBuckets." );
93+ throw new SketchesArgumentException ("Sketch would require excessive memory: " + numHashes + " * " + numBuckets
94+ + " = " + totalSize + " elements (~" + String .format ("%d " , totalSize * Long . BYTES / (1024 * 1024 * 1024 )) + " GB). "
95+ + "Consider reducing numHashes or numBuckets." );
9696 }
9797
9898 numHashes_ = numHashes ;
@@ -102,7 +102,7 @@ int mask() {
102102 sketchArray_ = new long [(int ) totalSize ];
103103 totalWeight_ = 0 ;
104104
105- Random rand = new Random (seed );
105+ final Random rand = new Random (seed );
106106 for (int i = 0 ; i < numHashes ; i ++) {
107107 hashSeeds_ [i ] = rand .nextLong ();
108108 }
@@ -118,11 +118,11 @@ private static byte[] longToBytes(final long value) {
118118 }
119119
120120
121- private long [] getHashes (byte [] item ) {
122- long [] updateLocations = new long [numHashes_ ];
121+ private long [] getHashes (final byte [] item ) {
122+ final long [] updateLocations = new long [numHashes_ ];
123123
124124 for (int i = 0 ; i < numHashes_ ; i ++) {
125- long [] index = MurmurHash3 .hash (item , hashSeeds_ [i ]);
125+ final long [] index = MurmurHash3 .hash (item , hashSeeds_ [i ]);
126126 updateLocations [i ] = i * (long )numBuckets_ + Math .floorMod (index [0 ], numBuckets_ );
127127 }
128128
@@ -182,11 +182,11 @@ public double getRelativeError() {
182182 * @param confidence The desired confidence level between 0 and 1.
183183 * @return Suggested number of hash functions.
184184 */
185- public static byte suggestNumHashes (double confidence ) {
185+ public static byte suggestNumHashes (final double confidence ) {
186186 if (confidence < 0 || confidence > 1 ) {
187187 throw new SketchesException ("Confidence must be between 0 and 1.0 (inclusive)." );
188188 }
189- int value = (int ) Math .ceil (Math .log (1.0 / (1.0 - confidence )));
189+ final int value = (int ) Math .ceil (Math .log (1.0 / (1.0 - confidence )));
190190 return (byte ) Math .min (value , 127 );
191191 }
192192
@@ -195,7 +195,7 @@ public static byte suggestNumHashes(double confidence) {
195195 * @param relativeError The desired relative error.
196196 * @return Suggested number of buckets.
197197 */
198- public static int suggestNumBuckets (double relativeError ) {
198+ public static int suggestNumBuckets (final double relativeError ) {
199199 if (relativeError < 0. ) {
200200 throw new SketchesException ("Relative error must be at least 0." );
201201 }
@@ -235,8 +235,8 @@ public void update(final byte[] item, final long weight) {
235235 }
236236
237237 totalWeight_ += weight > 0 ? weight : -weight ;
238- long [] hashLocations = getHashes (item );
239- for (long h : hashLocations ) {
238+ final long [] hashLocations = getHashes (item );
239+ for (final long h : hashLocations ) {
240240 sketchArray_ [(int ) h ] += weight ;
241241 }
242242 }
@@ -274,7 +274,7 @@ public long getEstimate(final byte[] item) {
274274 return 0 ;
275275 }
276276
277- long [] hashLocations = getHashes (item );
277+ final long [] hashLocations = getHashes (item );
278278 long res = sketchArray_ [(int ) hashLocations [0 ]];
279279 // Start from index 1 to avoid processing first element twice
280280 for (int i = 1 ; i < hashLocations .length ; i ++) {
@@ -303,8 +303,8 @@ public long getUpperBound(final String item) {
303303 return 0 ;
304304 }
305305
306- byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
307- return getUpperBound (strByte );
306+ final byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
307+ return getUpperBound (strByte );
308308 }
309309
310310 /**
@@ -339,7 +339,7 @@ public long getLowerBound(final String item) {
339339 return 0 ;
340340 }
341341
342- byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
342+ final byte [] strByte = item .getBytes (StandardCharsets .UTF_8 );
343343 return getLowerBound (strByte );
344344 }
345345
@@ -361,8 +361,8 @@ public void merge(final CountMinSketch other) {
361361 throw new SketchesException ("Cannot merge a sketch with itself" );
362362 }
363363
364- boolean acceptableConfig = getNumBuckets_ () == other .getNumBuckets_ () &&
365- getNumHashes_ () == other .getNumHashes_ () && getSeed_ () == other .getSeed_ ();
364+ final boolean acceptableConfig = getNumBuckets_ () == other .getNumBuckets_ ()
365+ && getNumHashes_ () == other .getNumHashes_ () && getSeed_ () == other .getSeed_ ();
366366
367367 if (!acceptableConfig ) {
368368 throw new SketchesException ("Incompatible sketch configuration." );
@@ -392,46 +392,40 @@ private int getSerializedSizeBytes() {
392392 */
393393 public byte [] toByteArray () {
394394 final int serializedSizeBytes = getSerializedSizeBytes ();
395- final MemorySegment wseg = MemorySegment .ofArray (new byte [serializedSizeBytes ]);
396-
397- long offset = 0 ;
395+ final byte [] bytes = new byte [serializedSizeBytes ];
396+ final PositionalSegment posSeg = PositionalSegment .wrap (MemorySegment .ofArray (bytes ));
398397
399398 // Long 0
400399 final int preambleLongs = Family .COUNTMIN .getMinPreLongs ();
401- wseg . set ( JAVA_BYTE , offset ++, (byte ) preambleLongs );
400+ posSeg . setByte ( (byte ) preambleLongs );
402401 final int serialVersion = 1 ;
403- wseg . set ( JAVA_BYTE , offset ++, (byte ) serialVersion );
402+ posSeg . setByte ( (byte ) serialVersion );
404403 final int familyId = Family .COUNTMIN .getID ();
405- wseg . set ( JAVA_BYTE , offset ++, (byte ) familyId );
404+ posSeg . setByte ( (byte ) familyId );
406405 final int flagsByte = isEmpty () ? Flag .IS_EMPTY .mask () : 0 ;
407- wseg . set ( JAVA_BYTE , offset ++, (byte ) flagsByte );
406+ posSeg . setByte ( (byte ) flagsByte );
408407 final int NULL_32 = 0 ;
409- wseg .set (JAVA_INT_UNALIGNED , offset , NULL_32 );
410- offset += 4 ;
408+ posSeg .setInt (NULL_32 );
411409
412410 // Long 1
413- wseg .set (JAVA_INT_UNALIGNED , offset , numBuckets_ );
414- offset += 4 ;
415- wseg .set (JAVA_BYTE , offset ++, numHashes_ );
416- short hashSeed = Util .computeSeedHash (seed_ );
417- wseg .set (JAVA_SHORT_UNALIGNED , offset , hashSeed );
418- offset += 2 ;
411+ posSeg .setInt (numBuckets_ );
412+ posSeg .setByte (numHashes_ );
413+ final short hashSeed = Util .computeSeedHash (seed_ );
414+ posSeg .setShort (hashSeed );
419415 final byte NULL_8 = 0 ;
420- wseg . set ( JAVA_BYTE , offset ++, NULL_8 );
416+ posSeg . setByte ( NULL_8 );
421417
422418 if (isEmpty ()) {
423- return wseg . toArray ( JAVA_BYTE ) ;
419+ return bytes ;
424420 }
425421
426- wseg .set (JAVA_LONG_UNALIGNED , offset , totalWeight_ );
427- offset += 8 ;
422+ posSeg .setLong (totalWeight_ );
428423
429- for (long w : sketchArray_ ) {
430- wseg .set (JAVA_LONG_UNALIGNED , offset , w );
431- offset += 8 ;
424+ for (final long w : sketchArray_ ) {
425+ posSeg .setLong (w );
432426 }
433427
434- return wseg . toArray ( JAVA_BYTE ) ;
428+ return bytes ;
435429 }
436430
437431 /**
@@ -441,40 +435,51 @@ public byte[] toByteArray() {
441435 * @return The deserialized CountMinSketch.
442436 */
443437 public static CountMinSketch deserialize (final byte [] b , final long seed ) {
444- final MemorySegment buf = MemorySegment .ofArray (b );
445- long offset = 0 ;
446-
447- final byte preambleLongs = buf .get (JAVA_BYTE , offset ++);
448- final byte serialVersion = buf .get (JAVA_BYTE , offset ++);
449- final byte familyId = buf .get (JAVA_BYTE , offset ++);
450- final byte flagsByte = buf .get (JAVA_BYTE , offset ++);
451- final int NULL_32 = buf .get (JAVA_INT_UNALIGNED , offset );
452- offset += 4 ;
453-
454- final int numBuckets = buf .get (JAVA_INT_UNALIGNED , offset );
455- offset += 4 ;
456- final byte numHashes = buf .get (JAVA_BYTE , offset ++);
457- final short seedHash = buf .get (JAVA_SHORT_UNALIGNED , offset );
458- offset += 2 ;
459- final byte NULL_8 = buf .get (JAVA_BYTE , offset ++);
438+ final PositionalSegment posSeg = PositionalSegment .wrap (MemorySegment .ofArray (b ));
439+
440+ final byte preambleLongs = posSeg .getByte ();
441+ final byte serialVersion = posSeg .getByte ();
442+ final byte familyId = posSeg .getByte ();
443+ final byte flagsByte = posSeg .getByte ();
444+ posSeg .getInt (); // skip NULL_32
445+
446+ // Validate serialization format
447+ final int expectedPreambleLongs = Family .COUNTMIN .getMinPreLongs ();
448+ if (preambleLongs != expectedPreambleLongs ) {
449+ throw new SketchesArgumentException ("Preamble longs mismatch: expected " + expectedPreambleLongs
450+ + ", actual " + preambleLongs );
451+ }
452+ final int expectedSerialVersion = 1 ;
453+ if (serialVersion != expectedSerialVersion ) {
454+ throw new SketchesArgumentException ("Serial version mismatch: expected " + expectedSerialVersion
455+ + ", actual " + serialVersion );
456+ }
457+ final int expectedFamilyId = Family .COUNTMIN .getID ();
458+ if (familyId != expectedFamilyId ) {
459+ throw new SketchesArgumentException ("Family ID mismatch: expected " + expectedFamilyId
460+ + ", actual " + familyId );
461+ }
462+
463+ final int numBuckets = posSeg .getInt ();
464+ final byte numHashes = posSeg .getByte ();
465+ final short seedHash = posSeg .getShort ();
466+ posSeg .getByte (); // skip NULL_8
460467
461468 if (seedHash != Util .computeSeedHash (seed )) {
462- throw new SketchesArgumentException ("Incompatible seed hashes: " + String . valueOf ( seedHash ) + ", "
463- + String . valueOf ( Util .computeSeedHash (seed ) ));
469+ throw new SketchesArgumentException ("Incompatible seed hashes: " + seedHash + ", "
470+ + Util .computeSeedHash (seed ));
464471 }
465472
466- CountMinSketch cms = new CountMinSketch (numHashes , numBuckets , seed );
473+ final CountMinSketch cms = new CountMinSketch (numHashes , numBuckets , seed );
467474 final boolean empty = (flagsByte & Flag .IS_EMPTY .mask ()) > 0 ;
468475 if (empty ) {
469476 return cms ;
470477 }
471- long w = buf .get (JAVA_LONG_UNALIGNED , offset );
472- offset += 8 ;
478+ final long w = posSeg .getLong ();
473479 cms .totalWeight_ = w ;
474480
475481 for (int i = 0 ; i < cms .sketchArray_ .length ; i ++) {
476- cms .sketchArray_ [i ] = buf .get (JAVA_LONG_UNALIGNED , offset );
477- offset += 8 ;
482+ cms .sketchArray_ [i ] = posSeg .getLong ();
478483 }
479484
480485 return cms ;
0 commit comments