Skip to content

Commit 8bd1423

Browse files
Address all remaining review comments
1 parent ea9fda9 commit 8bd1423

1 file changed

Lines changed: 74 additions & 69 deletions

File tree

src/main/java/org/apache/datasketches/count/CountMinSketch.java

Lines changed: 74 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
import org.apache.datasketches.common.SketchesArgumentException;
2424
import org.apache.datasketches.common.SketchesException;
2525
import org.apache.datasketches.common.Util;
26+
import org.apache.datasketches.common.positional.PositionalSegment;
2627
import org.apache.datasketches.hash.MurmurHash3;
2728

2829
import java.lang.foreign.MemorySegment;
2930
import java.nio.charset.StandardCharsets;
30-
31-
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
3231
import java.util.Random;
3332

33+
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
3434
import static java.lang.foreign.ValueLayout.JAVA_INT_UNALIGNED;
3535
import static java.lang.foreign.ValueLayout.JAVA_LONG_UNALIGNED;
3636
import static java.lang.foreign.ValueLayout.JAVA_SHORT_UNALIGNED;
@@ -46,7 +46,7 @@ public class CountMinSketch {
4646

4747
// Thread-local MemorySegment to avoid allocations in hot paths with explicit endianness control
4848
private static final ThreadLocal<MemorySegment> LONG_SEGMENT =
49-
ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[8]));
49+
ThreadLocal.withInitial(() -> MemorySegment.ofArray(new byte[Long.BYTES]));
5050

5151
private enum Flag {
5252
IS_EMPTY;
@@ -83,16 +83,16 @@ int mask() {
8383
// Use long arithmetic to detect overflow before casting
8484
final long totalSize = (long) numHashes * (long) numBuckets;
8585
if (totalSize > Integer.MAX_VALUE) {
86-
throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets +
87-
" = " + totalSize + " > " + Integer.MAX_VALUE);
86+
throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets
87+
+ " = " + totalSize + " > " + Integer.MAX_VALUE);
8888
}
8989

9090
// This check is to ensure later compatibility with a Java implementation whose maximum size can only
9191
// be 2^31-1. We check only against 2^30 for simplicity.
9292
if (totalSize >= (1L << 30)) {
93-
throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets +
94-
" = " + totalSize + " elements (~" + String.format("%.1f", totalSize * 8.0 / (1024 * 1024 * 1024)) + " GB). " +
95-
"Consider reducing numHashes or numBuckets.");
93+
throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets
94+
+ " = " + totalSize + " elements (~" + String.format("%d", totalSize * Long.BYTES / (1024 * 1024 * 1024)) + " GB). "
95+
+ "Consider reducing numHashes or numBuckets.");
9696
}
9797

9898
numHashes_ = numHashes;
@@ -102,7 +102,7 @@ int mask() {
102102
sketchArray_ = new long[(int) totalSize];
103103
totalWeight_ = 0;
104104

105-
Random rand = new Random(seed);
105+
final Random rand = new Random(seed);
106106
for (int i = 0; i < numHashes; i++) {
107107
hashSeeds_[i] = rand.nextLong();
108108
}
@@ -118,11 +118,11 @@ private static byte[] longToBytes(final long value) {
118118
}
119119

120120

121-
private long[] getHashes(byte[] item) {
122-
long[] updateLocations = new long[numHashes_];
121+
private long[] getHashes(final byte[] item) {
122+
final long[] updateLocations = new long[numHashes_];
123123

124124
for (int i = 0; i < numHashes_; i++) {
125-
long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
125+
final long[] index = MurmurHash3.hash(item, hashSeeds_[i]);
126126
updateLocations[i] = i * (long)numBuckets_ + Math.floorMod(index[0], numBuckets_);
127127
}
128128

@@ -182,11 +182,11 @@ public double getRelativeError() {
182182
* @param confidence The desired confidence level between 0 and 1.
183183
* @return Suggested number of hash functions.
184184
*/
185-
public static byte suggestNumHashes(double confidence) {
185+
public static byte suggestNumHashes(final double confidence) {
186186
if (confidence < 0 || confidence > 1) {
187187
throw new SketchesException("Confidence must be between 0 and 1.0 (inclusive).");
188188
}
189-
int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence)));
189+
final int value = (int) Math.ceil(Math.log(1.0 / (1.0 - confidence)));
190190
return (byte) Math.min(value, 127);
191191
}
192192

@@ -195,7 +195,7 @@ public static byte suggestNumHashes(double confidence) {
195195
* @param relativeError The desired relative error.
196196
* @return Suggested number of buckets.
197197
*/
198-
public static int suggestNumBuckets(double relativeError) {
198+
public static int suggestNumBuckets(final double relativeError) {
199199
if (relativeError < 0.) {
200200
throw new SketchesException("Relative error must be at least 0.");
201201
}
@@ -235,8 +235,8 @@ public void update(final byte[] item, final long weight) {
235235
}
236236

237237
totalWeight_ += weight > 0 ? weight : -weight;
238-
long[] hashLocations = getHashes(item);
239-
for (long h : hashLocations) {
238+
final long[] hashLocations = getHashes(item);
239+
for (final long h : hashLocations) {
240240
sketchArray_[(int) h] += weight;
241241
}
242242
}
@@ -274,7 +274,7 @@ public long getEstimate(final byte[] item) {
274274
return 0;
275275
}
276276

277-
long[] hashLocations = getHashes(item);
277+
final long[] hashLocations = getHashes(item);
278278
long res = sketchArray_[(int) hashLocations[0]];
279279
// Start from index 1 to avoid processing first element twice
280280
for (int i = 1; i < hashLocations.length; i++) {
@@ -303,8 +303,8 @@ public long getUpperBound(final String item) {
303303
return 0;
304304
}
305305

306-
byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
307-
return getUpperBound(strByte);
306+
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
307+
return getUpperBound(strByte);
308308
}
309309

310310
/**
@@ -339,7 +339,7 @@ public long getLowerBound(final String item) {
339339
return 0;
340340
}
341341

342-
byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
342+
final byte[] strByte = item.getBytes(StandardCharsets.UTF_8);
343343
return getLowerBound(strByte);
344344
}
345345

@@ -361,8 +361,8 @@ public void merge(final CountMinSketch other) {
361361
throw new SketchesException("Cannot merge a sketch with itself");
362362
}
363363

364-
boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_() &&
365-
getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_();
364+
final boolean acceptableConfig = getNumBuckets_() == other.getNumBuckets_()
365+
&& getNumHashes_() == other.getNumHashes_() && getSeed_() == other.getSeed_();
366366

367367
if (!acceptableConfig) {
368368
throw new SketchesException("Incompatible sketch configuration.");
@@ -392,46 +392,40 @@ private int getSerializedSizeBytes() {
392392
*/
393393
public byte[] toByteArray() {
394394
final int serializedSizeBytes = getSerializedSizeBytes();
395-
final MemorySegment wseg = MemorySegment.ofArray(new byte[serializedSizeBytes]);
396-
397-
long offset = 0;
395+
final byte[] bytes = new byte[serializedSizeBytes];
396+
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(bytes));
398397

399398
// Long 0
400399
final int preambleLongs = Family.COUNTMIN.getMinPreLongs();
401-
wseg.set(JAVA_BYTE, offset++, (byte) preambleLongs);
400+
posSeg.setByte((byte) preambleLongs);
402401
final int serialVersion = 1;
403-
wseg.set(JAVA_BYTE, offset++, (byte) serialVersion);
402+
posSeg.setByte((byte) serialVersion);
404403
final int familyId = Family.COUNTMIN.getID();
405-
wseg.set(JAVA_BYTE, offset++, (byte) familyId);
404+
posSeg.setByte((byte) familyId);
406405
final int flagsByte = isEmpty() ? Flag.IS_EMPTY.mask() : 0;
407-
wseg.set(JAVA_BYTE, offset++, (byte) flagsByte);
406+
posSeg.setByte((byte) flagsByte);
408407
final int NULL_32 = 0;
409-
wseg.set(JAVA_INT_UNALIGNED, offset, NULL_32);
410-
offset += 4;
408+
posSeg.setInt(NULL_32);
411409

412410
// Long 1
413-
wseg.set(JAVA_INT_UNALIGNED, offset, numBuckets_);
414-
offset += 4;
415-
wseg.set(JAVA_BYTE, offset++, numHashes_);
416-
short hashSeed = Util.computeSeedHash(seed_);
417-
wseg.set(JAVA_SHORT_UNALIGNED, offset, hashSeed);
418-
offset += 2;
411+
posSeg.setInt(numBuckets_);
412+
posSeg.setByte(numHashes_);
413+
final short hashSeed = Util.computeSeedHash(seed_);
414+
posSeg.setShort(hashSeed);
419415
final byte NULL_8 = 0;
420-
wseg.set(JAVA_BYTE, offset++, NULL_8);
416+
posSeg.setByte(NULL_8);
421417

422418
if (isEmpty()) {
423-
return wseg.toArray(JAVA_BYTE);
419+
return bytes;
424420
}
425421

426-
wseg.set(JAVA_LONG_UNALIGNED, offset, totalWeight_);
427-
offset += 8;
422+
posSeg.setLong(totalWeight_);
428423

429-
for (long w: sketchArray_) {
430-
wseg.set(JAVA_LONG_UNALIGNED, offset, w);
431-
offset += 8;
424+
for (final long w: sketchArray_) {
425+
posSeg.setLong(w);
432426
}
433427

434-
return wseg.toArray(JAVA_BYTE);
428+
return bytes;
435429
}
436430

437431
/**
@@ -441,40 +435,51 @@ public byte[] toByteArray() {
441435
* @return The deserialized CountMinSketch.
442436
*/
443437
public static CountMinSketch deserialize(final byte[] b, final long seed) {
444-
final MemorySegment buf = MemorySegment.ofArray(b);
445-
long offset = 0;
446-
447-
final byte preambleLongs = buf.get(JAVA_BYTE, offset++);
448-
final byte serialVersion = buf.get(JAVA_BYTE, offset++);
449-
final byte familyId = buf.get(JAVA_BYTE, offset++);
450-
final byte flagsByte = buf.get(JAVA_BYTE, offset++);
451-
final int NULL_32 = buf.get(JAVA_INT_UNALIGNED, offset);
452-
offset += 4;
453-
454-
final int numBuckets = buf.get(JAVA_INT_UNALIGNED, offset);
455-
offset += 4;
456-
final byte numHashes = buf.get(JAVA_BYTE, offset++);
457-
final short seedHash = buf.get(JAVA_SHORT_UNALIGNED, offset);
458-
offset += 2;
459-
final byte NULL_8 = buf.get(JAVA_BYTE, offset++);
438+
final PositionalSegment posSeg = PositionalSegment.wrap(MemorySegment.ofArray(b));
439+
440+
final byte preambleLongs = posSeg.getByte();
441+
final byte serialVersion = posSeg.getByte();
442+
final byte familyId = posSeg.getByte();
443+
final byte flagsByte = posSeg.getByte();
444+
posSeg.getInt(); // skip NULL_32
445+
446+
// Validate serialization format
447+
final int expectedPreambleLongs = Family.COUNTMIN.getMinPreLongs();
448+
if (preambleLongs != expectedPreambleLongs) {
449+
throw new SketchesArgumentException("Preamble longs mismatch: expected " + expectedPreambleLongs
450+
+ ", actual " + preambleLongs);
451+
}
452+
final int expectedSerialVersion = 1;
453+
if (serialVersion != expectedSerialVersion) {
454+
throw new SketchesArgumentException("Serial version mismatch: expected " + expectedSerialVersion
455+
+ ", actual " + serialVersion);
456+
}
457+
final int expectedFamilyId = Family.COUNTMIN.getID();
458+
if (familyId != expectedFamilyId) {
459+
throw new SketchesArgumentException("Family ID mismatch: expected " + expectedFamilyId
460+
+ ", actual " + familyId);
461+
}
462+
463+
final int numBuckets = posSeg.getInt();
464+
final byte numHashes = posSeg.getByte();
465+
final short seedHash = posSeg.getShort();
466+
posSeg.getByte(); // skip NULL_8
460467

461468
if (seedHash != Util.computeSeedHash(seed)) {
462-
throw new SketchesArgumentException("Incompatible seed hashes: " + String.valueOf(seedHash) + ", "
463-
+ String.valueOf(Util.computeSeedHash(seed)));
469+
throw new SketchesArgumentException("Incompatible seed hashes: " + seedHash + ", "
470+
+ Util.computeSeedHash(seed));
464471
}
465472

466-
CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
473+
final CountMinSketch cms = new CountMinSketch(numHashes, numBuckets, seed);
467474
final boolean empty = (flagsByte & Flag.IS_EMPTY.mask()) > 0;
468475
if (empty) {
469476
return cms;
470477
}
471-
long w = buf.get(JAVA_LONG_UNALIGNED, offset);
472-
offset += 8;
478+
final long w = posSeg.getLong();
473479
cms.totalWeight_ = w;
474480

475481
for (int i = 0; i < cms.sketchArray_.length; i++) {
476-
cms.sketchArray_[i] = buf.get(JAVA_LONG_UNALIGNED, offset);
477-
offset += 8;
482+
cms.sketchArray_[i] = posSeg.getLong();
478483
}
479484

480485
return cms;

0 commit comments

Comments
 (0)