Skip to content

Commit e0445ef

Browse files
use MemorySegment instead of ByteBuffer in CMS
1 parent c6c90c3 commit e0445ef

2 files changed

Lines changed: 78 additions & 50 deletions

File tree

src/main/java/org/apache/datasketches/count/CountMinSketch.java

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,16 @@
2525
import org.apache.datasketches.common.Util;
2626
import org.apache.datasketches.hash.MurmurHash3;
2727

28-
import java.io.ByteArrayOutputStream;
29-
import java.nio.ByteBuffer;
28+
import java.lang.foreign.MemorySegment;
3029
import java.nio.charset.StandardCharsets;
30+
31+
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
3132
import java.util.Random;
3233

34+
import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_INT_UNALIGNED_BIG_ENDIAN;
35+
import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_LONG_UNALIGNED_BIG_ENDIAN;
36+
import static org.apache.datasketches.common.SpecialValueLayouts.JAVA_SHORT_UNALIGNED_BIG_ENDIAN;
37+
3338

3439
public class CountMinSketch {
3540
private final byte numHashes_;
@@ -39,9 +44,9 @@ public class CountMinSketch {
3944
private final long[] sketchArray_;
4045
private long totalWeight_;
4146

42-
// Thread-local ByteBuffer to avoid allocations in hot paths
43-
private static final ThreadLocal<ByteBuffer> LONG_BUFFER =
44-
ThreadLocal.withInitial(() -> ByteBuffer.allocate(8));
47+
// Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
48+
private static final ThreadLocal<MemorySegment> LONG_SEGMENT =
49+
ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[8]));
4550

4651
private enum Flag {
4752
IS_EMPTY;
@@ -104,15 +109,15 @@ int mask() {
104109
}
105110

106111
/**
107-
* Efficiently converts a long to byte array using thread-local buffer to avoid allocations.
112+
* Efficiently converts a long to byte array using thread-local MemorySegment with explicit endianness.
108113
*/
109114
private static byte[] longToBytes(final long value) {
110-
final ByteBuffer buffer = LONG_BUFFER.get();
111-
buffer.clear();
112-
buffer.putLong(value);
113-
return buffer.array();
115+
final MemorySegment segment = LONG_SEGMENT.get();
116+
segment.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, 0, value);
117+
return segment.toArray(JAVA_BYTE);
114118
}
115119

120+
116121
private long[] getHashes(byte[] item) {
117122
long[] updateLocations = new long[numHashes_];
118123

@@ -371,39 +376,62 @@ public void merge(final CountMinSketch other) {
371376
}
372377

373378
/**
374-
* Serializes the sketch into the provided ByteBuffer.
375-
* @param buf The ByteBuffer to write into.
379+
* Returns the serialized size in bytes.
376380
*/
377-
public void serialize(ByteArrayOutputStream buf) {
381+
private int getSerializedSizeBytes() {
382+
final int preambleBytes = Family.COUNTMIN.getMinPreLongs() * Long.BYTES;
383+
if (isEmpty()) {
384+
return preambleBytes;
385+
}
386+
return preambleBytes + Long.BYTES + (sketchArray_.length * Long.BYTES);
387+
}
388+
389+
390+
/**
391+
* Returns the sketch as a byte array.
392+
*/
393+
public byte[] toByteArray() {
394+
final int serializedSizeBytes = getSerializedSizeBytes();
395+
final MemorySegment wseg = MemorySegment.ofArray(new byte[serializedSizeBytes]);
396+
397+
long offset = 0;
398+
378399
// Long 0
379400
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
380-
buf.write((byte) preambleLongs);
401+
wseg.set(JAVA_BYTE, offset++, (byte) preambleLongs);
381402
final int serialVersion = 1;
382-
buf.write((byte) serialVersion);
403+
wseg.set(JAVA_BYTE, offset++, (byte) serialVersion);
383404
final int familyId = Family.COUNTMIN.getID();
384-
buf.write((byte) familyId);
405+
wseg.set(JAVA_BYTE, offset++, (byte) familyId);
385406
final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0;
386-
buf.write((byte)flagsByte);
407+
wseg.set(JAVA_BYTE, offset++, (byte) flagsByte);
387408
final int NULL_32 = 0;
388-
buf.writeBytes(ByteBuffer.allocate(4).putInt(NULL_32).array());
409+
wseg.set(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset, NULL_32);
410+
offset += 4;
389411

390412
// Long 1
391-
buf.writeBytes(ByteBuffer.allocate(4).putInt(numBuckets_).array());
392-
buf.write(numHashes_);
413+
wseg.set(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset, numBuckets_);
414+
offset += 4;
415+
wseg.set(JAVA_BYTE, offset++, numHashes_);
393416
short hashSeed = Util.computeSeedHash(seed_);
394-
buf.writeBytes(ByteBuffer.allocate(2).putShort(hashSeed).array());
417+
wseg.set(JAVA_SHORT_UNALIGNED_BIG_ENDIAN, offset, hashSeed);
418+
offset += 2;
395419
final byte NULL_8 = 0;
396-
buf.write(NULL_8);
420+
wseg.set(JAVA_BYTE, offset++, NULL_8);
421+
397422
if (isEmpty()) {
398-
return;
423+
return wseg.toArray(JAVA_BYTE);
399424
}
400425

401-
final byte[] totWeightByte = ByteBuffer.allocate(8).putLong(totalWeight_).array();
402-
buf.writeBytes(totWeightByte);
426+
wseg.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset, totalWeight_);
427+
offset += 8;
403428

404429
for (long w: sketchArray_) {
405-
buf.writeBytes(ByteBuffer.allocate(8).putLong(w).array());
430+
wseg.set(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset, w);
431+
offset += 8;
406432
}
433+
434+
return wseg.toArray(JAVA_BYTE);
407435
}
408436

409437
/**
@@ -413,20 +441,22 @@ public void serialize(ByteArrayOutputStream buf) {
413441
* @return The deserialized CountMinSketch.
414442
*/
415443
public static CountMinSketch deserialize(final byte[] b, final long seed) {
416-
ByteBuffer buf = ByteBuffer.allocate(b.length);
417-
buf.put(b);
418-
buf.flip();
419-
420-
final byte preambleLongs = buf.get();
421-
final byte serialVersion = buf.get();
422-
final byte familyId = buf.get();
423-
final byte flagsByte = buf.get();
424-
final int NULL_32 = buf.getInt();
425-
426-
final int numBuckets = buf.getInt();
427-
final byte numHashes = buf.get();
428-
final short seedHash = buf.getShort();
429-
final byte NULL_8 = buf.get();
444+
final MemorySegment buf = MemorySegment.ofArray(b);
445+
long offset = 0;
446+
447+
final byte preambleLongs = buf.get(JAVA_BYTE, offset++);
448+
final byte serialVersion = buf.get(JAVA_BYTE, offset++);
449+
final byte familyId = buf.get(JAVA_BYTE, offset++);
450+
final byte flagsByte = buf.get(JAVA_BYTE, offset++);
451+
final int NULL_32 = buf.get(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset);
452+
offset += 4;
453+
454+
final int numBuckets = buf.get(JAVA_INT_UNALIGNED_BIG_ENDIAN, offset);
455+
offset += 4;
456+
final byte numHashes = buf.get(JAVA_BYTE, offset++);
457+
final short seedHash = buf.get(JAVA_SHORT_UNALIGNED_BIG_ENDIAN, offset);
458+
offset += 2;
459+
final byte NULL_8 = buf.get(JAVA_BYTE, offset++);
430460

431461
if (seedHash != Util.computeSeedHash(seed)) {
432462
throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
@@ -438,11 +468,13 @@ public static CountMinSketch deserialize(final byte[] b, final long seed) {
438468
if (empty) {
439469
return cms;
440470
}
441-
long w = buf.getLong();
471+
long w = buf.get(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset);
472+
offset += 8;
442473
cms.totalWeight_ = w;
443474

444475
for (int i = 0; i < cms.sketchArray_.length; i++) {
445-
cms.sketchArray_[i] = buf.getLong();
476+
cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED_BIG_ENDIAN, offset);
477+
offset += 8;
446478
}
447479

448480
return cms;

src/test/java/org/apache/datasketches/count/CountMinSketchTest.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,7 @@ public void serializeDeserializeEmptyTest() {
203203
final long seed = 123456;
204204
CountMinSketch c = new CountMinSketch(numHashes, numBuckets, seed);
205205

206-
ByteArrayOutputStream buf = new ByteArrayOutputStream();
207-
c.serialize(buf);
208-
209-
byte[] b = buf.toByteArray();
206+
byte[] b = c.toByteArray();
210207
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1));
211208

212209
CountMinSketch d = CountMinSketch.deserialize(b, seed);
@@ -228,11 +225,10 @@ public void serializeDeserializeTest() {
228225
c.update(i, 10*i*i);
229226
}
230227

231-
ByteArrayOutputStream buf = new ByteArrayOutputStream();
232-
c.serialize(buf);
228+
byte[] b = c.toByteArray();
233229

234-
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(buf.toByteArray(), seed - 1));
235-
CountMinSketch d = CountMinSketch.deserialize(buf.toByteArray(), seed);
230+
assertThrows(SketchesArgumentException.class, () -> CountMinSketch.deserialize(b, seed - 1));
231+
CountMinSketch d = CountMinSketch.deserialize(b, seed);
236232

237233
assertEquals(d.getNumHashes_(), c.getNumHashes_());
238234
assertEquals(d.getNumBuckets_(), c.getNumBuckets_());

0 commit comments

Comments
 (0)