2222import org .apache .datasketches .common .Family ;
2323import org .apache .datasketches .common .SketchesArgumentException ;
2424import org .apache .datasketches .common .SketchesException ;
25+ import org .apache .datasketches .common .Util ;
2526import org .apache .datasketches .hash .MurmurHash3 ;
26- import org .apache .datasketches .tuple .Util ;
2727
2828import java .io .ByteArrayOutputStream ;
2929import java .nio .ByteBuffer ;
@@ -39,6 +39,9 @@ public class CountMinSketch {
3939 private final long [] sketchArray_ ;
4040 private long totalWeight_ ;
4141
42+ // Thread-local ByteBuffer to avoid allocations in hot paths
43+ private static final ThreadLocal <ByteBuffer > LONG_BUFFER =
44+ ThreadLocal .withInitial (() -> ByteBuffer .allocate (8 ));
4245
4346 private enum Flag {
4447 IS_EMPTY ;
@@ -57,30 +60,59 @@ int mask() {
5760 * @param seed The base hash seed
5861 */
5962 CountMinSketch (final byte numHashes , final int numBuckets , final long seed ) {
60- numHashes_ = numHashes ;
61- numBuckets_ = numBuckets ;
62- seed_ = seed ;
63- hashSeeds_ = new long [numHashes ];
64- sketchArray_ = new long [numHashes * numBuckets ];
65- totalWeight_ = 0 ;
63+ // Validate numHashes
64+ if (numHashes <= 0 ) {
65+ throw new SketchesArgumentException ("Number of hash functions must be positive, got: " + numHashes );
66+ }
6667
68+ // Validate numBuckets with clear mathematical justification
69+ if (numBuckets <= 0 ) {
70+ throw new SketchesArgumentException ("Number of buckets must be positive, got: " + numBuckets );
71+ }
6772 if (numBuckets < 3 ) {
68- throw new SketchesArgumentException ("Using fewer than 3 buckets incurs relative error greater than 1." );
73+ throw new SketchesArgumentException ("Number of buckets must be at least 3 to ensure relative error ≤ 1.0. " +
74+ "With " + numBuckets + " buckets, relative error would be " + String .format ("%.3f" , Math .exp (1.0 ) / numBuckets ));
75+ }
76+
77+ // Check for potential overflow in array size calculation
78+ // Use long arithmetic to detect overflow before casting
79+ final long totalSize = (long ) numHashes * (long ) numBuckets ;
80+ if (totalSize > Integer .MAX_VALUE ) {
81+ throw new SketchesArgumentException ("Sketch array size would overflow: " + numHashes + " * " + numBuckets +
82+ " = " + totalSize + " > " + Integer .MAX_VALUE );
6983 }
7084
7185 // This check is to ensure later compatibility with a Java implementation whose maximum size can only
7286 // be 2^31-1. We check only against 2^30 for simplicity.
73- if (numBuckets * numHashes >= 1 << 30 ) {
74- throw new SketchesArgumentException ("These parameters generate a sketch that exceeds 2^30 elements. \n " +
75- "Try reducing either the number of buckets or the number of hash functions." );
87+ if (totalSize >= (1L << 30 )) {
88+ throw new SketchesArgumentException ("Sketch would require excessive memory: " + numHashes + " * " + numBuckets +
89+ " = " + totalSize + " elements (~" + String .format ("%.1f" , totalSize * 8.0 / (1024 * 1024 * 1024 )) + " GB). " +
90+ "Consider reducing numHashes or numBuckets." );
7691 }
7792
93+ numHashes_ = numHashes ;
94+ numBuckets_ = numBuckets ;
95+ seed_ = seed ;
96+ hashSeeds_ = new long [numHashes ];
97+ sketchArray_ = new long [(int ) totalSize ];
98+ totalWeight_ = 0 ;
99+
78100 Random rand = new Random (seed );
79101 for (int i = 0 ; i < numHashes ; i ++) {
80102 hashSeeds_ [i ] = rand .nextLong ();
81103 }
82104 }
83105
106+ /**
107+ * Efficiently converts a long to byte array using thread-local buffer to avoid allocations.
108+ */
109+ private static byte [] longToBytes (final long value ) {
110+ final ByteBuffer buffer = LONG_BUFFER .get ();
111+ buffer .clear ();
112+ buffer .putLong (value );
113+ return buffer .array ();
114+ }
115+
84116 private long [] getHashes (byte [] item ) {
85117 long [] updateLocations = new long [numHashes_ ];
86118
@@ -171,8 +203,7 @@ public static int suggestNumBuckets(double relativeError) {
171203 * @param weight The weight of the item.
172204 */
173205 public void update (final long item , final long weight ) {
174- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
175- update (longByte , weight );
206+ update (longToBytes (item ), weight );
176207 }
177208
178209 /**
@@ -211,8 +242,7 @@ public void update(final byte[] item, final long weight) {
211242 * @return Estimated frequency.
212243 */
213244 public long getEstimate (final long item ) {
214- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
215- return getEstimate (longByte );
245+ return getEstimate (longToBytes (item ));
216246 }
217247
218248 /**
@@ -241,8 +271,9 @@ public long getEstimate(final byte[] item) {
241271
242272 long [] hashLocations = getHashes (item );
243273 long res = sketchArray_ [(int ) hashLocations [0 ]];
244- for (long h : hashLocations ) {
245- res = Math .min (res , sketchArray_ [(int ) h ]);
274+ // Start from index 1 to avoid processing first element twice
275+ for (int i = 1 ; i < hashLocations .length ; i ++) {
276+ res = Math .min (res , sketchArray_ [(int ) hashLocations [i ]]);
246277 }
247278
248279 return res ;
@@ -254,8 +285,7 @@ public long getEstimate(final byte[] item) {
254285 * @return Upper bound of estimated frequency.
255286 */
256287 public long getUpperBound (final long item ) {
257- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
258- return getUpperBound (longByte );
288+ return getUpperBound (longToBytes (item ));
259289 }
260290
261291 /**
@@ -291,8 +321,7 @@ public long getUpperBound(final byte[] item) {
291321 * @return Lower bound of estimated frequency.
292322 */
293323 public long getLowerBound (final long item ) {
294- byte [] longByte = ByteBuffer .allocate (8 ).putLong (item ).array ();
295- return getLowerBound (longByte );
324+ return getLowerBound (longToBytes (item ));
296325 }
297326
298327 /**
0 commit comments