From a01a2fccf3a4bdebeba6c311301d35d6665d69d2 Mon Sep 17 00:00:00 2001 From: Stephen DeRosa Date: Wed, 6 May 2026 21:30:12 -0600 Subject: [PATCH] pimple implementation: baseline build with ~111s, after impl was ~86s. ~22% --- include/livekit/audio_stream.h | 35 +- include/livekit/data_track_stream.h | 40 +- include/livekit/room.h | 28 +- .../livekit/subscription_thread_dispatcher.h | 61 +-- include/livekit/video_stream.h | 32 +- src/audio_stream.cpp | 335 ++++++------- src/data_track_stream.cpp | 183 ++++--- src/room.cpp | 461 ++++++++++-------- src/subscription_thread_dispatcher.cpp | 274 +++++++---- .../test_subscription_thread_dispatcher.cpp | 152 +++--- src/video_stream.cpp | 328 ++++++------- 11 files changed, 965 insertions(+), 964 deletions(-) diff --git a/include/livekit/audio_stream.h b/include/livekit/audio_stream.h index d6dce270..2ff72c06 100644 --- a/include/livekit/audio_stream.h +++ b/include/livekit/audio_stream.h @@ -16,21 +16,17 @@ #pragma once -#include -#include -#include +#include #include -#include -#include #include #include "audio_frame.h" -#include "ffi_handle.h" -#include "participant.h" #include "track.h" namespace livekit { +class Participant; + namespace proto { class FfiEvent; } @@ -119,34 +115,15 @@ class AudioStream { void close(); private: - AudioStream() = default; + AudioStream(); void initFromTrack(const std::shared_ptr &track, const Options &options); void initFromParticipant(Participant &participant, TrackSource track_source, const Options &options); - // FFI event handler (registered with FfiClient) - void onFfiEvent(const proto::FfiEvent &event); - - // Queue helpers - void pushFrame(AudioFrameEvent &&ev); - void pushEos(); - - mutable std::mutex mutex_; - std::condition_variable cv_; - std::deque queue_; - std::size_t capacity_{0}; - bool eof_{false}; - bool closed_{false}; - - Options options_; - - // Underlying FFI audio stream handle - FfiHandle stream_handle_; - - // Listener id registered on FfiClient - std::int32_t listener_id_{0}; + struct Impl; + std::unique_ptr impl_; }; } // namespace livekit diff --git a/include/livekit/data_track_stream.h b/include/livekit/data_track_stream.h index cb81054e..b1593ec2 100644 --- a/include/livekit/data_track_stream.h +++ b/include/livekit/data_track_stream.h @@ -19,11 +19,8 @@ #include "livekit/data_track_frame.h" #include "livekit/ffi_handle.h" -#include #include -#include #include -#include #include namespace livekit { @@ -90,43 +87,12 @@ class DataTrackStream { private: friend class RemoteDataTrack; - DataTrackStream() = default; + DataTrackStream(); /// Internal init helper, called by RemoteDataTrack. void init(FfiHandle subscription_handle); - /// FFI event handler, called by FfiClient. - void onFfiEvent(const proto::FfiEvent &event); - - /// Push a received DataTrackFrame to the internal storage. - void pushFrame(DataTrackFrame &&frame); - - /// Push an end-of-stream signal (EOS). - void pushEos(); - - /** Protects all mutable state below. */ - mutable std::mutex mutex_; - - /** Signalled when a frame is pushed or the subscription ends. */ - std::condition_variable cv_; - - /** - * Received frame awaiting read(). - * NOTE: the Rust side handles buffering, so we should only really ever have - * one item. - */ - std::optional frame_; - - /** True once the remote side signals end-of-stream. */ - bool eof_{false}; - - /** True after close() has been called by the consumer. */ - bool closed_{false}; - - /** RAII handle for the Rust-owned subscription resource. */ - FfiHandle subscription_handle_; - - /** FfiClient listener id for routing FfiEvent callbacks to this object. */ - std::int32_t listener_id_{0}; + struct Impl; + std::unique_ptr impl_; }; } // namespace livekit diff --git a/include/livekit/room.h b/include/livekit/room.h index e65d7d1f..ba79cf78 100644 --- a/include/livekit/room.h +++ b/include/livekit/room.h @@ -19,13 +19,14 @@ #include "livekit/data_stream.h" #include "livekit/e2ee.h" -#include "livekit/ffi_handle.h" #include "livekit/room_event_types.h" #include "livekit/subscription_thread_dispatcher.h" #include #include -#include +#include +#include +#include namespace livekit { @@ -318,27 +319,8 @@ class Room { private: friend class RoomCallbackTest; - mutable std::mutex lock_; - ConnectionState connection_state_ = ConnectionState::Disconnected; - RoomDelegate *delegate_ = nullptr; // Not owned - RoomInfoData room_info_; - std::shared_ptr room_handle_; - std::unique_ptr local_participant_; - std::unordered_map> - remote_participants_; - // Data stream - std::unordered_map text_stream_handlers_; - std::unordered_map byte_stream_handlers_; - std::unordered_map> - text_stream_readers_; - std::unordered_map> - byte_stream_readers_; - // E2EE - std::unique_ptr e2ee_manager_; - std::shared_ptr subscription_thread_dispatcher_; - - // FfiClient listener ID (0 means no listener registered) - int listener_id_{0}; + struct Impl; + std::unique_ptr impl_; void OnEvent(const proto::FfiEvent &event); }; diff --git a/include/livekit/subscription_thread_dispatcher.h b/include/livekit/subscription_thread_dispatcher.h index 8e5fa65c..fa796a44 100644 --- a/include/livekit/subscription_thread_dispatcher.h +++ b/include/livekit/subscription_thread_dispatcher.h @@ -20,14 +20,13 @@ #include "livekit/audio_stream.h" #include "livekit/video_stream.h" +#include #include #include #include -#include #include #include #include -#include #include namespace livekit { @@ -357,13 +356,6 @@ class SubscriptionThreadDispatcher { } }; - /// Active read-side resources for one audio/video subscription dispatch slot. - struct ActiveReader { - std::shared_ptr audio_stream; - std::shared_ptr video_stream; - std::thread thread; - }; - /// Compound lookup key for a remote participant identity and data track name. struct DataCallbackKey { std::string participant_identity; @@ -390,14 +382,6 @@ class SubscriptionThreadDispatcher { DataFrameCallback callback; }; - /// Active read-side resources for one data track stream subscription. - struct ActiveDataReader { - std::shared_ptr remote_track; - std::mutex sub_mutex; - std::shared_ptr stream; // guarded by sub_mutex - std::thread thread; - }; - /// Stored audio callback registration plus stream-construction options. struct RegisteredAudioCallback { AudioFrameCallback callback; @@ -455,39 +439,20 @@ class SubscriptionThreadDispatcher { const std::shared_ptr &track, const DataFrameCallback &cb); - /// Protects callback registration maps and active reader state. - mutable std::mutex lock_; - - /// Registered audio frame callbacks keyed by \ref CallbackKey. - std::unordered_map - audio_callbacks_; - - /// Registered video frame callbacks keyed by \ref CallbackKey. - std::unordered_map - video_callbacks_; - - /// Active stream/thread state keyed by \ref CallbackKey. - std::unordered_map - active_readers_; - - /// Next auto-increment ID for data frame callbacks. - DataFrameCallbackId next_data_callback_id_{0}; - - /// Registered data frame callbacks keyed by opaque callback ID. - std::unordered_map - data_callbacks_; - - /// Active data reader threads keyed by callback ID. - std::unordered_map> - active_data_readers_; - - /// Currently published remote data tracks, keyed by (participant, name). - std::unordered_map, - DataCallbackKeyHash> - remote_data_tracks_; - /// Hard limit on concurrently active per-subscription reader threads. static constexpr int kMaxActiveReaders = 20; + + std::size_t audioCallbackCountForTest() const; + std::size_t videoCallbackCountForTest() const; + std::size_t activeReaderCountForTest() const; + std::size_t dataCallbackCountForTest() const; + std::size_t activeDataReaderCountForTest() const; + std::size_t remoteDataTrackCountForTest() const; + bool hasAudioCallbackForTest(const CallbackKey &key) const; + bool hasVideoCallbackForTest(const CallbackKey &key) const; + + struct Impl; + std::unique_ptr impl_; }; } // namespace livekit diff --git a/include/livekit/video_stream.h b/include/livekit/video_stream.h index 850b5038..d0e395f0 100644 --- a/include/livekit/video_stream.h +++ b/include/livekit/video_stream.h @@ -16,22 +16,19 @@ #pragma once -#include +#include #include -#include -#include #include -#include #include -#include "ffi_handle.h" -#include "participant.h" #include "track.h" #include "video_frame.h" #include "video_source.h" namespace livekit { +class Participant; + // A single video frame event delivered by VideoStream::read(). struct VideoFrameEvent { VideoFrame frame; @@ -110,7 +107,7 @@ class VideoStream { void close(); private: - VideoStream() = default; + VideoStream(); // Internal init helpers, used by the factories void initFromTrack(const std::shared_ptr &track, @@ -118,25 +115,8 @@ class VideoStream { void initFromParticipant(Participant &participant, TrackSource source, const Options &options); - // FFI event handler (registered with FfiClient) - void onFfiEvent(const proto::FfiEvent &event); - - // Queue helpers - void pushFrame(VideoFrameEvent &&ev); - void pushEos(); - - mutable std::mutex mutex_; - std::condition_variable cv_; - std::deque queue_; - std::size_t capacity_{0}; - bool eof_{false}; - bool closed_{false}; - - // Underlying FFI handle for the video stream - FfiHandle stream_handle_; - - // Listener id registered on FfiClient - std::int32_t listener_id_{0}; + struct Impl; + std::unique_ptr impl_; }; } // namespace livekit diff --git a/src/audio_stream.cpp b/src/audio_stream.cpp index 07b1e693..84d1f2d4 100644 --- a/src/audio_stream.cpp +++ b/src/audio_stream.cpp @@ -16,11 +16,17 @@ #include "livekit/audio_stream.h" +#include +#include +#include +#include #include #include "audio_frame.pb.h" #include "ffi.pb.h" #include "ffi_client.h" +#include "livekit/ffi_handle.h" +#include "livekit/participant.h" #include "livekit/track.h" namespace livekit { @@ -28,6 +34,150 @@ namespace livekit { using proto::FfiEvent; using proto::FfiRequest; +struct AudioStream::Impl { + ~Impl() { close(); } + + bool read(AudioFrameEvent &out_event) { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !queue_.empty() || eof_ || closed_; }); + + if (closed_ || (queue_.empty() && eof_)) { + return false; + } + + out_event = std::move(queue_.front()); + queue_.pop_front(); + return true; + } + + void close() { + FfiHandle stream_handle; + std::int32_t listener_id = 0; + { + const std::scoped_lock lock(mutex_); + if (closed_) { + return; + } + closed_ = true; + stream_handle = std::move(stream_handle_); + listener_id = listener_id_; + listener_id_ = 0; + } + + if (stream_handle.get() != 0) { + stream_handle.reset(); + } + if (listener_id != 0) { + FfiClient::instance().RemoveListener(listener_id); + } + + cv_.notify_all(); + } + + void initFromTrack(const std::shared_ptr &track, + const Options &options) { + capacity_ = options.capacity; + options_ = options; + + listener_id_ = FfiClient::instance().AddListener( + [this](const FfiEvent &e) { this->onFfiEvent(e); }); + + FfiRequest req; + auto *new_audio_stream = req.mutable_new_audio_stream(); + new_audio_stream->set_track_handle( + static_cast(track->ffi_handle_id())); + new_audio_stream->set_type(proto::AudioStreamType::AUDIO_STREAM_NATIVE); + + if (!options_.noise_cancellation_module.empty()) { + new_audio_stream->set_audio_filter_module_id( + options_.noise_cancellation_module); + new_audio_stream->set_audio_filter_options( + options_.noise_cancellation_options_json); + } + + auto resp = FfiClient::instance().sendRequest(req); + const auto &stream = resp.new_audio_stream().stream(); + stream_handle_ = FfiHandle(static_cast(stream.handle().id())); + } + + void initFromParticipant(Participant &participant, TrackSource track_source, + const Options &options) { + capacity_ = options.capacity; + options_ = options; + + listener_id_ = FfiClient::instance().AddListener( + [this](const FfiEvent &e) { this->onFfiEvent(e); }); + + FfiRequest req; + auto *as = req.mutable_audio_stream_from_participant(); + as->set_participant_handle(participant.ffiHandleId()); + as->set_type(proto::AudioStreamType::AUDIO_STREAM_NATIVE); + as->set_track_source(static_cast(track_source)); + + if (!options_.noise_cancellation_module.empty()) { + as->set_audio_filter_module_id(options_.noise_cancellation_module); + as->set_audio_filter_options(options_.noise_cancellation_options_json); + } + + auto resp = FfiClient::instance().sendRequest(req); + const auto &stream = resp.audio_stream_from_participant().stream(); + stream_handle_ = FfiHandle(static_cast(stream.handle().id())); + } + + void onFfiEvent(const FfiEvent &event) { + if (event.message_case() != FfiEvent::kAudioStreamEvent) { + return; + } + const auto &ase = event.audio_stream_event(); + if (ase.stream_handle() != + static_cast(stream_handle_.get())) { + return; + } + if (ase.has_frame_received()) { + const auto &fr = ase.frame_received(); + AudioFrameEvent ev{AudioFrame::fromOwnedInfo(fr.frame())}; + pushFrame(std::move(ev)); + } else if (ase.has_eos()) { + pushEos(); + } + } + + void pushFrame(AudioFrameEvent &&ev) { + { + const std::scoped_lock lock(mutex_); + if (closed_ || eof_) { + return; + } + if (capacity_ > 0 && queue_.size() >= capacity_) { + queue_.pop_front(); + } + queue_.push_back(std::move(ev)); + } + cv_.notify_one(); + } + + void pushEos() { + { + const std::scoped_lock lock(mutex_); + if (eof_) { + return; + } + eof_ = true; + } + cv_.notify_all(); + } + + mutable std::mutex mutex_; + std::condition_variable cv_; + std::deque queue_; + std::size_t capacity_{0}; + bool eof_{false}; + bool closed_{false}; + Options options_; + FfiHandle stream_handle_; + std::int32_t listener_id_{0}; +}; + // ------------------------ // Factory helpers // ------------------------ @@ -52,200 +202,33 @@ AudioStream::fromParticipant(Participant &participant, TrackSource track_source, // Destructor / move // ------------------------ -AudioStream::~AudioStream() { close(); } - -AudioStream::AudioStream(AudioStream &&other) noexcept { - const std::scoped_lock lock(other.mutex_); - queue_ = std::move(other.queue_); - capacity_ = other.capacity_; - eof_ = other.eof_; - closed_ = other.closed_; - options_ = other.options_; - stream_handle_ = std::move(other.stream_handle_); - listener_id_ = other.listener_id_; - - other.listener_id_ = 0; - other.closed_ = true; -} - -AudioStream &AudioStream::operator=(AudioStream &&other) noexcept { - if (this == &other) { - return *this; - } - - close(); +AudioStream::AudioStream() : impl_(std::make_unique()) {} - { - const std::scoped_lock lock_this(mutex_); - const std::scoped_lock lock_other(other.mutex_); +AudioStream::~AudioStream() = default; - queue_ = std::move(other.queue_); - capacity_ = other.capacity_; - eof_ = other.eof_; - closed_ = other.closed_; - options_ = other.options_; - stream_handle_ = std::move(other.stream_handle_); - listener_id_ = other.listener_id_; +AudioStream::AudioStream(AudioStream &&other) noexcept = default; - other.listener_id_ = 0; - other.closed_ = true; - } - - return *this; -} +AudioStream &AudioStream::operator=(AudioStream &&other) noexcept = default; bool AudioStream::read(AudioFrameEvent &out_event) { - std::unique_lock lock(mutex_); - - cv_.wait(lock, [this] { return !queue_.empty() || eof_ || closed_; }); - - if (closed_ || (queue_.empty() && eof_)) { - return false; // EOS / closed - } - - out_event = std::move(queue_.front()); - queue_.pop_front(); - return true; + return impl_ ? impl_->read(out_event) : false; } void AudioStream::close() { - { - const std::scoped_lock lock(mutex_); - if (closed_) { - return; - } - closed_ = true; - } - - // Dispose FFI handle - if (stream_handle_.get() != 0) { - stream_handle_.reset(); - } - - // Remove listener - if (listener_id_ != 0) { - FfiClient::instance().RemoveListener(listener_id_); - listener_id_ = 0; + if (impl_) { + impl_->close(); } - - // Wake any waiting readers - cv_.notify_all(); } -// Internal functions - void AudioStream::initFromTrack(const std::shared_ptr &track, const Options &options) { - capacity_ = options.capacity; - options_ = options; - - // 1) Subscribe to FFI events - listener_id_ = FfiClient::instance().AddListener( - [this](const FfiEvent &e) { this->onFfiEvent(e); }); - - // 2) Send FfiRequest to create a new audio stream bound to this track - FfiRequest req; - auto *new_audio_stream = req.mutable_new_audio_stream(); - new_audio_stream->set_track_handle( - static_cast(track->ffi_handle_id())); - // TODO, sample_rate and num_channels are not useful in AudioStream, remove it - // from FFI. - // new_audio_stream->set_sample_rate(options_.sample_rate); - // new_audio_stream->set_num_channels(options.num_channels); - new_audio_stream->set_type(proto::AudioStreamType::AUDIO_STREAM_NATIVE); - - if (!options_.noise_cancellation_module.empty()) { - new_audio_stream->set_audio_filter_module_id( - options_.noise_cancellation_module); - // Always set options JSON even if empty - backend will treat empty string - // as "no options" - new_audio_stream->set_audio_filter_options( - options_.noise_cancellation_options_json); - } - - auto resp = FfiClient::instance().sendRequest(req); - const auto &stream = resp.new_audio_stream().stream(); - stream_handle_ = FfiHandle(static_cast(stream.handle().id())); + impl_->initFromTrack(track, options); } void AudioStream::initFromParticipant(Participant &participant, TrackSource track_source, const Options &options) { - capacity_ = options.capacity; - options_ = options; - - // 1) Subscribe to FFI events - listener_id_ = FfiClient::instance().AddListener( - [this](const FfiEvent &e) { this->onFfiEvent(e); }); - - // 2) Send FfiRequest to create audio stream from participant + track source - FfiRequest req; - auto *as = req.mutable_audio_stream_from_participant(); - as->set_participant_handle(participant.ffiHandleId()); - // TODO, sample_rate and num_channels are not useful in AudioStream, remove it - // from FFI. - // as->set_sample_rate(options_.sample_rate); - // as->set_num_channels(options_.num_channels); - as->set_type(proto::AudioStreamType::AUDIO_STREAM_NATIVE); - as->set_track_source(static_cast(track_source)); - - if (!options_.noise_cancellation_module.empty()) { - as->set_audio_filter_module_id(options_.noise_cancellation_module); - // Always set options JSON even if empty — backend will treat empty string - // as "no options" - as->set_audio_filter_options(options_.noise_cancellation_options_json); - } - - auto resp = FfiClient::instance().sendRequest(req); - const auto &stream = resp.audio_stream_from_participant().stream(); - stream_handle_ = FfiHandle(static_cast(stream.handle().id())); -} - -void AudioStream::onFfiEvent(const FfiEvent &event) { - if (event.message_case() != FfiEvent::kAudioStreamEvent) { - return; - } - const auto &ase = event.audio_stream_event(); - // Check if this event is for our stream handle. - if (ase.stream_handle() != static_cast(stream_handle_.get())) { - return; - } - if (ase.has_frame_received()) { - const auto &fr = ase.frame_received(); - AudioFrameEvent ev{AudioFrame::fromOwnedInfo(fr.frame())}; - pushFrame(std::move(ev)); - } else if (ase.has_eos()) { - pushEos(); - } -} - -void AudioStream::pushFrame(AudioFrameEvent &&ev) { - { - const std::scoped_lock lock(mutex_); - - if (closed_ || eof_) { - return; - } - - if (capacity_ > 0 && queue_.size() >= capacity_) { - // Ring behavior: drop oldest frame when full. - queue_.pop_front(); - } - - queue_.push_back(std::move(ev)); - } - cv_.notify_one(); -} - -void AudioStream::pushEos() { - { - const std::scoped_lock lock(mutex_); - if (eof_) { - return; - } - eof_ = true; - } - cv_.notify_all(); + impl_->initFromParticipant(participant, track_source, options); } } // namespace livekit diff --git a/src/data_track_stream.cpp b/src/data_track_stream.cpp index 30a355c8..6f426cb3 100644 --- a/src/data_track_stream.cpp +++ b/src/data_track_stream.cpp @@ -19,122 +19,149 @@ #include "data_track.pb.h" #include "ffi.pb.h" #include "ffi_client.h" -#include "lk_log.h" +#include +#include +#include #include namespace livekit { using proto::FfiEvent; -DataTrackStream::~DataTrackStream() { close(); } +struct DataTrackStream::Impl { + ~Impl() { close(); } -void DataTrackStream::init(FfiHandle subscription_handle) { - subscription_handle_ = std::move(subscription_handle); + void init(FfiHandle subscription_handle) { + subscription_handle_ = std::move(subscription_handle); - listener_id_ = FfiClient::instance().AddListener( - [this](const FfiEvent &e) { this->onFfiEvent(e); }); -} + listener_id_ = FfiClient::instance().AddListener( + [this](const FfiEvent &e) { this->onFfiEvent(e); }); + } -bool DataTrackStream::read(DataTrackFrame &out) { - { - const std::scoped_lock lock(mutex_); - if (closed_ || eof_) { - return false; - } + bool read(DataTrackFrame &out) { + { + const std::scoped_lock lock(mutex_); + if (closed_ || eof_) { + return false; + } - const auto subscription_handle = - static_cast(subscription_handle_.get()); + const auto subscription_handle = + static_cast(subscription_handle_.get()); - // Signal the Rust side that we're ready to receive the next frame. - // The Rust SubscriptionTask uses a demand-driven protocol: it won't pull - // from the underlying stream until notified via this request. - proto::FfiRequest req; - auto *msg = req.mutable_data_track_stream_read(); - msg->set_stream_handle(subscription_handle); - FfiClient::instance().sendRequest(req); - } + proto::FfiRequest req; + auto *msg = req.mutable_data_track_stream_read(); + msg->set_stream_handle(subscription_handle); + FfiClient::instance().sendRequest(req); + } - std::unique_lock lock(mutex_); - cv_.wait(lock, [this] { return frame_.has_value() || eof_ || closed_; }); + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return frame_.has_value() || eof_ || closed_; }); - if (closed_ || (!frame_.has_value() && eof_)) { - return false; + if (closed_ || (!frame_.has_value() && eof_)) { + return false; + } + + out = std::move(*frame_); // NOLINT(bugprone-unchecked-optional-access) + frame_.reset(); + return true; } - out = std::move(*frame_); // NOLINT(bugprone-unchecked-optional-access) - frame_.reset(); - return true; -} + void close() { + FfiHandle subscription_handle; + std::int32_t listener_id = 0; + { + const std::scoped_lock lock(mutex_); + if (closed_) { + return; + } + closed_ = true; + subscription_handle = std::move(subscription_handle_); + listener_id = listener_id_; + listener_id_ = 0; + } -void DataTrackStream::close() { - std::int32_t listener_id = -1; - { - const std::scoped_lock lock(mutex_); - if (closed_) { - return; + if (subscription_handle.get() != 0) { + subscription_handle.reset(); + } + if (listener_id != 0) { + FfiClient::instance().RemoveListener(listener_id); } - closed_ = true; - subscription_handle_.reset(); - listener_id = listener_id_; - listener_id_ = 0; - } - if (listener_id != -1) { - FfiClient::instance().RemoveListener(listener_id); + cv_.notify_all(); } - cv_.notify_all(); -} + void onFfiEvent(const FfiEvent &event) { + if (event.message_case() != FfiEvent::kDataTrackStreamEvent) { + return; + } + + const auto &dts = event.data_track_stream_event(); + { + const std::scoped_lock lock(mutex_); + if (closed_ || dts.stream_handle() != static_cast( + subscription_handle_.get())) { + return; + } + } -void DataTrackStream::onFfiEvent(const FfiEvent &event) { - if (event.message_case() != FfiEvent::kDataTrackStreamEvent) { - return; + if (dts.has_frame_received()) { + const auto &fr = dts.frame_received().frame(); + DataTrackFrame frame = DataTrackFrame::fromOwnedInfo(fr); + pushFrame(std::move(frame)); + } else if (dts.has_eos()) { + pushEos(); + } } - const auto &dts = event.data_track_stream_event(); - { + void pushFrame(DataTrackFrame &&frame) { const std::scoped_lock lock(mutex_); - if (closed_ || dts.stream_handle() != - static_cast(subscription_handle_.get())) { + + if (closed_ || eof_) { return; } + + assert(!frame_.has_value()); + frame_ = std::move(frame); + cv_.notify_one(); } - if (dts.has_frame_received()) { - const auto &fr = dts.frame_received().frame(); - DataTrackFrame frame = DataTrackFrame::fromOwnedInfo(fr); - pushFrame(std::move(frame)); - } else if (dts.has_eos()) { - pushEos(); + void pushEos() { + { + const std::scoped_lock lock(mutex_); + if (eof_) { + return; + } + eof_ = true; + } + cv_.notify_all(); } -} -void DataTrackStream::pushFrame(DataTrackFrame &&frame) { - const std::scoped_lock lock(mutex_); + mutable std::mutex mutex_; + std::condition_variable cv_; + std::optional frame_; + bool eof_{false}; + bool closed_{false}; + FfiHandle subscription_handle_; + std::int32_t listener_id_{0}; +}; - if (closed_ || eof_) { - return; - } +DataTrackStream::DataTrackStream() : impl_(std::make_unique()) {} - // rust side handles buffering, so we should only really ever have one item - assert(!frame_.has_value()); +DataTrackStream::~DataTrackStream() = default; - frame_ = std::move(frame); +void DataTrackStream::init(FfiHandle subscription_handle) { + impl_->init(std::move(subscription_handle)); +} - // notify no matter what since we got a new frame - cv_.notify_one(); +bool DataTrackStream::read(DataTrackFrame &out) { + return impl_ ? impl_->read(out) : false; } -void DataTrackStream::pushEos() { - { - const std::scoped_lock lock(mutex_); - if (eof_) { - return; - } - eof_ = true; +void DataTrackStream::close() { + if (impl_) { + impl_->close(); } - cv_.notify_all(); } } // namespace livekit diff --git a/src/room.cpp b/src/room.cpp index 9eae941d..8a33b126 100644 --- a/src/room.cpp +++ b/src/room.cpp @@ -18,9 +18,8 @@ #include "livekit/audio_stream.h" #include "livekit/e2ee.h" -#include "livekit/local_data_track.h" +#include "livekit/ffi_handle.h" #include "livekit/local_participant.h" -#include "livekit/local_track_publication.h" #include "livekit/remote_audio_track.h" #include "livekit/remote_data_track.h" #include "livekit/remote_participant.h" @@ -29,7 +28,6 @@ #include "livekit/room_delegate.h" #include "livekit/room_event_types.h" -#include "data_track.pb.h" #include "ffi.pb.h" #include "ffi_client.h" #include "livekit_ffi.h" @@ -40,6 +38,8 @@ #include "track.pb.h" #include "track_proto_converter.h" #include +#include +#include namespace livekit { @@ -68,23 +68,44 @@ createRemoteParticipant(const proto::OwnedParticipant &owned) { } } // namespace -Room::Room() - : subscription_thread_dispatcher_( - std::make_unique()) {} + +struct Room::Impl { + mutable std::mutex lock_; + ConnectionState connection_state_ = ConnectionState::Disconnected; + RoomDelegate *delegate_ = nullptr; // Not owned + RoomInfoData room_info_; + std::shared_ptr room_handle_; + std::unique_ptr local_participant_; + std::unordered_map> + remote_participants_; + std::unordered_map text_stream_handlers_; + std::unordered_map byte_stream_handlers_; + std::unordered_map> + text_stream_readers_; + std::unordered_map> + byte_stream_readers_; + std::unique_ptr e2ee_manager_; + std::shared_ptr + subscription_thread_dispatcher_ = + std::make_shared(); + int listener_id_{0}; +}; + +Room::Room() : impl_(std::make_unique()) {} Room::~Room() { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->stopAll(); + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->stopAll(); } int listener_to_remove = 0; std::unique_ptr local_participant_to_cleanup; { - const std::scoped_lock g(lock_); - listener_to_remove = listener_id_; - listener_id_ = 0; + const std::scoped_lock g(impl_->lock_); + listener_to_remove = impl_->listener_id_; + impl_->listener_id_ = 0; // Move local participant out for cleanup outside the lock - local_participant_to_cleanup = std::move(local_participant_); + local_participant_to_cleanup = std::move(impl_->local_participant_); } // Shutdown local participant (unregisters RPC handlers, etc.) before @@ -102,8 +123,8 @@ Room::~Room() { } void Room::setDelegate(RoomDelegate *delegate) { - const std::scoped_lock g(lock_); - delegate_ = delegate; + const std::scoped_lock g(impl_->lock_); + impl_->delegate_ = delegate; } bool Room::Connect(const std::string &url, const std::string &token, @@ -111,11 +132,11 @@ bool Room::Connect(const std::string &url, const std::string &token, TRACE_EVENT0("livekit", "Room::Connect"); { - const std::scoped_lock g(lock_); - if (connection_state_ != ConnectionState::Disconnected) { + const std::scoped_lock g(impl_->lock_); + if (impl_->connection_state_ != ConnectionState::Disconnected) { throw std::runtime_error("already connected"); } - connection_state_ = ConnectionState::Reconnecting; + impl_->connection_state_ = ConnectionState::Reconnecting; } auto fut = FfiClient::instance().connectAsync(url, token, options); try { @@ -154,7 +175,7 @@ bool Room::Connect(const std::string &url, const std::string &token, new_remote_participants; { const auto &participants = connectCb.result().participants(); - const std::scoped_lock g(lock_); + const std::scoped_lock g(impl_->lock_); for (const auto &pt : participants) { const auto &owned = pt.participant(); auto rp = createRemoteParticipant(owned); @@ -180,69 +201,69 @@ bool Room::Connect(const std::string &url, const std::string &token, // Publish all state atomically under lock { - const std::scoped_lock g(lock_); - room_handle_ = std::move(new_room_handle); - room_info_ = std::move(new_room_info); - local_participant_ = std::move(new_local_participant); - remote_participants_ = std::move(new_remote_participants); - e2ee_manager_ = std::move(new_e2ee_manager); - connection_state_ = ConnectionState::Connected; + const std::scoped_lock g(impl_->lock_); + impl_->room_handle_ = std::move(new_room_handle); + impl_->room_info_ = std::move(new_room_info); + impl_->local_participant_ = std::move(new_local_participant); + impl_->remote_participants_ = std::move(new_remote_participants); + impl_->e2ee_manager_ = std::move(new_e2ee_manager); + impl_->connection_state_ = ConnectionState::Connected; } // Install listener (Room is fully initialized) auto listenerId = FfiClient::instance().AddListener( [this](const proto::FfiEvent &e) { OnEvent(e); }); { - const std::scoped_lock g(lock_); - listener_id_ = listenerId; + const std::scoped_lock g(impl_->lock_); + impl_->listener_id_ = listenerId; } return true; } catch (const std::exception &e) { - // On error, set the connection_state_ to Disconnected - connection_state_ = ConnectionState::Disconnected; + // On error, set the connection state to Disconnected. + impl_->connection_state_ = ConnectionState::Disconnected; LK_LOG_ERROR("Room::Connect failed: {}", e.what()); return false; } } RoomInfoData Room::room_info() const { - const std::scoped_lock g(lock_); - return room_info_; + const std::scoped_lock g(impl_->lock_); + return impl_->room_info_; } LocalParticipant *Room::localParticipant() const { - const std::scoped_lock g(lock_); - return local_participant_.get(); + const std::scoped_lock g(impl_->lock_); + return impl_->local_participant_.get(); } RemoteParticipant *Room::remoteParticipant(const std::string &identity) const { - const std::scoped_lock g(lock_); - auto it = remote_participants_.find(identity); - return it == remote_participants_.end() ? nullptr : it->second.get(); + const std::scoped_lock g(impl_->lock_); + auto it = impl_->remote_participants_.find(identity); + return it == impl_->remote_participants_.end() ? nullptr : it->second.get(); } std::vector> Room::remoteParticipants() const { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); std::vector> out; - out.reserve(remote_participants_.size()); - for (const auto &kv : remote_participants_) { + out.reserve(impl_->remote_participants_.size()); + for (const auto &kv : impl_->remote_participants_) { out.push_back(kv.second); } return out; } E2EEManager *Room::e2eeManager() const { - const std::scoped_lock g(lock_); - return e2ee_manager_.get(); + const std::scoped_lock g(impl_->lock_); + return impl_->e2ee_manager_.get(); } void Room::registerTextStreamHandler(const std::string &topic, TextStreamHandler handler) { - const std::scoped_lock g(lock_); + const std::scoped_lock g(impl_->lock_); auto [it, inserted] = - text_stream_handlers_.emplace(topic, std::move(handler)); + impl_->text_stream_handlers_.emplace(topic, std::move(handler)); if (!inserted) { throw std::runtime_error("text stream handler for topic '" + topic + "' already set"); @@ -250,15 +271,15 @@ void Room::registerTextStreamHandler(const std::string &topic, } void Room::unregisterTextStreamHandler(const std::string &topic) { - const std::scoped_lock g(lock_); - text_stream_handlers_.erase(topic); + const std::scoped_lock g(impl_->lock_); + impl_->text_stream_handlers_.erase(topic); } void Room::registerByteStreamHandler(const std::string &topic, ByteStreamHandler handler) { - const std::scoped_lock g(lock_); + const std::scoped_lock g(impl_->lock_); auto [it, inserted] = - byte_stream_handlers_.emplace(topic, std::move(handler)); + impl_->byte_stream_handlers_.emplace(topic, std::move(handler)); if (!inserted) { throw std::runtime_error("byte stream handler for topic '" + topic + "' already set"); @@ -266,8 +287,8 @@ void Room::registerByteStreamHandler(const std::string &topic, } void Room::unregisterByteStreamHandler(const std::string &topic) { - const std::scoped_lock g(lock_); - byte_stream_handlers_.erase(topic); + const std::scoped_lock g(impl_->lock_); + impl_->byte_stream_handlers_.erase(topic); } // ------------------------------------------------------------------- @@ -278,8 +299,8 @@ void Room::setOnAudioFrameCallback(const std::string &participant_identity, TrackSource source, AudioFrameCallback callback, const AudioStream::Options &opts) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->setOnAudioFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->setOnAudioFrameCallback( participant_identity, source, std::move(callback), opts); } } @@ -288,8 +309,8 @@ void Room::setOnAudioFrameCallback(const std::string &participant_identity, const std::string &track_name, AudioFrameCallback callback, const AudioStream::Options &opts) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->setOnAudioFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->setOnAudioFrameCallback( participant_identity, track_name, std::move(callback), opts); } } @@ -298,8 +319,8 @@ void Room::setOnVideoFrameCallback(const std::string &participant_identity, TrackSource source, VideoFrameCallback callback, const VideoStream::Options &opts) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->setOnVideoFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->setOnVideoFrameCallback( participant_identity, source, std::move(callback), opts); } } @@ -308,8 +329,8 @@ void Room::setOnVideoFrameCallback(const std::string &participant_identity, const std::string &track_name, VideoFrameCallback callback, const VideoStream::Options &opts) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->setOnVideoFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->setOnVideoFrameCallback( participant_identity, track_name, std::move(callback), opts); } } @@ -318,40 +339,40 @@ void Room::setOnVideoFrameEventCallback(const std::string &participant_identity, const std::string &track_name, VideoFrameEventCallback callback, const VideoStream::Options &opts) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->setOnVideoFrameEventCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->setOnVideoFrameEventCallback( participant_identity, track_name, std::move(callback), opts); } } void Room::clearOnAudioFrameCallback(const std::string &participant_identity, TrackSource source) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->clearOnAudioFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->clearOnAudioFrameCallback( participant_identity, source); } } void Room::clearOnAudioFrameCallback(const std::string &participant_identity, const std::string &track_name) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->clearOnAudioFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->clearOnAudioFrameCallback( participant_identity, track_name); } } void Room::clearOnVideoFrameCallback(const std::string &participant_identity, TrackSource source) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->clearOnVideoFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->clearOnVideoFrameCallback( participant_identity, source); } } void Room::clearOnVideoFrameCallback(const std::string &participant_identity, const std::string &track_name) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->clearOnVideoFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->clearOnVideoFrameCallback( participant_identity, track_name); } } @@ -360,16 +381,16 @@ DataFrameCallbackId Room::addOnDataFrameCallback(const std::string &participant_identity, const std::string &track_name, DataFrameCallback callback) { - if (subscription_thread_dispatcher_) { - return subscription_thread_dispatcher_->addOnDataFrameCallback( + if (impl_->subscription_thread_dispatcher_) { + return impl_->subscription_thread_dispatcher_->addOnDataFrameCallback( participant_identity, track_name, std::move(callback)); } return std::numeric_limits::max(); } void Room::removeOnDataFrameCallback(DataFrameCallbackId id) { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->removeOnDataFrameCallback(id); + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->removeOnDataFrameCallback(id); } } @@ -378,8 +399,8 @@ void Room::OnEvent(const FfiEvent &event) { // lock. RoomDelegate *delegate_snapshot = nullptr; { - const std::scoped_lock guard(lock_); - delegate_snapshot = delegate_; + const std::scoped_lock guard(impl_->lock_); + delegate_snapshot = impl_->delegate_; } // First, handle RPC method invocations (not part of RoomEvent). @@ -388,18 +409,18 @@ void Room::OnEvent(const FfiEvent &event) { LocalParticipant *lp = nullptr; { - const std::scoped_lock guard(lock_); - if (!local_participant_) { + const std::scoped_lock guard(impl_->lock_); + if (!impl_->local_participant_) { return; } - auto local_handle = local_participant_->ffiHandleId(); + auto local_handle = impl_->local_participant_->ffiHandleId(); if (local_handle == INVALID_HANDLE || rpc.local_participant_handle() != static_cast(local_handle)) { // RPC is not targeted at this room's local participant; ignore. return; } - lp = local_participant_.get(); + lp = impl_->local_participant_.get(); } // Call outside the lock to avoid deadlocks / re-entrancy issues. @@ -417,9 +438,10 @@ void Room::OnEvent(const FfiEvent &event) { // Check if this event is for our room handle { - const std::scoped_lock guard(lock_); - if (!room_handle_ || - re.room_handle() != static_cast(room_handle_->get())) { + const std::scoped_lock guard(impl_->lock_); + if (!impl_->room_handle_ || + re.room_handle() != + static_cast(impl_->room_handle_->get())) { return; } } @@ -428,12 +450,12 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kParticipantConnected: { std::shared_ptr new_participant; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &owned = re.participant_connected().info(); // createRemoteParticipant takes proto::OwnedParticipant new_participant = createRemoteParticipant(owned); - remote_participants_.emplace(new_participant->identity(), - new_participant); + impl_->remote_participants_.emplace(new_participant->identity(), + new_participant); } ParticipantConnectedEvent ev; ev.participant = new_participant.get(); @@ -446,15 +468,15 @@ void Room::OnEvent(const FfiEvent &event) { std::shared_ptr removed; DisconnectReason reason = DisconnectReason::Unknown; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &pd = re.participant_disconnected(); const std::string &identity = pd.participant_identity(); reason = toDisconnectReason(pd.disconnect_reason()); - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { removed = it->second; - remote_participants_.erase(it); + impl_->remote_participants_.erase(it); } else { // We saw a disconnect event for a participant we don't track // internally. This can happen on races or if we never created a @@ -476,14 +498,14 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kLocalTrackPublished: { LocalTrackPublishedEvent ev; { - const std::scoped_lock guard(lock_); - if (!local_participant_) { + const std::scoped_lock guard(impl_->lock_); + if (!impl_->local_participant_) { LK_LOG_ERROR("kLocalTrackPublished: local_participant_ is nullptr"); break; } const auto <p = re.local_track_published(); const std::string &sid = ltp.track_sid(); - const auto pubs = local_participant_->trackPublications(); + const auto pubs = impl_->local_participant_->trackPublications(); auto it = pubs.find(sid); if (it == pubs.end()) { LK_LOG_WARN("local_track_published for unknown sid: {}", sid); @@ -500,14 +522,14 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kLocalTrackUnpublished: { LocalTrackUnpublishedEvent ev; { - const std::scoped_lock guard(lock_); - if (!local_participant_) { + const std::scoped_lock guard(impl_->lock_); + if (!impl_->local_participant_) { LK_LOG_ERROR("kLocalTrackUnpublished: local_participant_ is nullptr"); break; } const auto <u = re.local_track_unpublished(); const std::string &pub_sid = ltu.publication_sid(); - const auto pubs = local_participant_->trackPublications(); + const auto pubs = impl_->local_participant_->trackPublications(); auto it = pubs.find(pub_sid); if (it == pubs.end()) { LK_LOG_WARN("local_track_unpublished for unknown publication sid: {}", @@ -524,13 +546,13 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kLocalTrackSubscribed: { LocalTrackSubscribedEvent ev; { - const std::scoped_lock guard(lock_); - if (!local_participant_) { + const std::scoped_lock guard(impl_->lock_); + if (!impl_->local_participant_) { break; } const auto <s = re.local_track_subscribed(); const std::string &sid = lts.track_sid(); - const auto pubs = local_participant_->trackPublications(); + const auto pubs = impl_->local_participant_->trackPublications(); auto it = pubs.find(sid); if (it == pubs.end()) { LK_LOG_WARN("local_track_subscribed for unknown sid: {}", sid); @@ -548,11 +570,11 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kTrackPublished: { TrackPublishedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &tp = re.track_published(); const std::string &identity = tp.participant_identity(); - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { RemoteParticipant *rparticipant = it->second.get(); const auto &owned_publication = tp.publication(); auto rpublication = @@ -577,12 +599,12 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kTrackUnpublished: { TrackUnpublishedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &tu = re.track_unpublished(); const std::string &identity = tu.participant_identity(); const std::string &pub_sid = tu.publication_sid(); - auto pit = remote_participants_.find(identity); - if (pit == remote_participants_.end()) { + auto pit = impl_->remote_participants_.find(identity); + if (pit == impl_->remote_participants_.end()) { LK_LOG_WARN("track_unpublished for unknown participant: {}", identity); break; @@ -615,10 +637,10 @@ void Room::OnEvent(const FfiEvent &event) { RemoteParticipant *rparticipant = nullptr; std::shared_ptr remote_track; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); // Find participant - auto pit = remote_participants_.find(identity); - if (pit == remote_participants_.end()) { + auto pit = impl_->remote_participants_.find(identity); + if (pit == impl_->remote_participants_.end()) { LK_LOG_WARN("track_subscribed for unknown participant: {}", identity); break; } @@ -658,8 +680,9 @@ void Room::OnEvent(const FfiEvent &event) { delegate_snapshot->onTrackSubscribed(*this, ev); } - if (subscription_thread_dispatcher_ && remote_track && rpublication) { - subscription_thread_dispatcher_->handleTrackSubscribed( + if (impl_->subscription_thread_dispatcher_ && remote_track && + rpublication) { + impl_->subscription_thread_dispatcher_->handleTrackSubscribed( identity, rpublication->source(), rpublication->name(), remote_track); } @@ -670,12 +693,12 @@ void Room::OnEvent(const FfiEvent &event) { TrackSource unsub_source = TrackSource::SOURCE_UNKNOWN; std::string unsub_identity; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &tu = re.track_unsubscribed(); unsub_identity = tu.participant_identity(); const std::string &track_sid = tu.track_sid(); - auto pit = remote_participants_.find(unsub_identity); - if (pit == remote_participants_.end()) { + auto pit = impl_->remote_participants_.find(unsub_identity); + if (pit == impl_->remote_participants_.end()) { LK_LOG_WARN("track_unsubscribed for unknown participant: {}", unsub_identity); break; @@ -703,9 +726,9 @@ void Room::OnEvent(const FfiEvent &event) { delegate_snapshot->onTrackUnsubscribed(*this, ev); } - if (subscription_thread_dispatcher_ && + if (impl_->subscription_thread_dispatcher_ && unsub_source != TrackSource::SOURCE_UNKNOWN) { - subscription_thread_dispatcher_->handleTrackUnsubscribed( + impl_->subscription_thread_dispatcher_->handleTrackUnsubscribed( unsub_identity, unsub_source, ev.publication ? ev.publication->name() : ""); } @@ -714,11 +737,11 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kTrackSubscriptionFailed: { TrackSubscriptionFailedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &tsf = re.track_subscription_failed(); const std::string &identity = tsf.participant_identity(); - auto pit = remote_participants_.find(identity); - if (pit == remote_participants_.end()) { + auto pit = impl_->remote_participants_.find(identity); + if (pit == impl_->remote_participants_.end()) { LK_LOG_WARN("track_subscription_failed for unknown participant: {}", identity); break; @@ -737,8 +760,9 @@ void Room::OnEvent(const FfiEvent &event) { auto remote_track = std::shared_ptr(new RemoteDataTrack(rdtp.track())); - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->handleDataTrackPublished(remote_track); + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->handleDataTrackPublished( + remote_track); } DataTrackPublishedEvent ev; @@ -751,8 +775,9 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kDataTrackUnpublished: { const auto &dtu = re.data_track_unpublished(); - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->handleDataTrackUnpublished(dtu.sid()); + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->handleDataTrackUnpublished( + dtu.sid()); } DataTrackUnpublishedEvent ev; @@ -766,16 +791,17 @@ void Room::OnEvent(const FfiEvent &event) { TrackMutedEvent ev; bool success = false; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &tm = re.track_muted(); const std::string &identity = tm.participant_identity(); const std::string &sid = tm.track_sid(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto pit = remote_participants_.find(identity); - if (pit != remote_participants_.end()) { + auto pit = impl_->remote_participants_.find(identity); + if (pit != impl_->remote_participants_.end()) { participant = pit->second.get(); } } @@ -805,16 +831,17 @@ void Room::OnEvent(const FfiEvent &event) { TrackUnmutedEvent ev; bool success = false; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &tu = re.track_unmuted(); const std::string &identity = tu.participant_identity(); const std::string &sid = tu.track_sid(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto pit = remote_participants_.find(identity); - if (pit != remote_participants_.end()) { + auto pit = impl_->remote_participants_.find(identity); + if (pit != impl_->remote_participants_.end()) { participant = pit->second.get(); } } @@ -848,18 +875,18 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kActiveSpeakersChanged: { ActiveSpeakersChangedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &asc = re.active_speakers_changed(); for (const auto &identity : asc.participant_identities()) { // Appears to be clang-tidy false positive // NOLINTNEXTLINE(misc-const-correctness) Participant *participant = nullptr; - if (local_participant_ && - local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto pit = remote_participants_.find(identity); - if (pit != remote_participants_.end()) { + auto pit = impl_->remote_participants_.find(identity); + if (pit != impl_->remote_participants_.end()) { participant = pit->second.get(); } } @@ -876,11 +903,11 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kRoomMetadataChanged: { RoomMetadataChangedEvent ev; { - const std::scoped_lock guard(lock_); - const auto old_metadata = room_info_.metadata; - room_info_.metadata = re.room_metadata_changed().metadata(); + const std::scoped_lock guard(impl_->lock_); + const auto old_metadata = impl_->room_info_.metadata; + impl_->room_info_.metadata = re.room_metadata_changed().metadata(); ev.old_metadata = old_metadata; - ev.new_metadata = room_info_.metadata; + ev.new_metadata = impl_->room_info_.metadata; } if (delegate_snapshot) { delegate_snapshot->onRoomMetadataChanged(*this, ev); @@ -890,9 +917,9 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kRoomSidChanged: { RoomSidChangedEvent ev; { - const std::scoped_lock guard(lock_); - room_info_.sid = re.room_sid_changed().sid(); - ev.sid = room_info_.sid.value_or(std::string{}); + const std::scoped_lock guard(impl_->lock_); + impl_->room_info_.sid = re.room_sid_changed().sid(); + ev.sid = impl_->room_info_.sid.value_or(std::string{}); } if (delegate_snapshot) { delegate_snapshot->onRoomSidChanged(*this, ev); @@ -902,15 +929,16 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kParticipantMetadataChanged: { ParticipantMetadataChangedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &pm = re.participant_metadata_changed(); const std::string &identity = pm.participant_identity(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { participant = it->second.get(); } } @@ -935,15 +963,16 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kParticipantNameChanged: { ParticipantNameChangedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &pn = re.participant_name_changed(); const std::string &identity = pn.participant_identity(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { participant = it->second.get(); } } @@ -966,15 +995,16 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kParticipantAttributesChanged: { ParticipantAttributesChangedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &pa = re.participant_attributes_changed(); const std::string &identity = pa.participant_identity(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { participant = it->second.get(); } } @@ -1005,15 +1035,16 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kParticipantEncryptionStatusChanged: { ParticipantEncryptionStatusChangedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &pe = re.participant_encryption_status_changed(); const std::string &identity = pe.participant_identity(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { participant = it->second.get(); } } @@ -1035,15 +1066,16 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kConnectionQualityChanged: { ConnectionQualityChangedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &cq = re.connection_quality_changed(); const std::string &identity = cq.participant_identity(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { participant = it->second.get(); } } @@ -1069,9 +1101,9 @@ void Room::OnEvent(const FfiEvent &event) { const auto &dp = re.data_packet_received(); RemoteParticipant *rp = nullptr; { - const std::scoped_lock guard(lock_); - auto it = remote_participants_.find(dp.participant_identity()); - if (it != remote_participants_.end()) { + const std::scoped_lock guard(impl_->lock_); + auto it = impl_->remote_participants_.find(dp.participant_identity()); + if (it != impl_->remote_participants_.end()) { rp = it->second.get(); } } @@ -1094,15 +1126,16 @@ void Room::OnEvent(const FfiEvent &event) { E2eeStateChangedEvent ev; { LK_LOG_DEBUG("e2ee_state_changed for participant"); - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &es = re.e2ee_state_changed(); const std::string &identity = es.participant_identity(); Participant *participant = nullptr; - if (local_participant_ && local_participant_->identity() == identity) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + impl_->local_participant_->identity() == identity) { + participant = impl_->local_participant_.get(); } else { - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { participant = it->second.get(); } } @@ -1128,14 +1161,14 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kConnectionStateChanged: { ConnectionStateChangedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &cs = re.connection_state_changed(); // TODO, maybe we should update our |connection_state_| // correspoindingly, but the this kConnectionStateChanged event is never // triggered in my local test. LK_LOG_DEBUG("cs.state() is {} connection_state_ is {}", static_cast(cs.state()), - static_cast(connection_state_)); + static_cast(impl_->connection_state_)); ev.state = static_cast(cs.state()); } if (delegate_snapshot) { @@ -1166,8 +1199,8 @@ void Room::OnEvent(const FfiEvent &event) { break; } case proto::RoomEvent::kEos: { - if (subscription_thread_dispatcher_) { - subscription_thread_dispatcher_->stopAll(); + if (impl_->subscription_thread_dispatcher_) { + impl_->subscription_thread_dispatcher_->stopAll(); } int listener_to_remove = 0; @@ -1185,20 +1218,20 @@ void Room::OnEvent(const FfiEvent &event) { old_byte_readers; { - const std::scoped_lock guard(lock_); - listener_to_remove = listener_id_; - listener_id_ = 0; + const std::scoped_lock guard(impl_->lock_); + listener_to_remove = impl_->listener_id_; + impl_->listener_id_ = 0; // Reset connection state - connection_state_ = ConnectionState::Disconnected; + impl_->connection_state_ = ConnectionState::Disconnected; // Move state out for cleanup outside lock - old_local_participant = std::move(local_participant_); - old_remote_participants = std::move(remote_participants_); - old_room_handle = std::move(room_handle_); - old_e2ee_manager = std::move(e2ee_manager_); - old_text_readers = std::move(text_stream_readers_); - old_byte_readers = std::move(byte_stream_readers_); + old_local_participant = std::move(impl_->local_participant_); + old_remote_participants = std::move(impl_->remote_participants_); + old_room_handle = std::move(impl_->room_handle_); + old_e2ee_manager = std::move(impl_->e2ee_manager_); + old_text_readers = std::move(impl_->text_stream_readers_); + old_byte_readers = std::move(impl_->byte_stream_readers_); } // Remove listener outside lock @@ -1230,14 +1263,14 @@ void Room::OnEvent(const FfiEvent &event) { std::shared_ptr text_reader; std::shared_ptr byte_reader; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); // Determine stream type from oneof in protobuf // Adjust these names if your generated C++ uses different ones const auto stream_type = header.content_header_case(); if (stream_type == proto::DataStream::Header::kTextHeader) { - auto it = text_stream_handlers_.find(header.topic()); - if (it == text_stream_handlers_.end()) { + auto it = impl_->text_stream_handlers_.find(header.topic()); + if (it == impl_->text_stream_handlers_.end()) { // Ignore if no callback attached break; } @@ -1245,17 +1278,17 @@ void Room::OnEvent(const FfiEvent &event) { const TextStreamInfo info = makeTextInfo(header); text_reader = std::make_shared(info); - text_stream_readers_[header.stream_id()] = text_reader; + impl_->text_stream_readers_[header.stream_id()] = text_reader; } else if (stream_type == proto::DataStream::Header::kByteHeader) { - auto it = byte_stream_handlers_.find(header.topic()); - if (it == byte_stream_handlers_.end()) { + auto it = impl_->byte_stream_handlers_.find(header.topic()); + if (it == impl_->byte_stream_handlers_.end()) { break; } byte_cb = it->second; const ByteStreamInfo info = makeByteInfo(header); byte_reader = std::make_shared(info); - byte_stream_readers_[header.stream_id()] = byte_reader; + impl_->byte_stream_readers_[header.stream_id()] = byte_reader; } else { // unknown header type: ignore @@ -1277,13 +1310,13 @@ void Room::OnEvent(const FfiEvent &event) { std::shared_ptr text_reader; std::shared_ptr byte_reader; { - const std::scoped_lock guard(lock_); - auto itT = text_stream_readers_.find(chunk.stream_id()); - if (itT != text_stream_readers_.end()) { + const std::scoped_lock guard(impl_->lock_); + auto itT = impl_->text_stream_readers_.find(chunk.stream_id()); + if (itT != impl_->text_stream_readers_.end()) { text_reader = itT->second; } else { - auto itB = byte_stream_readers_.find(chunk.stream_id()); - if (itB != byte_stream_readers_.end()) { + auto itB = impl_->byte_stream_readers_.find(chunk.stream_id()); + if (itB != impl_->byte_stream_readers_.end()) { byte_reader = itB->second; } } @@ -1309,16 +1342,16 @@ void Room::OnEvent(const FfiEvent &event) { trailer_attrs.emplace(kv.first, kv.second); } { - const std::scoped_lock guard(lock_); - auto itT = text_stream_readers_.find(trailer.stream_id()); - if (itT != text_stream_readers_.end()) { + const std::scoped_lock guard(impl_->lock_); + auto itT = impl_->text_stream_readers_.find(trailer.stream_id()); + if (itT != impl_->text_stream_readers_.end()) { text_reader = itT->second; - text_stream_readers_.erase(itT); + impl_->text_stream_readers_.erase(itT); } else { - auto itB = byte_stream_readers_.find(trailer.stream_id()); - if (itB != byte_stream_readers_.end()) { + auto itB = impl_->byte_stream_readers_.find(trailer.stream_id()); + if (itB != impl_->byte_stream_readers_.end()) { byte_reader = itB->second; - byte_stream_readers_.erase(itB); + impl_->byte_stream_readers_.erase(itB); } } } @@ -1368,18 +1401,18 @@ void Room::OnEvent(const FfiEvent &event) { case proto::RoomEvent::kParticipantsUpdated: { ParticipantsUpdatedEvent ev; { - const std::scoped_lock guard(lock_); + const std::scoped_lock guard(impl_->lock_); const auto &pu = re.participants_updated(); for (const auto &info : pu.participants()) { const std::string &identity = info.identity(); Participant *participant = nullptr; - if (local_participant_ && - identity == local_participant_->identity()) { - participant = local_participant_.get(); + if (impl_->local_participant_ && + identity == impl_->local_participant_->identity()) { + participant = impl_->local_participant_.get(); } else { - auto it = remote_participants_.find(identity); - if (it != remote_participants_.end()) { + auto it = impl_->remote_participants_.find(identity); + if (it != impl_->remote_participants_.end()) { participant = it->second.get(); } } diff --git a/src/subscription_thread_dispatcher.cpp b/src/subscription_thread_dispatcher.cpp index 91877877..72078992 100644 --- a/src/subscription_thread_dispatcher.cpp +++ b/src/subscription_thread_dispatcher.cpp @@ -23,6 +23,8 @@ #include "lk_log.h" #include +#include +#include #include #include @@ -43,7 +45,39 @@ const char *trackKindName(TrackKind kind) { } // namespace -SubscriptionThreadDispatcher::SubscriptionThreadDispatcher() = default; +struct SubscriptionThreadDispatcher::Impl { + struct ActiveReader { + std::shared_ptr audio_stream; + std::shared_ptr video_stream; + std::thread thread; + }; + + struct ActiveDataReader { + std::shared_ptr remote_track; + std::mutex sub_mutex; + std::shared_ptr stream; // guarded by sub_mutex + std::thread thread; + }; + + mutable std::mutex lock_; + std::unordered_map + audio_callbacks_; + std::unordered_map + video_callbacks_; + std::unordered_map + active_readers_; + DataFrameCallbackId next_data_callback_id_{0}; + std::unordered_map + data_callbacks_; + std::unordered_map> + active_data_readers_; + std::unordered_map, + DataCallbackKeyHash> + remote_data_tracks_; +}; + +SubscriptionThreadDispatcher::SubscriptionThreadDispatcher() + : impl_(std::make_unique()) {} // NOLINTBEGIN(bugprone-exception-escape) // Exceptions can be thrown by stopAll() in this desctuctor, and clang flags as @@ -58,13 +92,15 @@ void SubscriptionThreadDispatcher::setOnAudioFrameCallback( const std::string &participant_identity, TrackSource source, AudioFrameCallback callback, const AudioStream::Options &opts) { const CallbackKey key{participant_identity, source, ""}; - const std::scoped_lock lock(lock_); - const bool replacing = audio_callbacks_.find(key) != audio_callbacks_.end(); - audio_callbacks_[key] = RegisteredAudioCallback{std::move(callback), opts}; + const std::scoped_lock lock(impl_->lock_); + const bool replacing = + impl_->audio_callbacks_.find(key) != impl_->audio_callbacks_.end(); + impl_->audio_callbacks_[key] = + RegisteredAudioCallback{std::move(callback), opts}; LK_LOG_DEBUG("Registered audio frame callback for participant={} source={} " "replacing_existing={} total_audio_callbacks={}", participant_identity, static_cast(source), replacing, - audio_callbacks_.size()); + impl_->audio_callbacks_.size()); } void SubscriptionThreadDispatcher::setOnAudioFrameCallback( @@ -72,22 +108,26 @@ void SubscriptionThreadDispatcher::setOnAudioFrameCallback( AudioFrameCallback callback, const AudioStream::Options &opts) { const CallbackKey key{participant_identity, TrackSource::SOURCE_UNKNOWN, track_name}; - const std::scoped_lock lock(lock_); - const bool replacing = audio_callbacks_.find(key) != audio_callbacks_.end(); - audio_callbacks_[key] = RegisteredAudioCallback{std::move(callback), opts}; + const std::scoped_lock lock(impl_->lock_); + const bool replacing = + impl_->audio_callbacks_.find(key) != impl_->audio_callbacks_.end(); + impl_->audio_callbacks_[key] = + RegisteredAudioCallback{std::move(callback), opts}; LK_LOG_DEBUG( "Registered audio frame callback for participant={} track_name={} " "replacing_existing={} total_audio_callbacks={}", - participant_identity, track_name, replacing, audio_callbacks_.size()); + participant_identity, track_name, replacing, + impl_->audio_callbacks_.size()); } void SubscriptionThreadDispatcher::setOnVideoFrameCallback( const std::string &participant_identity, TrackSource source, VideoFrameCallback callback, const VideoStream::Options &opts) { const CallbackKey key{participant_identity, source, ""}; - const std::scoped_lock lock(lock_); - const bool replacing = video_callbacks_.find(key) != video_callbacks_.end(); - video_callbacks_[key] = RegisteredVideoCallback{ + const std::scoped_lock lock(impl_->lock_); + const bool replacing = + impl_->video_callbacks_.find(key) != impl_->video_callbacks_.end(); + impl_->video_callbacks_[key] = RegisteredVideoCallback{ std::move(callback), VideoFrameEventCallback{}, opts, @@ -95,7 +135,7 @@ void SubscriptionThreadDispatcher::setOnVideoFrameCallback( LK_LOG_DEBUG("Registered legacy video frame callback for participant={} " "source={} replacing_existing={} total_video_callbacks={}", participant_identity, static_cast(source), replacing, - video_callbacks_.size()); + impl_->video_callbacks_.size()); } void SubscriptionThreadDispatcher::setOnVideoFrameEventCallback( @@ -103,9 +143,10 @@ void SubscriptionThreadDispatcher::setOnVideoFrameEventCallback( VideoFrameEventCallback callback, const VideoStream::Options &opts) { const CallbackKey key{participant_identity, TrackSource::SOURCE_UNKNOWN, track_name}; - const std::scoped_lock lock(lock_); - const bool replacing = video_callbacks_.find(key) != video_callbacks_.end(); - video_callbacks_[key] = RegisteredVideoCallback{ + const std::scoped_lock lock(impl_->lock_); + const bool replacing = + impl_->video_callbacks_.find(key) != impl_->video_callbacks_.end(); + impl_->video_callbacks_[key] = RegisteredVideoCallback{ VideoFrameCallback{}, std::move(callback), opts, @@ -113,7 +154,8 @@ void SubscriptionThreadDispatcher::setOnVideoFrameEventCallback( LK_LOG_DEBUG( "Registered video frame event callback for participant={} track_name={} " "replacing_existing={} total_video_callbacks={}", - participant_identity, track_name, replacing, video_callbacks_.size()); + participant_identity, track_name, replacing, + impl_->video_callbacks_.size()); } void SubscriptionThreadDispatcher::setOnVideoFrameCallback( @@ -121,9 +163,10 @@ void SubscriptionThreadDispatcher::setOnVideoFrameCallback( VideoFrameCallback callback, const VideoStream::Options &opts) { const CallbackKey key{participant_identity, TrackSource::SOURCE_UNKNOWN, track_name}; - const std::scoped_lock lock(lock_); - const bool replacing = video_callbacks_.find(key) != video_callbacks_.end(); - video_callbacks_[key] = RegisteredVideoCallback{ + const std::scoped_lock lock(impl_->lock_); + const bool replacing = + impl_->video_callbacks_.find(key) != impl_->video_callbacks_.end(); + impl_->video_callbacks_[key] = RegisteredVideoCallback{ std::move(callback), VideoFrameEventCallback{}, opts, @@ -131,7 +174,8 @@ void SubscriptionThreadDispatcher::setOnVideoFrameCallback( LK_LOG_DEBUG( "Registered video frame callback for participant={} track_name={} " "replacing_existing={} total_video_callbacks={}", - participant_identity, track_name, replacing, video_callbacks_.size()); + participant_identity, track_name, replacing, + impl_->video_callbacks_.size()); } void SubscriptionThreadDispatcher::clearOnAudioFrameCallback( @@ -140,14 +184,14 @@ void SubscriptionThreadDispatcher::clearOnAudioFrameCallback( std::thread old_thread; bool removed_callback = false; { - const std::scoped_lock lock(lock_); - removed_callback = audio_callbacks_.erase(key) > 0; + const std::scoped_lock lock(impl_->lock_); + removed_callback = impl_->audio_callbacks_.erase(key) > 0; old_thread = extractReaderThreadLocked(key); LK_LOG_DEBUG( "Clearing audio frame callback for participant={} source={} " "removed_callback={} stopped_reader={} remaining_audio_callbacks={}", participant_identity, static_cast(source), removed_callback, - old_thread.joinable(), audio_callbacks_.size()); + old_thread.joinable(), impl_->audio_callbacks_.size()); } if (old_thread.joinable()) { old_thread.join(); @@ -161,14 +205,14 @@ void SubscriptionThreadDispatcher::clearOnAudioFrameCallback( std::thread old_thread; bool removed_callback = false; { - const std::scoped_lock lock(lock_); - removed_callback = audio_callbacks_.erase(key) > 0; + const std::scoped_lock lock(impl_->lock_); + removed_callback = impl_->audio_callbacks_.erase(key) > 0; old_thread = extractReaderThreadLocked(key); LK_LOG_DEBUG( "Clearing audio frame callback for participant={} track_name={} " "removed_callback={} stopped_reader={} remaining_audio_callbacks={}", participant_identity, track_name, removed_callback, - old_thread.joinable(), audio_callbacks_.size()); + old_thread.joinable(), impl_->audio_callbacks_.size()); } if (old_thread.joinable()) { old_thread.join(); @@ -181,14 +225,14 @@ void SubscriptionThreadDispatcher::clearOnVideoFrameCallback( std::thread old_thread; bool removed_callback = false; { - const std::scoped_lock lock(lock_); - removed_callback = video_callbacks_.erase(key) > 0; + const std::scoped_lock lock(impl_->lock_); + removed_callback = impl_->video_callbacks_.erase(key) > 0; old_thread = extractReaderThreadLocked(key); LK_LOG_DEBUG( "Clearing video frame callback for participant={} source={} " "removed_callback={} stopped_reader={} remaining_video_callbacks={}", participant_identity, static_cast(source), removed_callback, - old_thread.joinable(), video_callbacks_.size()); + old_thread.joinable(), impl_->video_callbacks_.size()); } if (old_thread.joinable()) { old_thread.join(); @@ -202,14 +246,14 @@ void SubscriptionThreadDispatcher::clearOnVideoFrameCallback( std::thread old_thread; bool removed_callback = false; { - const std::scoped_lock lock(lock_); - removed_callback = video_callbacks_.erase(key) > 0; + const std::scoped_lock lock(impl_->lock_); + removed_callback = impl_->video_callbacks_.erase(key) > 0; old_thread = extractReaderThreadLocked(key); LK_LOG_DEBUG( "Clearing video frame callback for participant={} track_name={} " "removed_callback={} stopped_reader={} remaining_video_callbacks={}", participant_identity, track_name, removed_callback, - old_thread.joinable(), video_callbacks_.size()); + old_thread.joinable(), impl_->video_callbacks_.size()); } if (old_thread.joinable()) { old_thread.join(); @@ -236,11 +280,11 @@ void SubscriptionThreadDispatcher::handleTrackSubscribed( const CallbackKey fallback_key{participant_identity, source, ""}; std::thread old_thread; { - const std::scoped_lock lock(lock_); + const std::scoped_lock lock(impl_->lock_); if ((track->kind() == TrackKind::KIND_AUDIO && - audio_callbacks_.find(key) == audio_callbacks_.end()) || + impl_->audio_callbacks_.find(key) == impl_->audio_callbacks_.end()) || (track->kind() == TrackKind::KIND_VIDEO && - video_callbacks_.find(key) == video_callbacks_.end())) { + impl_->video_callbacks_.find(key) == impl_->video_callbacks_.end())) { key = fallback_key; } old_thread = startReaderLocked(key, track); @@ -259,7 +303,7 @@ void SubscriptionThreadDispatcher::handleTrackUnsubscribed( std::thread old_thread; std::thread fallback_old_thread; { - const std::scoped_lock lock(lock_); + const std::scoped_lock lock(impl_->lock_); old_thread = extractReaderThreadLocked(key); fallback_old_thread = extractReaderThreadLocked(fallback_key); LK_LOG_DEBUG("Handling unsubscribed track for participant={} source={} " @@ -285,15 +329,16 @@ DataFrameCallbackId SubscriptionThreadDispatcher::addOnDataFrameCallback( std::thread old_thread; DataFrameCallbackId id; { - const std::scoped_lock lock(lock_); - id = next_data_callback_id_++; + const std::scoped_lock lock(impl_->lock_); + id = impl_->next_data_callback_id_++; const DataCallbackKey key{participant_identity, track_name}; - data_callbacks_[id] = RegisteredDataCallback{key, std::move(callback)}; + impl_->data_callbacks_[id] = + RegisteredDataCallback{key, std::move(callback)}; - auto track_it = remote_data_tracks_.find(key); - if (track_it != remote_data_tracks_.end()) { + auto track_it = impl_->remote_data_tracks_.find(key); + if (track_it != impl_->remote_data_tracks_.end()) { old_thread = startDataReaderLocked(id, key, track_it->second, - data_callbacks_[id].callback); + impl_->data_callbacks_[id].callback); } } if (old_thread.joinable()) { @@ -306,8 +351,8 @@ void SubscriptionThreadDispatcher::removeOnDataFrameCallback( DataFrameCallbackId id) { std::thread old_thread; { - const std::scoped_lock lock(lock_); - data_callbacks_.erase(id); + const std::scoped_lock lock(impl_->lock_); + impl_->data_callbacks_.erase(id); old_thread = extractDataReaderThreadLocked(id); } if (old_thread.joinable()) { @@ -328,11 +373,11 @@ void SubscriptionThreadDispatcher::handleDataTrackPublished( std::vector old_threads; { - const std::scoped_lock lock(lock_); + const std::scoped_lock lock(impl_->lock_); const DataCallbackKey key{track->publisherIdentity(), track->info().name}; - remote_data_tracks_[key] = track; + impl_->remote_data_tracks_[key] = track; - for (auto &[id, reg] : data_callbacks_) { + for (auto &[id, reg] : impl_->data_callbacks_) { if (reg.key == key) { auto t = startDataReaderLocked(id, key, track, reg.callback); if (t.joinable()) { @@ -352,9 +397,9 @@ void SubscriptionThreadDispatcher::handleDataTrackUnpublished( std::vector old_threads; { - const std::scoped_lock lock(lock_); - for (auto it = active_data_readers_.begin(); - it != active_data_readers_.end();) { + const std::scoped_lock lock(impl_->lock_); + for (auto it = impl_->active_data_readers_.begin(); + it != impl_->active_data_readers_.end();) { auto &reader = it->second; if (reader->remote_track && reader->remote_track->info().sid == sid) { { @@ -366,15 +411,15 @@ void SubscriptionThreadDispatcher::handleDataTrackUnpublished( if (reader->thread.joinable()) { old_threads.push_back(std::move(reader->thread)); } - it = active_data_readers_.erase(it); + it = impl_->active_data_readers_.erase(it); } else { ++it; } } - for (auto it = remote_data_tracks_.begin(); it != remote_data_tracks_.end(); - ++it) { + for (auto it = impl_->remote_data_tracks_.begin(); + it != impl_->remote_data_tracks_.end(); ++it) { if (it->second && it->second->info().sid == sid) { - remote_data_tracks_.erase(it); + impl_->remote_data_tracks_.erase(it); break; } } @@ -387,15 +432,16 @@ void SubscriptionThreadDispatcher::handleDataTrackUnpublished( void SubscriptionThreadDispatcher::stopAll() { std::vector threads; { - const std::scoped_lock lock(lock_); + const std::scoped_lock lock(impl_->lock_); LK_LOG_DEBUG("Stopping all subscription readers active_readers={} " "active_data_readers={} audio_callbacks={} " "video_callbacks={} data_callbacks={}", - active_readers_.size(), active_data_readers_.size(), - audio_callbacks_.size(), video_callbacks_.size(), - data_callbacks_.size()); + impl_->active_readers_.size(), + impl_->active_data_readers_.size(), + impl_->audio_callbacks_.size(), impl_->video_callbacks_.size(), + impl_->data_callbacks_.size()); - for (auto &[key, reader] : active_readers_) { + for (auto &[key, reader] : impl_->active_readers_) { if (reader.audio_stream) { reader.audio_stream->close(); } @@ -406,11 +452,11 @@ void SubscriptionThreadDispatcher::stopAll() { threads.push_back(std::move(reader.thread)); } } - active_readers_.clear(); - audio_callbacks_.clear(); - video_callbacks_.clear(); + impl_->active_readers_.clear(); + impl_->audio_callbacks_.clear(); + impl_->video_callbacks_.clear(); - for (auto &[id, reader] : active_data_readers_) { + for (auto &[id, reader] : impl_->active_data_readers_) { { const std::scoped_lock sub_guard(reader->sub_mutex); if (reader->stream) { @@ -421,9 +467,9 @@ void SubscriptionThreadDispatcher::stopAll() { threads.push_back(std::move(reader->thread)); } } - active_data_readers_.clear(); - data_callbacks_.clear(); - remote_data_tracks_.clear(); + impl_->active_data_readers_.clear(); + impl_->data_callbacks_.clear(); + impl_->remote_data_tracks_.clear(); } for (auto &thread : threads) { thread.join(); @@ -431,10 +477,52 @@ void SubscriptionThreadDispatcher::stopAll() { LK_LOG_DEBUG("Stopped {} subscription reader threads", threads.size()); } +std::size_t SubscriptionThreadDispatcher::audioCallbackCountForTest() const { + const std::scoped_lock lock(impl_->lock_); + return impl_->audio_callbacks_.size(); +} + +std::size_t SubscriptionThreadDispatcher::videoCallbackCountForTest() const { + const std::scoped_lock lock(impl_->lock_); + return impl_->video_callbacks_.size(); +} + +std::size_t SubscriptionThreadDispatcher::activeReaderCountForTest() const { + const std::scoped_lock lock(impl_->lock_); + return impl_->active_readers_.size(); +} + +std::size_t SubscriptionThreadDispatcher::dataCallbackCountForTest() const { + const std::scoped_lock lock(impl_->lock_); + return impl_->data_callbacks_.size(); +} + +std::size_t SubscriptionThreadDispatcher::activeDataReaderCountForTest() const { + const std::scoped_lock lock(impl_->lock_); + return impl_->active_data_readers_.size(); +} + +std::size_t SubscriptionThreadDispatcher::remoteDataTrackCountForTest() const { + const std::scoped_lock lock(impl_->lock_); + return impl_->remote_data_tracks_.size(); +} + +bool SubscriptionThreadDispatcher::hasAudioCallbackForTest( + const CallbackKey &key) const { + const std::scoped_lock lock(impl_->lock_); + return impl_->audio_callbacks_.find(key) != impl_->audio_callbacks_.end(); +} + +bool SubscriptionThreadDispatcher::hasVideoCallbackForTest( + const CallbackKey &key) const { + const std::scoped_lock lock(impl_->lock_); + return impl_->video_callbacks_.find(key) != impl_->video_callbacks_.end(); +} + std::thread SubscriptionThreadDispatcher::extractReaderThreadLocked( const CallbackKey &key) { - auto it = active_readers_.find(key); - if (it == active_readers_.end()) { + auto it = impl_->active_readers_.find(key); + if (it == impl_->active_readers_.end()) { LK_LOG_TRACE("No active reader to extract for participant={} source={} " "track_name={}", key.participant_identity, static_cast(key.source), @@ -446,8 +534,8 @@ std::thread SubscriptionThreadDispatcher::extractReaderThreadLocked( "track_name={}", key.participant_identity, static_cast(key.source), key.track_name); - ActiveReader reader = std::move(it->second); - active_readers_.erase(it); + Impl::ActiveReader reader = std::move(it->second); + impl_->active_readers_.erase(it); if (reader.audio_stream) { reader.audio_stream->close(); @@ -461,8 +549,8 @@ std::thread SubscriptionThreadDispatcher::extractReaderThreadLocked( std::thread SubscriptionThreadDispatcher::startReaderLocked( const CallbackKey &key, const std::shared_ptr &track) { if (track->kind() == TrackKind::KIND_AUDIO) { - auto it = audio_callbacks_.find(key); - if (it == audio_callbacks_.end()) { + auto it = impl_->audio_callbacks_.find(key); + if (it == impl_->audio_callbacks_.end()) { LK_LOG_TRACE("Skipping audio reader start for participant={} source={} " "because no audio callback is registered", key.participant_identity, static_cast(key.source)); @@ -472,8 +560,8 @@ std::thread SubscriptionThreadDispatcher::startReaderLocked( it->second.options); } if (track->kind() == TrackKind::KIND_VIDEO) { - auto it = video_callbacks_.find(key); - if (it == video_callbacks_.end()) { + auto it = impl_->video_callbacks_.find(key); + if (it == impl_->video_callbacks_.end()) { LK_LOG_TRACE("Skipping video reader start for participant={} source={} " "because no video callback is registered", key.participant_identity, static_cast(key.source)); @@ -503,7 +591,7 @@ std::thread SubscriptionThreadDispatcher::startAudioReaderLocked( key.participant_identity, static_cast(key.source)); auto old_thread = extractReaderThreadLocked(key); - if (static_cast(active_readers_.size()) >= kMaxActiveReaders) { + if (static_cast(impl_->active_readers_.size()) >= kMaxActiveReaders) { LK_LOG_ERROR( "Cannot start audio reader for {} source={}: active reader limit ({}) " "reached", @@ -519,7 +607,7 @@ std::thread SubscriptionThreadDispatcher::startAudioReaderLocked( return old_thread; } - ActiveReader reader; + Impl::ActiveReader reader; reader.audio_stream = stream; const std::string participant_identity = key.participant_identity; const TrackSource source = key.source; @@ -551,11 +639,11 @@ std::thread SubscriptionThreadDispatcher::startAudioReaderLocked( } }); // NOLINTEND(bugprone-lambda-function-name,bugprone-exception-escape) - active_readers_[key] = std::move(reader); + impl_->active_readers_[key] = std::move(reader); LK_LOG_DEBUG("Started audio reader for participant={} source={} " "active_readers={}", key.participant_identity, static_cast(key.source), - active_readers_.size()); + impl_->active_readers_.size()); return old_thread; } @@ -566,7 +654,7 @@ std::thread SubscriptionThreadDispatcher::startVideoReaderLocked( key.participant_identity, static_cast(key.source)); auto old_thread = extractReaderThreadLocked(key); - if (static_cast(active_readers_.size()) >= kMaxActiveReaders) { + if (static_cast(impl_->active_readers_.size()) >= kMaxActiveReaders) { LK_LOG_ERROR( "Cannot start video reader for {} source={}: active reader limit ({}) " "reached", @@ -582,7 +670,7 @@ std::thread SubscriptionThreadDispatcher::startVideoReaderLocked( return old_thread; } - ActiveReader reader; + Impl::ActiveReader reader; reader.video_stream = stream; auto legacy_cb = callback.legacy_callback; auto event_cb = callback.event_callback; @@ -619,11 +707,11 @@ std::thread SubscriptionThreadDispatcher::startVideoReaderLocked( } }); // NOLINTEND(bugprone-lambda-function-name,bugprone-exception-escape) - active_readers_[key] = std::move(reader); + impl_->active_readers_[key] = std::move(reader); LK_LOG_DEBUG("Started video reader for participant={} source={} " "active_readers={}", key.participant_identity, static_cast(key.source), - active_readers_.size()); + impl_->active_readers_.size()); return old_thread; } @@ -633,12 +721,12 @@ std::thread SubscriptionThreadDispatcher::startVideoReaderLocked( std::thread SubscriptionThreadDispatcher::extractDataReaderThreadLocked( DataFrameCallbackId id) { - auto it = active_data_readers_.find(id); - if (it == active_data_readers_.end()) { + auto it = impl_->active_data_readers_.find(id); + if (it == impl_->active_data_readers_.end()) { return {}; } auto reader = std::move(it->second); - active_data_readers_.erase(it); + impl_->active_data_readers_.erase(it); { const std::scoped_lock guard(reader->sub_mutex); if (reader->stream) { @@ -650,14 +738,14 @@ std::thread SubscriptionThreadDispatcher::extractDataReaderThreadLocked( std::thread SubscriptionThreadDispatcher::extractDataReaderThreadLocked( const DataCallbackKey &key) { - for (auto it = active_data_readers_.begin(); it != active_data_readers_.end(); - ++it) { + for (auto it = impl_->active_data_readers_.begin(); + it != impl_->active_data_readers_.end(); ++it) { if (it->second && it->second->remote_track && it->second->remote_track->publisherIdentity() == key.participant_identity && it->second->remote_track->info().name == key.track_name) { auto reader = std::move(it->second); - active_data_readers_.erase(it); + impl_->active_data_readers_.erase(it); { const std::scoped_lock guard(reader->sub_mutex); if (reader->stream) { @@ -676,8 +764,8 @@ std::thread SubscriptionThreadDispatcher::startDataReaderLocked( const DataFrameCallback &cb) { auto old_thread = extractDataReaderThreadLocked(id); - const int total_active = static_cast(active_readers_.size()) + - static_cast(active_data_readers_.size()); + const int total_active = static_cast(impl_->active_readers_.size()) + + static_cast(impl_->active_data_readers_.size()); if (total_active >= kMaxActiveReaders) { LK_LOG_ERROR("Cannot start data reader for {} track={}: active reader " "limit ({}) reached", @@ -688,7 +776,7 @@ std::thread SubscriptionThreadDispatcher::startDataReaderLocked( LK_LOG_INFO("Starting data reader for \"{}\" track=\"{}\"", key.participant_identity, key.track_name); - auto reader = std::make_shared(); + auto reader = std::make_shared(); reader->remote_track = track; auto identity = key.participant_identity; auto track_name = key.track_name; @@ -728,7 +816,7 @@ std::thread SubscriptionThreadDispatcher::startDataReaderLocked( track_name); }); // NOLINTEND(bugprone-lambda-function-name) - active_data_readers_[id] = reader; + impl_->active_data_readers_[id] = reader; return old_thread; } diff --git a/src/tests/unit/test_subscription_thread_dispatcher.cpp b/src/tests/unit/test_subscription_thread_dispatcher.cpp index a3864c64..e530c522 100644 --- a/src/tests/unit/test_subscription_thread_dispatcher.cpp +++ b/src/tests/unit/test_subscription_thread_dispatcher.cpp @@ -39,23 +39,37 @@ class SubscriptionThreadDispatcherTest : public ::testing::Test { using DataCallbackKey = SubscriptionThreadDispatcher::DataCallbackKey; using DataCallbackKeyHash = SubscriptionThreadDispatcher::DataCallbackKeyHash; - static auto &audioCallbacks(SubscriptionThreadDispatcher &dispatcher) { - return dispatcher.audio_callbacks_; + static std::size_t + audioCallbackCount(const SubscriptionThreadDispatcher &dispatcher) { + return dispatcher.audioCallbackCountForTest(); } - static auto &videoCallbacks(SubscriptionThreadDispatcher &dispatcher) { - return dispatcher.video_callbacks_; + static std::size_t + videoCallbackCount(const SubscriptionThreadDispatcher &dispatcher) { + return dispatcher.videoCallbackCountForTest(); } - static auto &activeReaders(SubscriptionThreadDispatcher &dispatcher) { - return dispatcher.active_readers_; + static std::size_t + activeReaderCount(const SubscriptionThreadDispatcher &dispatcher) { + return dispatcher.activeReaderCountForTest(); } - static auto &dataCallbacks(SubscriptionThreadDispatcher &dispatcher) { - return dispatcher.data_callbacks_; + static std::size_t + dataCallbackCount(const SubscriptionThreadDispatcher &dispatcher) { + return dispatcher.dataCallbackCountForTest(); } - static auto &activeDataReaders(SubscriptionThreadDispatcher &dispatcher) { - return dispatcher.active_data_readers_; + static std::size_t + activeDataReaderCount(const SubscriptionThreadDispatcher &dispatcher) { + return dispatcher.activeDataReaderCountForTest(); } - static auto &remoteDataTracks(SubscriptionThreadDispatcher &dispatcher) { - return dispatcher.remote_data_tracks_; + static std::size_t + remoteDataTrackCount(const SubscriptionThreadDispatcher &dispatcher) { + return dispatcher.remoteDataTrackCountForTest(); + } + static bool hasAudioCallback(const SubscriptionThreadDispatcher &dispatcher, + const CallbackKey &key) { + return dispatcher.hasAudioCallbackForTest(key); + } + static bool hasVideoCallback(const SubscriptionThreadDispatcher &dispatcher, + const CallbackKey &key) { + return dispatcher.hasVideoCallbackForTest(key); } static int maxActiveReaders() { return SubscriptionThreadDispatcher::kMaxActiveReaders; @@ -166,7 +180,7 @@ TEST_F(SubscriptionThreadDispatcherTest, SetAudioCallbackStoresRegistration) { dispatcher.setOnAudioFrameCallback("alice", TrackSource::SOURCE_MICROPHONE, [](const AudioFrame &) {}); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(audioCallbackCount(dispatcher), 1u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -175,11 +189,10 @@ TEST_F(SubscriptionThreadDispatcherTest, dispatcher.setOnAudioFrameCallback("alice", "mic-main", [](const AudioFrame &) {}); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u); - EXPECT_EQ( - audioCallbacks(dispatcher) - .count(CallbackKey{"alice", TrackSource::SOURCE_UNKNOWN, "mic-main"}), - 1u); + EXPECT_EQ(audioCallbackCount(dispatcher), 1u); + EXPECT_TRUE(hasAudioCallback( + dispatcher, + CallbackKey{"alice", TrackSource::SOURCE_UNKNOWN, "mic-main"})); } TEST_F(SubscriptionThreadDispatcherTest, SetVideoCallbackStoresRegistration) { @@ -187,7 +200,7 @@ TEST_F(SubscriptionThreadDispatcherTest, SetVideoCallbackStoresRegistration) { dispatcher.setOnVideoFrameCallback("alice", TrackSource::SOURCE_CAMERA, [](const VideoFrame &, std::int64_t) {}); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(videoCallbackCount(dispatcher), 1u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -196,11 +209,10 @@ TEST_F(SubscriptionThreadDispatcherTest, dispatcher.setOnVideoFrameCallback("alice", "cam-main", [](const VideoFrame &, std::int64_t) {}); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 1u); - EXPECT_EQ( - videoCallbacks(dispatcher) - .count(CallbackKey{"alice", TrackSource::SOURCE_UNKNOWN, "cam-main"}), - 1u); + EXPECT_EQ(videoCallbackCount(dispatcher), 1u); + EXPECT_TRUE(hasVideoCallback( + dispatcher, + CallbackKey{"alice", TrackSource::SOURCE_UNKNOWN, "cam-main"})); } TEST_F(SubscriptionThreadDispatcherTest, @@ -208,10 +220,10 @@ TEST_F(SubscriptionThreadDispatcherTest, SubscriptionThreadDispatcher dispatcher; dispatcher.setOnAudioFrameCallback("alice", TrackSource::SOURCE_MICROPHONE, [](const AudioFrame &) {}); - ASSERT_EQ(audioCallbacks(dispatcher).size(), 1u); + ASSERT_EQ(audioCallbackCount(dispatcher), 1u); dispatcher.clearOnAudioFrameCallback("alice", TrackSource::SOURCE_MICROPHONE); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 0u); + EXPECT_EQ(audioCallbackCount(dispatcher), 0u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -219,10 +231,10 @@ TEST_F(SubscriptionThreadDispatcherTest, SubscriptionThreadDispatcher dispatcher; dispatcher.setOnAudioFrameCallback("alice", "mic-main", [](const AudioFrame &) {}); - ASSERT_EQ(audioCallbacks(dispatcher).size(), 1u); + ASSERT_EQ(audioCallbackCount(dispatcher), 1u); dispatcher.clearOnAudioFrameCallback("alice", "mic-main"); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 0u); + EXPECT_EQ(audioCallbackCount(dispatcher), 0u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -230,10 +242,10 @@ TEST_F(SubscriptionThreadDispatcherTest, SubscriptionThreadDispatcher dispatcher; dispatcher.setOnVideoFrameCallback("alice", TrackSource::SOURCE_CAMERA, [](const VideoFrame &, std::int64_t) {}); - ASSERT_EQ(videoCallbacks(dispatcher).size(), 1u); + ASSERT_EQ(videoCallbackCount(dispatcher), 1u); dispatcher.clearOnVideoFrameCallback("alice", TrackSource::SOURCE_CAMERA); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 0u); + EXPECT_EQ(videoCallbackCount(dispatcher), 0u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -241,10 +253,10 @@ TEST_F(SubscriptionThreadDispatcherTest, SubscriptionThreadDispatcher dispatcher; dispatcher.setOnVideoFrameCallback("alice", "cam-main", [](const VideoFrame &, std::int64_t) {}); - ASSERT_EQ(videoCallbacks(dispatcher).size(), 1u); + ASSERT_EQ(videoCallbackCount(dispatcher), 1u); dispatcher.clearOnVideoFrameCallback("alice", "cam-main"); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 0u); + EXPECT_EQ(videoCallbackCount(dispatcher), 0u); } TEST_F(SubscriptionThreadDispatcherTest, ClearNonExistentCallbackIsNoOp) { @@ -270,7 +282,7 @@ TEST_F(SubscriptionThreadDispatcherTest, "alice", TrackSource::SOURCE_MICROPHONE, [&counter2](const AudioFrame &) { counter2++; }); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u) + EXPECT_EQ(audioCallbackCount(dispatcher), 1u) << "Re-registering with the same key should overwrite, not add"; } @@ -282,7 +294,7 @@ TEST_F(SubscriptionThreadDispatcherTest, dispatcher.setOnVideoFrameCallback("alice", TrackSource::SOURCE_CAMERA, [](const VideoFrame &, std::int64_t) {}); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(videoCallbackCount(dispatcher), 1u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -293,7 +305,7 @@ TEST_F(SubscriptionThreadDispatcherTest, dispatcher.setOnAudioFrameCallback("alice", "mic-main", [](const AudioFrame &) {}); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(audioCallbackCount(dispatcher), 1u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -308,12 +320,12 @@ TEST_F(SubscriptionThreadDispatcherTest, dispatcher.setOnVideoFrameCallback("bob", TrackSource::SOURCE_CAMERA, [](const VideoFrame &, std::int64_t) {}); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 2u); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 2u); + EXPECT_EQ(audioCallbackCount(dispatcher), 2u); + EXPECT_EQ(videoCallbackCount(dispatcher), 2u); dispatcher.clearOnAudioFrameCallback("alice", TrackSource::SOURCE_MICROPHONE); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 2u); + EXPECT_EQ(audioCallbackCount(dispatcher), 1u); + EXPECT_EQ(videoCallbackCount(dispatcher), 2u); } TEST_F(SubscriptionThreadDispatcherTest, ClearingOneSourceDoesNotAffectOther) { @@ -323,13 +335,13 @@ TEST_F(SubscriptionThreadDispatcherTest, ClearingOneSourceDoesNotAffectOther) { dispatcher.setOnAudioFrameCallback("alice", TrackSource::SOURCE_SCREENSHARE_AUDIO, [](const AudioFrame &) {}); - ASSERT_EQ(audioCallbacks(dispatcher).size(), 2u); + ASSERT_EQ(audioCallbackCount(dispatcher), 2u); dispatcher.clearOnAudioFrameCallback("alice", TrackSource::SOURCE_MICROPHONE); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(audioCallbackCount(dispatcher), 1u); CallbackKey remaining{"alice", TrackSource::SOURCE_SCREENSHARE_AUDIO, ""}; - EXPECT_EQ(audioCallbacks(dispatcher).count(remaining), 1u); + EXPECT_TRUE(hasAudioCallback(dispatcher, remaining)); } TEST_F(SubscriptionThreadDispatcherTest, @@ -339,14 +351,12 @@ TEST_F(SubscriptionThreadDispatcherTest, [](const AudioFrame &) {}); dispatcher.setOnAudioFrameCallback("alice", "mic-main", [](const AudioFrame &) {}); - ASSERT_EQ(audioCallbacks(dispatcher).size(), 2u); + ASSERT_EQ(audioCallbackCount(dispatcher), 2u); dispatcher.clearOnAudioFrameCallback("alice", "mic-main"); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u); - EXPECT_EQ( - audioCallbacks(dispatcher) - .count(CallbackKey{"alice", TrackSource::SOURCE_MICROPHONE, ""}), - 1u); + EXPECT_EQ(audioCallbackCount(dispatcher), 1u); + EXPECT_TRUE(hasAudioCallback( + dispatcher, CallbackKey{"alice", TrackSource::SOURCE_MICROPHONE, ""})); } // ============================================================================ @@ -355,7 +365,7 @@ TEST_F(SubscriptionThreadDispatcherTest, TEST_F(SubscriptionThreadDispatcherTest, NoActiveReadersInitially) { SubscriptionThreadDispatcher dispatcher; - EXPECT_TRUE(activeReaders(dispatcher).empty()); + EXPECT_EQ(activeReaderCount(dispatcher), 0u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -363,7 +373,7 @@ TEST_F(SubscriptionThreadDispatcherTest, SubscriptionThreadDispatcher dispatcher; dispatcher.setOnAudioFrameCallback("alice", TrackSource::SOURCE_MICROPHONE, [](const AudioFrame &) {}); - EXPECT_TRUE(activeReaders(dispatcher).empty()) + EXPECT_EQ(activeReaderCount(dispatcher), 0u) << "Registering a callback without a subscribed track should not spawn " "readers"; } @@ -422,7 +432,7 @@ TEST_F(SubscriptionThreadDispatcherTest, ConcurrentRegistrationDoesNotCrash) { thread.join(); } - EXPECT_TRUE(audioCallbacks(dispatcher).empty()) + EXPECT_EQ(audioCallbackCount(dispatcher), 0u) << "All callbacks should be cleared after concurrent register/clear"; } @@ -451,8 +461,8 @@ TEST_F(SubscriptionThreadDispatcherTest, thread.join(); } - EXPECT_EQ(audioCallbacks(dispatcher).size(), static_cast(kThreads)); - EXPECT_EQ(videoCallbacks(dispatcher).size(), static_cast(kThreads)); + EXPECT_EQ(audioCallbackCount(dispatcher), static_cast(kThreads)); + EXPECT_EQ(videoCallbackCount(dispatcher), static_cast(kThreads)); } // ============================================================================ @@ -469,14 +479,14 @@ TEST_F(SubscriptionThreadDispatcherTest, ManyDistinctCallbacksCanBeRegistered) { [](const AudioFrame &) {}); } - EXPECT_EQ(audioCallbacks(dispatcher).size(), static_cast(kCount)); + EXPECT_EQ(audioCallbackCount(dispatcher), static_cast(kCount)); for (int i = 0; i < kCount; ++i) { dispatcher.clearOnAudioFrameCallback("participant-" + std::to_string(i), TrackSource::SOURCE_MICROPHONE); } - EXPECT_EQ(audioCallbacks(dispatcher).size(), 0u); + EXPECT_EQ(audioCallbackCount(dispatcher), 0u); } // ============================================================================ @@ -563,14 +573,14 @@ TEST_F(SubscriptionThreadDispatcherTest, [](const std::vector &, std::optional) {}); EXPECT_EQ(id, 0u); - EXPECT_EQ(dataCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(dataCallbackCount(dispatcher), 1u); // Add a second one to confirm size and IDs are correct auto id2 = dispatcher.addOnDataFrameCallback( "alice", "my-track", [](const std::vector &, std::optional) {}); EXPECT_EQ(id2, 1u); - EXPECT_EQ(dataCallbacks(dispatcher).size(), 2u); + EXPECT_EQ(dataCallbackCount(dispatcher), 2u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -579,10 +589,10 @@ TEST_F(SubscriptionThreadDispatcherTest, auto id = dispatcher.addOnDataFrameCallback( "alice", "my-track", [](const std::vector &, std::optional) {}); - ASSERT_EQ(dataCallbacks(dispatcher).size(), 1u); + ASSERT_EQ(dataCallbackCount(dispatcher), 1u); dispatcher.removeOnDataFrameCallback(id); - EXPECT_EQ(dataCallbacks(dispatcher).size(), 0u); + EXPECT_EQ(dataCallbackCount(dispatcher), 0u); } TEST_F(SubscriptionThreadDispatcherTest, RemoveNonExistentDataCallbackIsNoOp) { @@ -599,10 +609,10 @@ TEST_F(SubscriptionThreadDispatcherTest, auto id2 = dispatcher.addOnDataFrameCallback("alice", "track", cb); EXPECT_NE(id1, id2); - EXPECT_EQ(dataCallbacks(dispatcher).size(), 2u); + EXPECT_EQ(dataCallbackCount(dispatcher), 2u); dispatcher.removeOnDataFrameCallback(id1); - EXPECT_EQ(dataCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(dataCallbackCount(dispatcher), 1u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -624,7 +634,7 @@ TEST_F(SubscriptionThreadDispatcherTest, TEST_F(SubscriptionThreadDispatcherTest, NoActiveDataReadersInitially) { SubscriptionThreadDispatcher dispatcher; - EXPECT_TRUE(activeDataReaders(dispatcher).empty()); + EXPECT_EQ(activeDataReaderCount(dispatcher), 0u); } TEST_F(SubscriptionThreadDispatcherTest, @@ -633,14 +643,14 @@ TEST_F(SubscriptionThreadDispatcherTest, dispatcher.addOnDataFrameCallback( "alice", "my-track", [](const std::vector &, std::optional) {}); - EXPECT_TRUE(activeDataReaders(dispatcher).empty()) + EXPECT_EQ(activeDataReaderCount(dispatcher), 0u) << "Registering a callback without a published track should not spawn " "readers"; } TEST_F(SubscriptionThreadDispatcherTest, NoRemoteDataTracksInitially) { SubscriptionThreadDispatcher dispatcher; - EXPECT_TRUE(remoteDataTracks(dispatcher).empty()); + EXPECT_EQ(remoteDataTrackCount(dispatcher), 0u); } // ============================================================================ @@ -686,9 +696,9 @@ TEST_F(SubscriptionThreadDispatcherTest, "alice", "data-track", [](const std::vector &, std::optional) {}); - EXPECT_EQ(audioCallbacks(dispatcher).size(), 1u); - EXPECT_EQ(videoCallbacks(dispatcher).size(), 1u); - EXPECT_EQ(dataCallbacks(dispatcher).size(), 1u); + EXPECT_EQ(audioCallbackCount(dispatcher), 1u); + EXPECT_EQ(videoCallbackCount(dispatcher), 1u); + EXPECT_EQ(dataCallbackCount(dispatcher), 1u); } TEST_F(SubscriptionThreadDispatcherTest, StopAllClearsDataCallbacksAndReaders) { @@ -702,9 +712,9 @@ TEST_F(SubscriptionThreadDispatcherTest, StopAllClearsDataCallbacksAndReaders) { dispatcher.stopAll(); - EXPECT_EQ(dataCallbacks(dispatcher).size(), 0u); - EXPECT_TRUE(activeDataReaders(dispatcher).empty()); - EXPECT_TRUE(remoteDataTracks(dispatcher).empty()); + EXPECT_EQ(dataCallbackCount(dispatcher), 0u); + EXPECT_EQ(activeDataReaderCount(dispatcher), 0u); + EXPECT_EQ(remoteDataTrackCount(dispatcher), 0u); } // ============================================================================ @@ -736,7 +746,7 @@ TEST_F(SubscriptionThreadDispatcherTest, thread.join(); } - EXPECT_TRUE(dataCallbacks(dispatcher).empty()) + EXPECT_EQ(dataCallbackCount(dispatcher), 0u) << "All data callbacks should be cleared after concurrent " "register/remove"; } diff --git a/src/video_stream.cpp b/src/video_stream.cpp index a3bce022..4e46be0d 100644 --- a/src/video_stream.cpp +++ b/src/video_stream.cpp @@ -16,10 +16,16 @@ #include "livekit/video_stream.h" +#include +#include +#include +#include #include #include "ffi.pb.h" #include "ffi_client.h" +#include "livekit/ffi_handle.h" +#include "livekit/participant.h" #include "livekit/track.h" #include "lk_log.h" #include "video_frame.pb.h" @@ -31,6 +37,150 @@ using proto::FfiEvent; using proto::FfiRequest; using proto::VideoStreamEvent; +struct VideoStream::Impl { + ~Impl() { close(); } + + bool read(VideoFrameEvent &out) { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !queue_.empty() || eof_ || closed_; }); + + if (closed_ || (queue_.empty() && eof_)) { + return false; + } + + out = std::move(queue_.front()); + queue_.pop_front(); + return true; + } + + void close() { + FfiHandle stream_handle; + std::int32_t listener_id = 0; + { + const std::scoped_lock lock(mutex_); + if (closed_) { + return; + } + closed_ = true; + stream_handle = std::move(stream_handle_); + listener_id = listener_id_; + listener_id_ = 0; + } + + if (stream_handle.get() != 0) { + stream_handle.reset(); + } + if (listener_id != 0) { + FfiClient::instance().RemoveListener(listener_id); + } + + cv_.notify_all(); + } + + void initFromTrack(const std::shared_ptr &track, + const Options &options) { + capacity_ = options.capacity; + + listener_id_ = FfiClient::instance().AddListener( + [this](const proto::FfiEvent &e) { this->onFfiEvent(e); }); + + FfiRequest req; + auto *new_video_stream = req.mutable_new_video_stream(); + new_video_stream->set_track_handle( + static_cast(track->ffi_handle_id())); + new_video_stream->set_type(proto::VideoStreamType::VIDEO_STREAM_NATIVE); + new_video_stream->set_normalize_stride(true); + new_video_stream->set_format(toProto(options.format)); + + auto resp = FfiClient::instance().sendRequest(req); + if (!resp.has_new_video_stream()) { + LK_LOG_ERROR("VideoStream::initFromTrack: FFI response missing " + "new_video_stream()"); + throw std::runtime_error("new_video_stream FFI request failed"); + } + const auto &stream = resp.new_video_stream().stream(); + stream_handle_ = FfiHandle(static_cast(stream.handle().id())); + } + + void initFromParticipant(Participant &participant, TrackSource track_source, + const Options &options) { + capacity_ = options.capacity; + + listener_id_ = FfiClient::instance().AddListener( + [this](const FfiEvent &e) { this->onFfiEvent(e); }); + + FfiRequest req; + auto *vs = req.mutable_video_stream_from_participant(); + vs->set_participant_handle(participant.ffiHandleId()); + vs->set_type(proto::VideoStreamType::VIDEO_STREAM_NATIVE); + vs->set_track_source(static_cast(track_source)); + vs->set_normalize_stride(true); + vs->set_format(toProto(options.format)); + + auto resp = FfiClient::instance().sendRequest(req); + const auto &stream = resp.video_stream_from_participant().stream(); + stream_handle_ = FfiHandle(static_cast(stream.handle().id())); + } + + void onFfiEvent(const proto::FfiEvent &event) { + if (event.message_case() != FfiEvent::kVideoStreamEvent) { + return; + } + const auto &vse = event.video_stream_event(); + if (vse.stream_handle() != + static_cast(stream_handle_.get())) { + return; + } + if (vse.has_frame_received()) { + const auto &fr = vse.frame_received(); + VideoFrameEvent ev; + ev.frame = VideoFrame::fromOwnedInfo(fr.buffer()); + ev.timestamp_us = fr.timestamp_us(); + ev.rotation = static_cast(fr.rotation()); + if (fr.has_metadata()) { + ev.metadata = fromProto(fr.metadata()); + } + pushFrame(std::move(ev)); + } else if (vse.has_eos()) { + pushEos(); + } + } + + void pushFrame(VideoFrameEvent &&ev) { + { + const std::scoped_lock lock(mutex_); + if (closed_ || eof_) { + return; + } + if (capacity_ > 0 && queue_.size() >= capacity_) { + queue_.pop_front(); + } + queue_.push_back(std::move(ev)); + } + cv_.notify_one(); + } + + void pushEos() { + { + const std::scoped_lock lock(mutex_); + if (eof_) { + return; + } + eof_ = true; + } + cv_.notify_all(); + } + + mutable std::mutex mutex_; + std::condition_variable cv_; + std::deque queue_; + std::size_t capacity_{0}; + bool eof_{false}; + bool closed_{false}; + FfiHandle stream_handle_; + std::int32_t listener_id_{0}; +}; + std::shared_ptr VideoStream::fromTrack(const std::shared_ptr &track, const Options &options) { @@ -47,197 +197,37 @@ VideoStream::fromParticipant(Participant &participant, TrackSource track_source, return stream; } -VideoStream::~VideoStream() { close(); } - -VideoStream::VideoStream(VideoStream &&other) noexcept { - const std::scoped_lock lock(other.mutex_); - queue_ = std::move(other.queue_); - capacity_ = other.capacity_; - eof_ = other.eof_; - closed_ = other.closed_; - stream_handle_ = std::move(other.stream_handle_); - listener_id_ = other.listener_id_; - - other.listener_id_ = 0; - other.closed_ = true; -} +VideoStream::VideoStream() : impl_(std::make_unique()) {} -VideoStream &VideoStream::operator=(VideoStream &&other) noexcept { - if (this == &other) - return *this; +VideoStream::~VideoStream() = default; - close(); +VideoStream::VideoStream(VideoStream &&other) noexcept = default; - { - const std::scoped_lock lock_this(mutex_); - const std::scoped_lock lock_other(other.mutex_); - - queue_ = std::move(other.queue_); - capacity_ = other.capacity_; - eof_ = other.eof_; - closed_ = other.closed_; - stream_handle_ = std::move(other.stream_handle_); - listener_id_ = other.listener_id_; - - other.listener_id_ = 0; - other.closed_ = true; - } - - return *this; -} +VideoStream &VideoStream::operator=(VideoStream &&other) noexcept = default; // --------------------- Public API --------------------- bool VideoStream::read(VideoFrameEvent &out) { - std::unique_lock lock(mutex_); - - cv_.wait(lock, [this] { return !queue_.empty() || eof_ || closed_; }); - - if (closed_ || (queue_.empty() && eof_)) { - return false; // EOS / closed - } - - out = std::move(queue_.front()); - queue_.pop_front(); - return true; + return impl_ ? impl_->read(out) : false; } void VideoStream::close() { - { - const std::scoped_lock lock(mutex_); - if (closed_) { - return; - } - closed_ = true; - } - - // Dispose FFI handle - if (stream_handle_.get() != 0) { - stream_handle_.reset(); - } - - // Remove listener - if (listener_id_ != 0) { - FfiClient::instance().RemoveListener(listener_id_); - listener_id_ = 0; + if (impl_) { + impl_->close(); } - - // Wake any waiting readers - cv_.notify_all(); } // --------------------- Internal helpers --------------------- void VideoStream::initFromTrack(const std::shared_ptr &track, const Options &options) { - capacity_ = options.capacity; - - // Subscribe to FFI events, this is essential to get video frames from FFI. - listener_id_ = FfiClient::instance().AddListener( - [this](const proto::FfiEvent &e) { this->onFfiEvent(e); }); - - // Send FFI request to create a new video stream bound to this track - FfiRequest req; - auto *new_video_stream = req.mutable_new_video_stream(); - new_video_stream->set_track_handle( - static_cast(track->ffi_handle_id())); - new_video_stream->set_type(proto::VideoStreamType::VIDEO_STREAM_NATIVE); - new_video_stream->set_normalize_stride(true); - new_video_stream->set_format(toProto(options.format)); - - auto resp = FfiClient::instance().sendRequest(req); - if (!resp.has_new_video_stream()) { - LK_LOG_ERROR( - "VideoStream::initFromTrack: FFI response missing new_video_stream()"); - throw std::runtime_error("new_video_stream FFI request failed"); - } - // Adjust field names to match your proto exactly: - const auto &stream = resp.new_video_stream().stream(); - stream_handle_ = FfiHandle(static_cast(stream.handle().id())); - // TODO, do we need to cache the metadata from stream.info ? + impl_->initFromTrack(track, options); } void VideoStream::initFromParticipant(Participant &participant, TrackSource track_source, const Options &options) { - capacity_ = options.capacity; - - // 1) Subscribe to FFI events - listener_id_ = FfiClient::instance().AddListener( - [this](const FfiEvent &e) { this->onFfiEvent(e); }); - - // 2) Send FFI request to create a video stream from participant + track - // source - FfiRequest req; - auto *vs = req.mutable_video_stream_from_participant(); - vs->set_participant_handle(participant.ffiHandleId()); - vs->set_type(proto::VideoStreamType::VIDEO_STREAM_NATIVE); - vs->set_track_source(static_cast(track_source)); - vs->set_normalize_stride(true); - vs->set_format(toProto(options.format)); - - auto resp = FfiClient::instance().sendRequest(req); - // Adjust field names to match your proto exactly: - const auto &stream = resp.video_stream_from_participant().stream(); - stream_handle_ = FfiHandle(static_cast(stream.handle().id())); -} - -void VideoStream::onFfiEvent(const proto::FfiEvent &event) { - // Filter for video_stream_event first. - if (event.message_case() != FfiEvent::kVideoStreamEvent) { - return; - } - const auto &vse = event.video_stream_event(); - // Check if this event is for our stream handle. - if (vse.stream_handle() != static_cast(stream_handle_.get())) { - return; - } - // Handle frame_received or eos. - if (vse.has_frame_received()) { - const auto &fr = vse.frame_received(); - - // Convert owned buffer->VideoFrame via a helper. - // You should implement this static function in your VideoFrame class. - VideoFrameEvent ev; - ev.frame = VideoFrame::fromOwnedInfo(fr.buffer()); - ev.timestamp_us = fr.timestamp_us(); - ev.rotation = static_cast(fr.rotation()); - if (fr.has_metadata()) { - ev.metadata = fromProto(fr.metadata()); - } - pushFrame(std::move(ev)); - } else if (vse.has_eos()) { - pushEos(); - } -} - -void VideoStream::pushFrame(VideoFrameEvent &&ev) { - { - const std::scoped_lock lock(mutex_); - - if (closed_ || eof_) { - return; - } - - if (capacity_ > 0 && queue_.size() >= capacity_) { - // Ring behavior: drop oldest frame. - queue_.pop_front(); - } - - queue_.push_back(std::move(ev)); - } - cv_.notify_one(); -} - -void VideoStream::pushEos() { - { - const std::scoped_lock lock(mutex_); - if (eof_) { - return; - } - eof_ = true; - } - cv_.notify_all(); + impl_->initFromParticipant(participant, track_source, options); } } // namespace livekit