diff --git a/bench-jmh/src/main/scala/org/apache/pekko/stream/BroadcastHubBenchRunner.scala b/bench-jmh/src/main/scala/org/apache/pekko/stream/BroadcastHubBenchRunner.scala new file mode 100644 index 0000000000..0efe9e8d49 --- /dev/null +++ b/bench-jmh/src/main/scala/org/apache/pekko/stream/BroadcastHubBenchRunner.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.pekko.stream + +import java.util.concurrent.{ CountDownLatch, TimeUnit } + +import scala.concurrent.Await +import scala.concurrent.duration._ + +import org.apache.pekko.actor.ActorSystem +import org.apache.pekko.remote.artery.{ BenchTestSource, LatchSink } +import org.apache.pekko.stream.scaladsl._ + +import com.typesafe.config.ConfigFactory + +/** + * Standalone benchmark runner for BroadcastHub consumer wheel performance. + * Run with: sbt "bench-jmh/runMain org.apache.pekko.stream.BroadcastHubBenchRunner" + */ +object BroadcastHubBenchRunner { + + final val Elements = 100000 + final val SmallBuffer = 64 + final val LargeBuffer = 256 + final val WarmupRuns = 2 + final val MeasureRuns = 3 + + def main(args: Array[String]): Unit = { + val config = ConfigFactory.parseString(""" + pekko.actor.default-dispatcher { + executor = "fork-join-executor" + fork-join-executor { + parallelism-factor = 1 + } + } + """) + + val consumerCounts = Array(64, 256, 1000, 2000) + + println("=" * 80) + println("BroadcastHub Consumer Wheel Benchmark") + println(s"Elements per run: $Elements") + println(s"Warmup: $WarmupRuns runs, Measure: $MeasureRuns runs") + println("=" * 80) + + for (bufferSize <- Array(SmallBuffer, LargeBuffer)) { + println(s"\n--- Buffer size: $bufferSize (wheel slots: ${bufferSize * 2}) ---") + println(f"${"Consumers"}%-12s ${"Avg (elem/s)"}%16s ${"Min"}%12s ${"Max"}%12s ${"StdDev"}%10s") + println("-" * 70) + + for (consumerCount <- consumerCounts) { + implicit val system: ActorSystem = ActorSystem(s"bench-$consumerCount-$bufferSize", config) + + // eager init + SystemMaterializer(system).materializer + + val results = new Array[Double](WarmupRuns + MeasureRuns) + + for (run <- 0 until WarmupRuns + MeasureRuns) { + val latch = new CountDownLatch(consumerCount) + val broadcastSink = + BroadcastHub.sink[java.lang.Integer](bufferSize = bufferSize, startAfterNrOfConsumers = consumerCount) + val testSource = Source.fromGraph(new BenchTestSource(Elements)) + val source = testSource.runWith(broadcastSink) + + val start = System.nanoTime() + var idx = 0 + while (idx < consumerCount) { + source.runWith(new LatchSink(Elements, latch)) + idx += 1 + } + + if (!latch.await(120, TimeUnit.SECONDS)) { + println(s" TIMEOUT at consumers=$consumerCount buffer=$bufferSize run=$run") + Await.result(system.terminate(), 10.seconds) + System.exit(1) + } + val elapsed = (System.nanoTime() - start) / 1e9 + results(run) = Elements / elapsed + } + + val measured = results.drop(WarmupRuns) + val avg = measured.sum / measured.length + val min = measured.min + val max = measured.max + val variance = measured.map(x => (x - avg) * (x - avg)).sum / measured.length + val stddev = math.sqrt(variance) + + println(f"$consumerCount%-12d $avg%16.0f $min%12.0f $max%12.0f $stddev%10.0f") + + Await.result(system.terminate(), 10.seconds) + } + } + + println("\n" + "=" * 80) + println("Done.") + } +} diff --git a/bench-jmh/src/main/scala/org/apache/pekko/stream/BroadcastHubBenchmark.scala b/bench-jmh/src/main/scala/org/apache/pekko/stream/BroadcastHubBenchmark.scala index c46bcf5aee..687f695183 100644 --- a/bench-jmh/src/main/scala/org/apache/pekko/stream/BroadcastHubBenchmark.scala +++ b/bench-jmh/src/main/scala/org/apache/pekko/stream/BroadcastHubBenchmark.scala @@ -32,8 +32,24 @@ import org.apache.pekko.stream.testkit.scaladsl.StreamTestKit import com.typesafe.config.ConfigFactory +/** + * Benchmarks BroadcastHub throughput under high-fan-out lockstep consumer scenarios. + * + * The consumer wheel uses a LongMap per slot for O(1) keyed add/remove without Long boxing. + * In lockstep, all consumers cluster in the same wheel slot, maximizing per-slot contention. + * With a small buffer (64), the wheel has only 128 slots, so `consumerCount / 128` consumers + * share each slot — the old ArrayList.removeIf was O(k) per removal, now O(1). + * + * The `broadcast` benchmark parameterizes over consumer count with a fixed small buffer, + * measuring how throughput scales as wheel slot pressure increases. + * + * The `broadcastLargeBuffer` benchmark uses a larger buffer (256) for comparison, + * showing how the optimization holds up when consumers are spread across more slots. + */ object BroadcastHubBenchmark { final val OperationsPerInvocation = 100000 + final val SmallBufferSize = 64 + final val LargeBufferSize = 256 } @State(Scope.Benchmark) @@ -56,7 +72,7 @@ class BroadcastHubBenchmark { var testSource: Source[java.lang.Integer, NotUsed] = _ - @Param(Array("64", "256")) + @Param(Array("64", "256", "1000", "2000")) var parallelism = 0 @Setup @@ -71,12 +87,40 @@ class BroadcastHubBenchmark { Await.result(system.terminate(), 5.seconds) } + /** + * Lockstep broadcast with small buffer (64). + * All consumers stay at roughly the same wheel offset, clustering in the same slot. + * With 128 wheel slots and 2000 consumers, ~16 consumers share each slot on average; + * during NeedWakeup bursts, thousands cluster in a single slot. + * This maximizes the O(1) vs O(k) per-removal difference. + */ @Benchmark @OperationsPerInvocation(OperationsPerInvocation) def broadcast(): Unit = { val latch = new CountDownLatch(parallelism) val broadcastSink = - BroadcastHub.sink[java.lang.Integer](bufferSize = parallelism, startAfterNrOfConsumers = parallelism) + BroadcastHub.sink[java.lang.Integer](bufferSize = SmallBufferSize, startAfterNrOfConsumers = parallelism) + val sink = new LatchSink(OperationsPerInvocation, latch) + val source = testSource.runWith(broadcastSink) + var idx = 0 + while (idx < parallelism) { + source.runWith(sink) + idx += 1 + } + awaitLatch(latch) + } + + /** + * Lockstep broadcast with larger buffer (256) for comparison. + * The wheel has 512 slots, so consumers are spread more thinly. + * Shows how the optimization scales when per-slot pressure is lower. + */ + @Benchmark + @OperationsPerInvocation(OperationsPerInvocation) + def broadcastLargeBuffer(): Unit = { + val latch = new CountDownLatch(parallelism) + val broadcastSink = + BroadcastHub.sink[java.lang.Integer](bufferSize = LargeBufferSize, startAfterNrOfConsumers = parallelism) val sink = new LatchSink(OperationsPerInvocation, latch) val source = testSource.runWith(broadcastSink) var idx = 0 @@ -88,7 +132,7 @@ class BroadcastHubBenchmark { } private def awaitLatch(latch: CountDownLatch): Unit = { - if (!latch.await(30, TimeUnit.SECONDS)) { + if (!latch.await(60, TimeUnit.SECONDS)) { StreamTestKit.printDebugDump(SystemMaterializer(system).materializer.supervisor) throw new RuntimeException("Latch didn't complete in time") } diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala index a6d57dee19..b32c8746b1 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala @@ -629,6 +629,51 @@ class HubSpec extends StreamSpec { in.sendComplete() sinkProbe2.cancel() } + + "deliver all elements in order to many consumers" in { + val consumerCount = 200 + val messageCount = 2000 + + val source = Source(0 until messageCount).runWith(BroadcastHub.sink(bufferSize = 256, + startAfterNrOfConsumers = consumerCount)) + + val futures = (0 until consumerCount).map { _ => + source.runWith(Sink.seq) + } + + val results = Await.result(Future.sequence(futures), 30.seconds) + results.foreach { result => + result should ===(0 until messageCount) + } + } + + "handle many consumers when some cancel mid-stream" in { + val totalConsumers = 64 + val cancellingConsumers = 16 + val cancelAfter = 64 + val messageCount = 512 + + val source = Source(0 until messageCount).runWith( + BroadcastHub.sink(bufferSize = 256, startAfterNrOfConsumers = totalConsumers)) + + val cancellingFutures = (0 until cancellingConsumers).map { _ => + source.take(cancelAfter).runWith(Sink.seq) + } + + val remainingFutures = (0 until (totalConsumers - cancellingConsumers)).map { _ => + source.runWith(Sink.seq) + } + + val cancellingResults = Await.result(Future.sequence(cancellingFutures), 30.seconds) + cancellingResults.foreach { result => + result should ===(0 until cancelAfter) + } + + val remainingResults = Await.result(Future.sequence(remainingFutures), 30.seconds) + remainingResults.foreach { result => + result should ===(0 until messageCount) + } + } } "PartitionHub" must { diff --git a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala index 4b6ca0063e..8849529e6d 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala @@ -536,14 +536,17 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I * of priorities always fall to a range * * This wheel tracks the position of Consumers relative to the slowest ones. Every slot - * contains a list of Consumers being known at that location (this might be out of date!). + * contains a map of Consumers being known at that location (this might be out of date!). * Consumers from time to time send Advance messages to indicate that they have progressed * by reading from the broadcast queue. Consumers that are blocked (due to reaching tail) request * a wakeup and update their position at the same time. * + * Each slot uses a LongMap keyed by Consumer.id for O(1) add/remove without Long boxing. + * Empty slots are null (no backing map allocated), reducing baseline memory and GC pressure. + * When a slot drains to zero consumers, its map is released (set to null). */ private[this] val consumerWheel = - Array.fill[java.util.ArrayList[Consumer]](bufferSize * 2)(new util.ArrayList[Consumer]()) + new Array[LongMap[Consumer]](bufferSize * 2) private[this] var activeConsumers = 0 override def preStart(): Unit = { @@ -574,15 +577,19 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I val newOffset = previousOffset + DemandThreshold // Move the consumer from its last known offset to its new one. Check if we are unblocked. val consumer = findAndRemoveConsumer(id, previousOffset) - addConsumer(consumer, newOffset) + if (consumer ne null) { + addConsumer(consumer, newOffset) + } checkUnblock(previousOffset) case NeedWakeup(id, previousOffset, currentOffset) => // Move the consumer from its last known offset to its new one. Check if we are unblocked. val consumer = findAndRemoveConsumer(id, previousOffset) - addConsumer(consumer, currentOffset) + if (consumer ne null) { + addConsumer(consumer, currentOffset) - // Also check if the consumer is now unblocked since we published an element since it went asleep. - if (currentOffset != tail) consumer.callback.invoke(Wakeup) + // Also check if the consumer is now unblocked since we published an element since it went asleep. + if (currentOffset != tail) consumer.callback.invoke(Wakeup) + } checkUnblock(previousOffset) case RegistrationPending => @@ -650,10 +657,14 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I consumer.callback.invoke(failMessage) } - // Notify registered consumers + // Notify registered consumers — skip null (empty) slots var idx = 0 while (idx < consumerWheel.length) { - consumerWheel(idx).forEach(_.callback.invoke(failMessage)) + val bucket = consumerWheel(idx) + if (bucket ne null) { + val itr = bucket.valuesIterator + while (itr.hasNext) itr.next().callback.invoke(failMessage) + } idx += 1 } failStage(ex) @@ -664,21 +675,19 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I * * NB: You cannot remove a consumer without knowing its last offset! Consumers on the Source side always must * track this so this can be a fast operation. + * + * Uses LongMap.getOrNull + -= to avoid Option allocation on the hot path. */ private def findAndRemoveConsumer(id: Long, offset: Int): Consumer = { - // TODO: Try to eliminate modulo division somehow... val wheelSlot = offset & WheelMask - val consumersInSlot = consumerWheel(wheelSlot) - var removedConsumer: Consumer = null - if (consumersInSlot.size() > 0) { - consumersInSlot.removeIf(consumer => { - if (consumer.id == id) { - removedConsumer = consumer - true - } else false - }) + val bucket = consumerWheel(wheelSlot) + if (bucket eq null) return null + val consumer = bucket.getOrNull(id) + if (consumer ne null) { + bucket -= id + if (bucket.isEmpty) consumerWheel(wheelSlot) = null } - removedConsumer + consumer } /* @@ -697,7 +706,7 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I if (offsetOfConsumerRemoved == head) { // Try to advance along the wheel. We can skip any wheel slots which have no waiting Consumers, until // we either find a nonempty one, or we reached the end of the buffer. - while (consumerWheel(head & WheelMask).isEmpty && head != tail) { + while (isConsumerWheelSlotEmpty(head & WheelMask) && head != tail) { queue(head & Mask) = null head += 1 unblocked = true @@ -706,18 +715,35 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I unblocked } + private def isConsumerWheelSlotEmpty(slot: Int): Boolean = { + val bucket = consumerWheel(slot) + (bucket eq null) || bucket.isEmpty + } + private def addConsumer(consumer: Consumer, offset: Int): Unit = { val slot = offset & WheelMask - consumerWheel(slot).add(consumer) + val bucket = consumerWheel(slot) + if (bucket ne null) bucket.update(consumer.id, consumer) + else { + val newBucket = LongMap.empty[Consumer] + newBucket.update(consumer.id, consumer) + consumerWheel(slot) = newBucket + } } /* * Send a wakeup signal to all the Consumers at a certain wheel index. Note, this needs the actual index, * which is offset modulo (bufferSize + 1). + * + * Enumeration order of the bucket is not semantically significant — every consumer receives the same + * wakeup signal independently. */ private def wakeupIdx(idx: Int): Unit = { - val itr = consumerWheel(idx).iterator - while (itr.hasNext) itr.next().callback.invoke(Wakeup) + val bucket = consumerWheel(idx) + if (bucket ne null) { + val itr = bucket.valuesIterator + while (itr.hasNext) itr.next().callback.invoke(Wakeup) + } } private def complete(): Unit = {