Skip to content

Commit c091214

Browse files
Fix build following cms/cpc recent PR
1 parent bc9bfd7 commit c091214

2 files changed

Lines changed: 52 additions & 22 deletions

File tree

src/main/java/org/apache/datasketches/count/CountMinSketch.java

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import org.apache.datasketches.common.Family;
2323
import org.apache.datasketches.common.SketchesArgumentException;
2424
import org.apache.datasketches.common.SketchesException;
25+
import org.apache.datasketches.common.Util;
2526
import org.apache.datasketches.hash.MurmurHash3;
26-
import org.apache.datasketches.tuple.Util;
2727

2828
import java.io.ByteArrayOutputStream;
2929
import java.nio.ByteBuffer;
@@ -39,6 +39,9 @@ public class CountMinSketch {
3939
private final long[] sketchArray_;
4040
private long totalWeight_;
4141

42+
// Thread-local ByteBuffer to avoid allocations in hot paths
43+
private static final ThreadLocal<ByteBuffer> LONG_BUFFER =
44+
ThreadLocal.withInitial(() -> ByteBuffer.allocate(8));
4245

4346
private enum Flag {
4447
IS_EMPTY;
@@ -57,30 +60,59 @@ int mask() {
5760
* @param seed The base hash seed
5861
*/
5962
CountMinSketch(final byte numHashes, final int numBuckets, final long seed) {
60-
numHashes_ = numHashes;
61-
numBuckets_ = numBuckets;
62-
seed_ = seed;
63-
hashSeeds_ = new long[numHashes];
64-
sketchArray_ = new long[numHashes * numBuckets];
65-
totalWeight_ = 0;
63+
// Validate numHashes
64+
if (numHashes <= 0) {
65+
throw new SketchesArgumentException("Number of hash functions must be positive, got: " + numHashes);
66+
}
6667

68+
// Validate numBuckets with clear mathematical justification
69+
if (numBuckets <= 0) {
70+
throw new SketchesArgumentException("Number of buckets must be positive, got: " + numBuckets);
71+
}
6772
if (numBuckets < 3) {
68-
throw new SketchesArgumentException("Using fewer than 3 buckets incurs relative error greater than 1.");
73+
throw new SketchesArgumentException("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " +
74+
"With " + numBuckets + " buckets, relative error would be " + String.format("%.3f", Math.exp(1.0) / numBuckets));
75+
}
76+
77+
// Check for potential overflow in array size calculation
78+
// Use long arithmetic to detect overflow before casting
79+
final long totalSize = (long) numHashes * (long) numBuckets;
80+
if (totalSize > Integer.MAX_VALUE) {
81+
throw new SketchesArgumentException("Sketch array size would overflow: " + numHashes + " * " + numBuckets +
82+
" = " + totalSize + " > " + Integer.MAX_VALUE);
6983
}
7084

7185
// This check is to ensure later compatibility with a Java implementation whose maximum size can only
7286
// be 2^31-1. We check only against 2^30 for simplicity.
73-
if (numBuckets * numHashes >= 1 << 30) {
74-
throw new SketchesArgumentException("These parameters generate a sketch that exceeds 2^30 elements. \n" +
75-
"Try reducing either the number of buckets or the number of hash functions.");
87+
if (totalSize >= (1L << 30)) {
88+
throw new SketchesArgumentException("Sketch would require excessive memory: " + numHashes + " * " + numBuckets +
89+
" = " + totalSize + " elements (~" + String.format("%.1f", totalSize * 8.0 / (1024 * 1024 * 1024)) + " GB). " +
90+
"Consider reducing numHashes or numBuckets.");
7691
}
7792

93+
numHashes_ = numHashes;
94+
numBuckets_ = numBuckets;
95+
seed_ = seed;
96+
hashSeeds_ = new long[numHashes];
97+
sketchArray_ = new long[(int) totalSize];
98+
totalWeight_ = 0;
99+
78100
Random rand = new Random(seed);
79101
for (int i = 0; i < numHashes; i++) {
80102
hashSeeds_[i] = rand.nextLong();
81103
}
82104
}
83105

106+
/**
107+
* Efficiently converts a long to byte array using thread-local buffer to avoid allocations.
108+
*/
109+
private static byte[] longToBytes(final long value) {
110+
final ByteBuffer buffer = LONG_BUFFER.get();
111+
buffer.clear();
112+
buffer.putLong(value);
113+
return buffer.array();
114+
}
115+
84116
private long[] getHashes(byte[] item) {
85117
long[] updateLocations = new long[numHashes_];
86118

@@ -171,8 +203,7 @@ public static int suggestNumBuckets(double relativeError) {
171203
* @param weight The weight of the item.
172204
*/
173205
public void update(final long item, final long weight) {
174-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
175-
update(longByte, weight);
206+
update(longToBytes(item), weight);
176207
}
177208

178209
/**
@@ -211,8 +242,7 @@ public void update(final byte[] item, final long weight) {
211242
* @return Estimated frequency.
212243
*/
213244
public long getEstimate(final long item) {
214-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
215-
return getEstimate(longByte);
245+
return getEstimate(longToBytes(item));
216246
}
217247

218248
/**
@@ -241,8 +271,9 @@ public long getEstimate(final byte[] item) {
241271

242272
long[] hashLocations = getHashes(item);
243273
long res = sketchArray_[(int) hashLocations[0]];
244-
for (long h : hashLocations) {
245-
res = Math.min(res, sketchArray_[(int) h]);
274+
// Start from index 1 to avoid processing first element twice
275+
for (int i = 1; i < hashLocations.length; i++) {
276+
res = Math.min(res, sketchArray_[(int) hashLocations[i]]);
246277
}
247278

248279
return res;
@@ -254,8 +285,7 @@ public long getEstimate(final byte[] item) {
254285
* @return Upper bound of estimated frequency.
255286
*/
256287
public long getUpperBound(final long item) {
257-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
258-
return getUpperBound(longByte);
288+
return getUpperBound(longToBytes(item));
259289
}
260290

261291
/**
@@ -291,8 +321,7 @@ public long getUpperBound(final byte[] item) {
291321
* @return Lower bound of estimated frequency.
292322
*/
293323
public long getLowerBound(final long item) {
294-
byte[] longByte = ByteBuffer.allocate(8).putLong(item).array();
295-
return getLowerBound(longByte);
324+
return getLowerBound(longToBytes(item));
296325
}
297326

298327
/**

src/test/java/org/apache/datasketches/cpc/CpcSketchCrossLanguageTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.io.IOException;
3232
import java.nio.file.Files;
3333

34+
import org.apache.datasketches.memory.Memory;
3435
import org.testng.annotations.Test;
3536

3637
/**
@@ -89,7 +90,7 @@ public void checkAllFlavorsGo() throws IOException {
8990
int flavorIdx = 0;
9091
for (int n: nArr) {
9192
final byte[] bytes = Files.readAllBytes(goPath.resolve("cpc_n" + n + "_go.sk"));
92-
final CpcSketch sketch = CpcSketch.heapify(Memory.wrap(bytes));
93+
final CpcSketch sketch = CpcSketch.heapify(MemorySegment.ofArray(bytes));
9394
assertEquals(sketch.getFlavor(), flavorArr[flavorIdx++]);
9495
assertEquals(sketch.getEstimate(), n, n * 0.02);
9596
}

0 commit comments

Comments
 (0)