2525import org .apache .datasketches .common .Util ;
2626import org .apache .datasketches .hash .MurmurHash3 ;
2727
28- import java .io .ByteArrayOutputStream ;
29- import java .nio .ByteBuffer ;
28+ import java .lang .foreign .MemorySegment ;
3029import java .nio .charset .StandardCharsets ;
30+
31+ import static java .lang .foreign .ValueLayout .JAVA_BYTE ;
3132import java .util .Random ;
3233
34+ import static org .apache .datasketches .common .SpecialValueLayouts .JAVA_INT_UNALIGNED_BIG_ENDIAN ;
35+ import static org .apache .datasketches .common .SpecialValueLayouts .JAVA_LONG_UNALIGNED_BIG_ENDIAN ;
36+ import static org .apache .datasketches .common .SpecialValueLayouts .JAVA_SHORT_UNALIGNED_BIG_ENDIAN ;
37+
3338
3439public class CountMinSketch {
3540 private final byte numHashes_ ;
@@ -39,9 +44,9 @@ public class CountMinSketch {
3944 private final long [] sketchArray_ ;
4045 private long totalWeight_ ;
4146
42- // Thread-local ByteBuffer to avoid allocations in hot paths
43- private static final ThreadLocal <ByteBuffer > LONG_BUFFER =
44- ThreadLocal .withInitial (() -> ByteBuffer . allocate ( 8 ));
47+ // Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
48+ private static final ThreadLocal <MemorySegment > LONG_SEGMENT =
49+ ThreadLocal .withInitial (() -> MemorySegment . ofArray ( new byte [ 8 ] ));
4550
4651 private enum Flag {
4752 IS_EMPTY ;
@@ -104,15 +109,15 @@ int mask() {
104109 }
105110
106111 /**
107- * Efficiently converts a long to byte array using thread-local buffer to avoid allocations .
112+ * Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness .
108113 */
109114 private static byte [] longToBytes (final long value ) {
110- final ByteBuffer buffer = LONG_BUFFER .get ();
111- buffer .clear ();
112- buffer .putLong (value );
113- return buffer .array ();
115+ final MemorySegment segment = LONG_SEGMENT .get ();
116+ segment .set (JAVA_LONG_UNALIGNED_BIG_ENDIAN , 0 , value );
117+ return segment .toArray (JAVA_BYTE );
114118 }
115119
120+
116121 private long [] getHashes (byte [] item ) {
117122 long [] updateLocations = new long [numHashes_ ];
118123
@@ -371,39 +376,62 @@ public void merge(final CountMinSketch other) {
371376 }
372377
373378 /**
374- * Serializes the sketch into the provided ByteBuffer.
375- * @param buf The ByteBuffer to write into.
379+ * Returns the serialized size in bytes.
376380 */
377- public void serialize (ByteArrayOutputStream buf ) {
381+ private int getSerializedSizeBytes () {
382+ final int preambleBytes = Family .COUNTMIN .getMinPreLongs () * Long .BYTES ;
383+ if (isEmpty ()) {
384+ return preambleBytes ;
385+ }
386+ return preambleBytes + Long .BYTES + (sketchArray_ .length * Long .BYTES );
387+ }
388+
389+
390+ /**
391+ * Returns the sketch as a byte array.
392+ */
393+ public byte [] toByteArray () {
394+ final int serializedSizeBytes = getSerializedSizeBytes ();
395+ final MemorySegment wseg = MemorySegment .ofArray (new byte [serializedSizeBytes ]);
396+
397+ long offset = 0 ;
398+
378399 // Long 0
379400 final int preambleLongs = Family .COUNTMIN .getMinPreLongs ();
380- buf . write ( (byte ) preambleLongs );
401+ wseg . set ( JAVA_BYTE , offset ++, (byte ) preambleLongs );
381402 final int serialVersion = 1 ;
382- buf . write ( (byte ) serialVersion );
403+ wseg . set ( JAVA_BYTE , offset ++, (byte ) serialVersion );
383404 final int familyId = Family .COUNTMIN .getID ();
384- buf . write ( (byte ) familyId );
405+ wseg . set ( JAVA_BYTE , offset ++, (byte ) familyId );
385406 final int flagsByte = isEmpty () ? Flag .IS_EMPTY .mask () : 0 ;
386- buf . write ( (byte )flagsByte );
407+ wseg . set ( JAVA_BYTE , offset ++, (byte ) flagsByte );
387408 final int NULL_32 = 0 ;
388- buf .writeBytes (ByteBuffer .allocate (4 ).putInt (NULL_32 ).array ());
409+ wseg .set (JAVA_INT_UNALIGNED_BIG_ENDIAN , offset , NULL_32 );
410+ offset += 4 ;
389411
390412 // Long 1
391- buf .writeBytes (ByteBuffer .allocate (4 ).putInt (numBuckets_ ).array ());
392- buf .write (numHashes_ );
413+ wseg .set (JAVA_INT_UNALIGNED_BIG_ENDIAN , offset , numBuckets_ );
414+ offset += 4 ;
415+ wseg .set (JAVA_BYTE , offset ++, numHashes_ );
393416 short hashSeed = Util .computeSeedHash (seed_ );
394- buf .writeBytes (ByteBuffer .allocate (2 ).putShort (hashSeed ).array ());
417+ wseg .set (JAVA_SHORT_UNALIGNED_BIG_ENDIAN , offset , hashSeed );
418+ offset += 2 ;
395419 final byte NULL_8 = 0 ;
396- buf .write (NULL_8 );
420+ wseg .set (JAVA_BYTE , offset ++, NULL_8 );
421+
397422 if (isEmpty ()) {
398- return ;
423+ return wseg . toArray ( JAVA_BYTE ) ;
399424 }
400425
401- final byte [] totWeightByte = ByteBuffer . allocate ( 8 ). putLong ( totalWeight_ ). array ( );
402- buf . writeBytes ( totWeightByte ) ;
426+ wseg . set ( JAVA_LONG_UNALIGNED_BIG_ENDIAN , offset , totalWeight_ );
427+ offset += 8 ;
403428
404429 for (long w : sketchArray_ ) {
405- buf .writeBytes (ByteBuffer .allocate (8 ).putLong (w ).array ());
430+ wseg .set (JAVA_LONG_UNALIGNED_BIG_ENDIAN , offset , w );
431+ offset += 8 ;
406432 }
433+
434+ return wseg .toArray (JAVA_BYTE );
407435 }
408436
409437 /**
@@ -413,20 +441,22 @@ public void serialize(ByteArrayOutputStream buf) {
413441 * @return The deserialized CountMinSketch.
414442 */
415443 public static CountMinSketch deserialize (final byte [] b , final long seed ) {
416- ByteBuffer buf = ByteBuffer .allocate (b .length );
417- buf .put (b );
418- buf .flip ();
419-
420- final byte preambleLongs = buf .get ();
421- final byte serialVersion = buf .get ();
422- final byte familyId = buf .get ();
423- final byte flagsByte = buf .get ();
424- final int NULL_32 = buf .getInt ();
425-
426- final int numBuckets = buf .getInt ();
427- final byte numHashes = buf .get ();
428- final short seedHash = buf .getShort ();
429- final byte NULL_8 = buf .get ();
444+ final MemorySegment buf = MemorySegment .ofArray (b );
445+ long offset = 0 ;
446+
447+ final byte preambleLongs = buf .get (JAVA_BYTE , offset ++);
448+ final byte serialVersion = buf .get (JAVA_BYTE , offset ++);
449+ final byte familyId = buf .get (JAVA_BYTE , offset ++);
450+ final byte flagsByte = buf .get (JAVA_BYTE , offset ++);
451+ final int NULL_32 = buf .get (JAVA_INT_UNALIGNED_BIG_ENDIAN , offset );
452+ offset += 4 ;
453+
454+ final int numBuckets = buf .get (JAVA_INT_UNALIGNED_BIG_ENDIAN , offset );
455+ offset += 4 ;
456+ final byte numHashes = buf .get (JAVA_BYTE , offset ++);
457+ final short seedHash = buf .get (JAVA_SHORT_UNALIGNED_BIG_ENDIAN , offset );
458+ offset += 2 ;
459+ final byte NULL_8 = buf .get (JAVA_BYTE , offset ++);
430460
431461 if (seedHash != Util .computeSeedHash (seed )) {
432462 throw new SketchesArgumentException ("Incompatible seed hashes: " + String .valueOf (seedHash ) + ", "
@@ -438,11 +468,13 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) {
438468 if (empty ) {
439469 return cms ;
440470 }
441- long w = buf .getLong ();
471+ long w = buf .get (JAVA_LONG_UNALIGNED_BIG_ENDIAN , offset );
472+ offset += 8 ;
442473 cms .totalWeight_ = w ;
443474
444475 for (int i = 0 ; i < cms .sketchArray_ .length ; i ++) {
445- cms .sketchArray_ [i ] = buf .getLong ();
476+ cms .sketchArray_ [i ] = buf .get (JAVA_LONG_UNALIGNED_BIG_ENDIAN , offset );
477+ offset += 8 ;
446478 }
447479
448480 return cms ;
0 commit comments