From c0e7e24ee1022b8c55246ab25b05448ed237fad2 Mon Sep 17 00:00:00 2001 From: BartolomeyKant Date: Tue, 2 Jun 2026 18:53:45 +0500 Subject: [PATCH 1/2] new poller interface --- .../channels/ethernet_transport_factory.cpp | 4 +- aether/poller/epoll_poller.cpp | 66 +++++++++++----- aether/poller/epoll_poller.h | 21 +++-- aether/poller/freertos_poller.cpp | 45 +++++++++-- aether/poller/freertos_poller.h | 44 +++++++++-- aether/poller/kqueue_poller.cpp | 71 +++++++++++------ aether/poller/kqueue_poller.h | 19 +++-- aether/poller/poller.h | 12 ++- aether/poller/poller_types.h | 17 ++-- aether/poller/unix_poller.h | 78 +++++++++++++++++-- aether/poller/win_poller.cpp | 6 +- aether/poller/win_poller.h | 6 +- aether/serial_ports/unix_serial_port.cpp | 46 +++++------ aether/serial_ports/unix_serial_port.h | 9 +-- aether/serial_ports/win_serial_port.cpp | 2 +- aether/serial_ports/win_serial_port.h | 2 +- .../sockets/lwip_cb_tcp_socket.h | 2 - .../sockets/lwip_cb_udp_socket.h | 2 - .../system_sockets/sockets/lwip_socket.cpp | 62 ++++++++------- .../system_sockets/sockets/lwip_socket.h | 16 ++-- .../sockets/lwip_tcp_socket.cpp | 30 ++++--- .../system_sockets/sockets/lwip_tcp_socket.h | 8 +- .../sockets/lwip_udp_socket.cpp | 11 ++- .../system_sockets/sockets/lwip_udp_socket.h | 2 + .../system_sockets/sockets/unix_socket.cpp | 58 +++++++------- .../system_sockets/sockets/unix_socket.h | 20 ++--- .../sockets/unix_tcp_socket.cpp | 31 ++++---- .../system_sockets/sockets/unix_tcp_socket.h | 6 +- .../sockets/unix_udp_socket.cpp | 13 ++-- .../system_sockets/sockets/win_socket.cpp | 2 +- .../system_sockets/sockets/win_socket.h | 7 +- 31 files changed, 456 insertions(+), 262 deletions(-) diff --git a/aether/channels/ethernet_transport_factory.cpp b/aether/channels/ethernet_transport_factory.cpp index 942bdeba..4c376576 100644 --- a/aether/channels/ethernet_transport_factory.cpp +++ b/aether/channels/ethernet_transport_factory.cpp @@ -91,9 +91,9 @@ std::unique_ptr EthernetTransportFactory::BuildUdp( [[maybe_unused]] Ptr const& poller, [[maybe_unused]] Endpoint address_port_protocol) { # ifdef SYSTEM_SOCKET_UDP_TRANSPORT_ENABLED -# if LWIP_CB_TCP_SOCKET_ENABLED +# if LWIP_CB_UDP_SOCKET_ENABLED using SocketType = LwipCBUdpSocket; -# elif LWIP_TCP_SOCKET_ENABLED +# elif LWIP_UDP_SOCKET_ENABLED using SocketType = LwipUdpSocket; # elif UNIX_SOCKET_ENABLED using SocketType = UnixUdpSocket; diff --git a/aether/poller/epoll_poller.cpp b/aether/poller/epoll_poller.cpp index 11a1f733..fa90f057 100644 --- a/aether/poller/epoll_poller.cpp +++ b/aether/poller/epoll_poller.cpp @@ -67,9 +67,11 @@ EpollImpl::EpollImpl() thread_(&EpollImpl::Loop, this) { AE_TELE_INFO(kEpollWorkerCreate); + auto lock = std::scoped_lock{poller_mutex_}; // add wake up fd to epoll if (event_fd_ != -1) { - Event(DescriptorType{event_fd_}, EventType::kRead, [](auto) {}); + Callback(DescriptorType{event_fd_}, [](auto, auto) {}); + Event(DescriptorType{event_fd_}, EventType::kRead | EventType::kError); } } @@ -101,40 +103,62 @@ EpollImpl::~EpollImpl() { AE_TELE_INFO(kEpollWorkerDestroyed); } -void EpollImpl::Event(DescriptorType fd, EventType event, EventCb cb) { - AE_TELE_DEBUG(kEpollAddDescriptor, "Poller event fd:{} event:{}", fd, event); +void EpollImpl::lock() { poller_mutex_.lock(); } +void EpollImpl::unlock() { poller_mutex_.unlock(); } + +void EpollImpl::Callback(DescriptorType fd, EventCb cb) { + AE_TELE_DEBUG(kEpollAddDescriptor, "Poller callback for fd:{}", fd); + event_map_.emplace(fd, EventHandler{.cb = std::move(cb), .events = {}}); +} + +void EpollImpl::Event(DescriptorType fd, EventType events) { + AE_TELED_DEBUG("Poller event for fd:{} events: {}", fd, + static_cast(events)); + auto it = event_map_.find(fd); + if (it == event_map_.end()) { + assert(false && "Callback should setup first"); + return; + } + struct epoll_event epoll_event; - epoll_event.events = epoll_poller_internal::PollerEventsToEpol(event); + epoll_event.events = epoll_poller_internal::PollerEventsToEpol(events); // watch only edge triggered events epoll_event.events |= EPOLLET; epoll_event.data.fd = fd; - auto lock = std::scoped_lock{poller_mutex_}; - auto [_, new_event] = event_map_.insert_or_assign(fd, std::move(cb)); - int op = new_event ? EPOLL_CTL_ADD : EPOLL_CTL_MOD; + int op = (it->second.events == EventType{}) ? EPOLL_CTL_ADD : EPOLL_CTL_MOD; auto res = epoll_ctl(epoll_fd_, op, fd, &epoll_event); if (res < 0) { AE_TELE_ERROR(kEpollAddFailed, "Failed to add to epoll {} {}", errno, strerror(errno)); assert(false); } + it->second.events = events; } void EpollImpl::Remove(DescriptorType fd) { AE_TELE_DEBUG(kEpollRemoveDescriptor, "Remove poller event {}", fd); - auto lock = std::scoped_lock{poller_mutex_}; - event_map_.erase(fd); + auto it = event_map_.find(fd); + if (it == event_map_.end()) { + // nothing to remove + return; + } - struct epoll_event epoll_event{}; - auto res = epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, &epoll_event); - if (res < 0) { - if (errno != ENOENT) { - AE_TELE_ERROR(kEpollRemoveFailed, "Failed to remove from epoll {} {}", - errno, strerror(errno)); - assert(false); + if (it->second.events != EventType{}) { + // if events not empty, remove from epol_ctl also + struct epoll_event epoll_event{}; + auto res = epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, &epoll_event); + if (res < 0) { + if (errno != ENOENT) { + AE_TELE_ERROR(kEpollRemoveFailed, "Failed to remove from epoll {} {}", + errno, strerror(errno)); + assert(false); + } } } + + event_map_.erase(fd); } int EpollImpl::MakeEventFd() { @@ -161,7 +185,7 @@ int EpollImpl::InitEpoll() { } void EpollImpl::Loop() { - static constexpr auto kMaxEvents = 10; + static constexpr std::size_t kMaxEvents = 10; std::array events; while (!stop_requested_) { @@ -190,7 +214,7 @@ void EpollImpl::Loop() { continue; } auto ev_type = epoll_poller_internal::EpollEventToEventType(event.events); - poller_event->second(ev_type); + poller_event->second.cb(fd, ev_type); } } } @@ -199,11 +223,11 @@ EpollPoller::EpollPoller() = default; EpollPoller::EpollPoller(ObjProp prop) : IPoller{prop} {} -NativePoller* EpollPoller::Native() { +std::shared_ptr EpollPoller::Native() { if (!impl_) { - impl_.emplace(); + impl_ = std::make_shared(); } - return static_cast(&*impl_); + return impl_; } } // namespace ae diff --git a/aether/poller/epoll_poller.h b/aether/poller/epoll_poller.h index 68ef9037..22c2eeb2 100644 --- a/aether/poller/epoll_poller.h +++ b/aether/poller/epoll_poller.h @@ -23,21 +23,30 @@ # include # include # include -# include +# include # include "aether/poller/poller.h" # include "aether/poller/unix_poller.h" namespace ae { class EpollImpl final : public UnixPollerImpl { + struct EventHandler { + EventCb cb; + EventType events; + }; + public: EpollImpl(); ~EpollImpl() override; - void Event(DescriptorType fd, EventType event, EventCb cb) override; + private: + void lock() override; + void unlock() override; + + void Callback(DescriptorType fd, EventCb cb) override; + void Event(DescriptorType fd, EventType events) override; void Remove(DescriptorType fd) override; - private: static int InitEpoll(); static int MakeEventFd(); void EmptyWakeUpPipe(EventType event); @@ -46,7 +55,7 @@ class EpollImpl final : public UnixPollerImpl { int epoll_fd_; int event_fd_; std::recursive_mutex poller_mutex_; - std::map event_map_; + std::map event_map_; std::atomic_bool stop_requested_{false}; std::thread thread_; @@ -62,10 +71,10 @@ class EpollPoller : public IPoller { AE_OBJECT_REFLECT() - NativePoller* Native() override; + std::shared_ptr Native() override; private: - std::optional impl_; + std::shared_ptr impl_; }; } // namespace ae diff --git a/aether/poller/freertos_poller.cpp b/aether/poller/freertos_poller.cpp index 5a822c8c..25b8e50a 100644 --- a/aether/poller/freertos_poller.cpp +++ b/aether/poller/freertos_poller.cpp @@ -136,9 +136,12 @@ FreeRtosLwipPollerImpl::~FreeRtosLwipPollerImpl() { AE_TELE_DEBUG(kFreertosWorkerDestroyed, "Poll worker has been destroyed"); } +void FreeRtosLwipPollerImpl::lock() { ctl_mutex_.lock(); } + +void FreeRtosLwipPollerImpl::unlock() { ctl_mutex_.unlock(); } + void FreeRtosLwipPollerImpl::Event(DescriptorType fd, EventType event_type, EventCb cb) { - auto lock = std::scoped_lock{ctl_mutex_}; AE_TELE_DEBUG(kFreertosAddDescriptor, "Added descriptor {} event {}", fd, event_type); event_map_.insert_or_assign(fd, PollEvent{event_type, std::move(cb)}); @@ -147,7 +150,6 @@ void FreeRtosLwipPollerImpl::Event(DescriptorType fd, EventType event_type, } void FreeRtosLwipPollerImpl::Remove(DescriptorType fd) { - auto lock = std::scoped_lock{ctl_mutex_}; event_map_.erase(fd); freertos_poller_internal::WritePipe(wake_up_pipe_); AE_TELE_DEBUG(kFreertosRemoveDescriptor, "Removed descriptor {}", fd); @@ -190,7 +192,7 @@ void FreeRtosLwipPollerImpl::Loop() { continue; } poll_event->second.cb( - freertos_poller_internal::FromEpollEvent(v.revents)); + v.fd, freertos_poller_internal::FromEpollEvent(v.revents)); } } } @@ -242,15 +244,46 @@ std::vector FreeRtosLwipPollerImpl::FillFdsVector() { return fds; } +FreeRtosPolledFd::FreeRtosPolledFd(DescriptorType fd, + std::shared_ptr const& poller) + : fd_{fd}, + poller_{std::static_pointer_cast(poller)} {} + +FreeRtosPolledFd::~FreeRtosPolledFd() { + if (fd_ != kInvalidDescriptor) { + auto lock = std::scoped_lock(*poller_); + poller_->Remove(fd_); + } +} + +void FreeRtosPolledFd::Event(EventType event_type, + FreeRtosLwipPollerImpl::EventCb event_cb) { + auto lock = std::scoped_lock(*poller_); + poller_->Event(fd_, event_type, std::move(event_cb)); +} + +FreeRtosPolledFd::Fd FreeRtosPolledFd::fd() const noexcept { + return Fd{std::unique_lock{*poller_}, fd_}; +} + +FreeRtosPolledFd::Fd FreeRtosPolledFd::Remove() noexcept { + auto fd = Fd{std::unique_lock{*poller_}, fd_}; + if (fd_ != kInvalidDescriptor) { + poller_->Remove(fd_); + fd_ = kInvalidDescriptor; + } + return fd; +} + FreertosPoller::FreertosPoller() = default; FreertosPoller::FreertosPoller(ObjProp prop) : IPoller{prop} {} -NativePoller* FreertosPoller::Native() { +std::shared_ptr FreertosPoller::Native() { if (!impl_) { - impl_.emplace(); + impl_ = std::make_shared(); } - return static_cast(&*impl_); + return impl_; } } // namespace ae diff --git a/aether/poller/freertos_poller.h b/aether/poller/freertos_poller.h index 75ae523b..76c612a2 100644 --- a/aether/poller/freertos_poller.h +++ b/aether/poller/freertos_poller.h @@ -27,7 +27,7 @@ # include # include # include -# include +# include # include "aether/poller/poller.h" # include "aether/poller/poller_types.h" @@ -38,7 +38,7 @@ class FreeRtosLwipPollerImpl : public NativePoller { friend void vTaskFunction(void* pvParameters); public: - using EventCb = SmallFunction; + using EventCb = SmallFunction; struct PollEvent { EventType event_type; @@ -46,9 +46,12 @@ class FreeRtosLwipPollerImpl : public NativePoller { }; FreeRtosLwipPollerImpl(); - ~FreeRtosLwipPollerImpl(); + ~FreeRtosLwipPollerImpl() override; - void Event(DescriptorType fd, EventType event_type, EventCb cb); + void lock(); + void unlock(); + + void Event(DescriptorType fd, EventType event_type, EventCb event_cb); void Remove(DescriptorType fd); private: @@ -62,6 +65,35 @@ class FreeRtosLwipPollerImpl : public NativePoller { std::recursive_mutex ctl_mutex_; }; +class FreeRtosPolledFd { + public: + class Fd { + public: + Fd(std::unique_lock&& lock, + DescriptorType fd) noexcept + : lock_{std::move(lock)}, fd_{fd} {} + + DescriptorType operator*() const noexcept { return fd_; } + + private: + std::unique_lock lock_; + DescriptorType fd_; + }; + + FreeRtosPolledFd(DescriptorType fd, + std::shared_ptr const& poller); + + ~FreeRtosPolledFd(); + + void Event(EventType event_type, FreeRtosLwipPollerImpl::EventCb event_cb); + Fd fd() const noexcept; + Fd Remove() noexcept; + + private: + DescriptorType fd_; + mutable std::shared_ptr poller_; +}; + class FreertosPoller : public IPoller { AE_OBJECT(FreertosPoller, IPoller, 0) @@ -72,10 +104,10 @@ class FreertosPoller : public IPoller { AE_OBJECT_REFLECT() - NativePoller* Native() override; + std::shared_ptr Native() override; private: - std::optional impl_; + std::shared_ptr impl_; }; } // namespace ae diff --git a/aether/poller/kqueue_poller.cpp b/aether/poller/kqueue_poller.cpp index 09ae84d7..502fc499 100644 --- a/aether/poller/kqueue_poller.cpp +++ b/aether/poller/kqueue_poller.cpp @@ -119,36 +119,59 @@ KqueuePollerImpl::~KqueuePollerImpl() { AE_TELE_DEBUG(kKqueueWorkerDestroyed); } -void KqueuePollerImpl::Event(DescriptorType fd, EventType event, EventCb cb) { - AE_TELE_DEBUG(kKqueueAddDescriptor, "Add descriptor {} event {}", fd, event); - auto lock = std::scoped_lock{poller_mutex_}; - event_map_[fd] = std::move(cb); +void KqueuePollerImpl::lock() { poller_mutex_.lock(); } - std::array events; +void KqueuePollerImpl::unlock() { poller_mutex_.unlock(); } + +void KqueuePollerImpl::Callback(DescriptorType fd, EventCb cb) { + AE_TELE_DEBUG(kKqueueAddDescriptor, "Add descriptor {}", fd); + event_map_[fd] = EventHandler{.cb = std::move(cb), .events = EventType{}}; +} + +void KqueuePollerImpl::Event(DescriptorType fd, EventType events) { + AE_TELED_DEBUG("Add events for {} event {}", fd, events); + + auto it = event_map_.find(fd); + if (it == event_map_.end()) { + assert(false && "Callback should be setup first"); + return; + } + + std::array kqueu_events; auto count = kqueue_poller_internal::FillKQueueFilter( - events.data(), events.size(), fd, event); - auto res = kevent(kqueue_fd_, events.data(), count, nullptr, 0, nullptr); + kqueu_events.data(), kqueu_events.size(), fd, events); + auto res = + kevent(kqueue_fd_, kqueu_events.data(), count, nullptr, 0, nullptr); if (res == -1) { AE_TELE_ERROR(kKqueueAddFailed, "Add event with error {} {}", errno, strerror(errno)); assert(false); } + it->second.events = events; } void KqueuePollerImpl::Remove(DescriptorType fd) { AE_TELE_DEBUG(kKqueueRemoveDescriptor, "Remove event descriptor {}", fd); - auto lock = std::scoped_lock{poller_mutex_}; - event_map_.erase(fd); + auto it = event_map_.find(fd); + if (it == event_map_.end()) { + return; + } - // remove all kind of events - std::array events; - auto count = kqueue_poller_internal::FillKQueueFilter( - events.data(), events.size(), fd, EventType::kRead | EventType::kWrite); - auto res = kevent(kqueue_fd_, events.data(), count, nullptr, 0, nullptr); - if (res == -1) { - AE_TELE_ERROR(kKqueueRemoveFailed, "Remove event error {} {}", errno, - strerror(errno)); + if (it->second.events != EventType{}) { + // remove all kind of events + std::array kqueu_events; + auto count = kqueue_poller_internal::FillKQueueFilter( + kqueu_events.data(), kqueu_events.size(), fd, + EventType::kRead | EventType::kWrite); + auto res = + kevent(kqueue_fd_, kqueu_events.data(), count, nullptr, 0, nullptr); + if (res == -1) { + AE_TELE_ERROR(kKqueueRemoveFailed, "Remove event error {} {}", errno, + strerror(errno)); + } } + + event_map_.erase(fd); } int KqueuePollerImpl::InitKqueue() { @@ -183,13 +206,13 @@ void KqueuePollerImpl::Loop() { // user event continue; } - auto poller_event = - event_map_.find(DescriptorType{static_cast(ev.ident)}); + auto fd = DescriptorType{static_cast(ev.ident)}; + auto poller_event = event_map_.find(fd); if (poller_event == event_map_.end()) { continue; } - poller_event->second( - kqueue_poller_internal::FilterTypeToEventType(ev.filter)); + poller_event->second.cb( + fd, kqueue_poller_internal::FilterTypeToEventType(ev.filter)); } } } @@ -198,11 +221,11 @@ KqueuePoller::KqueuePoller() = default; KqueuePoller::KqueuePoller(ObjProp prop) : IPoller{prop} {} -NativePoller* KqueuePoller::Native() { +std::shared_ptr KqueuePoller::Native() { if (!impl_) { - impl_.emplace(); + impl_ = std::make_shared(); } - return static_cast(&*impl_); + return impl_; } } // namespace ae diff --git a/aether/poller/kqueue_poller.h b/aether/poller/kqueue_poller.h index 6495b9bc..4ecc3121 100644 --- a/aether/poller/kqueue_poller.h +++ b/aether/poller/kqueue_poller.h @@ -30,14 +30,23 @@ namespace ae { class KqueuePollerImpl : public UnixPollerImpl { + struct EventHandler { + EventCb cb; + EventType events; + }; + public: KqueuePollerImpl(); ~KqueuePollerImpl() override; - void Event(DescriptorType fd, EventType event, EventCb cb) override; + private: + void lock() override; + void unlock() override; + + void Callback(DescriptorType fd, EventCb cb) override; + void Event(DescriptorType fd, EventType events) override; void Remove(DescriptorType descriptor) override; - private: static int InitKqueue(); void Loop(); @@ -46,7 +55,7 @@ class KqueuePollerImpl : public UnixPollerImpl { int kqueue_fd_; std::recursive_mutex poller_mutex_; - std::map event_map_; + std::map event_map_; std::atomic_bool stop_requested_{false}; std::thread thread_; }; @@ -61,10 +70,10 @@ class KqueuePoller : public IPoller { AE_OBJECT_REFLECT() - NativePoller* Native() override; + std::shared_ptr Native() override; private: - std::optional impl_; + std::shared_ptr impl_; }; } // namespace ae diff --git a/aether/poller/poller.h b/aether/poller/poller.h index eaea3e47..fb3d2185 100644 --- a/aether/poller/poller.h +++ b/aether/poller/poller.h @@ -17,10 +17,18 @@ #ifndef AETHER_POLLER_POLLER_H_ #define AETHER_POLLER_POLLER_H_ +#include + #include "aether/obj/obj.h" namespace ae { -class NativePoller {}; +/** + * \brief Base pointer for platform native poller implementation + */ +class NativePoller { + public: + virtual ~NativePoller() = default; +}; class IPoller : public Obj { AE_OBJECT(IPoller, Obj, 0) @@ -36,7 +44,7 @@ class IPoller : public Obj { /** * \brief Return native poller implementation. */ - virtual NativePoller* Native() = 0; + virtual std::shared_ptr Native() = 0; }; } // namespace ae diff --git a/aether/poller/poller_types.h b/aether/poller/poller_types.h index 7a299572..99ed0a55 100644 --- a/aether/poller/poller_types.h +++ b/aether/poller/poller_types.h @@ -28,7 +28,7 @@ struct EventType { static constexpr std::uint8_t kError = 0x4; EventType() = default; - EventType(std::uint8_t v) : value{v} {} + EventType(std::uint8_t v) : value{v} {} // NOLINT(*explicit-constructor) std::uint8_t operator&(std::uint8_t other) const { return value & other; } std::uint8_t operator|(std::uint8_t other) const { return value | other; } @@ -44,9 +44,9 @@ struct EventType { struct DescriptorType { // Add our own defines to prevent windows.h in public header -#if defined _WIN32 +#ifdef _WIN32 using Handle = void*; -# if defined _WIN64 +# ifdef _WIN64 using Socket = std::uint64_t; # else using Socket = std::uint32_t; @@ -74,14 +74,21 @@ struct DescriptorType { Handle descriptor; #else - DescriptorType(int des) : descriptor{des} {} + DescriptorType(int des) : descriptor{des} {} // NOLINT(*explicit-constructor) - operator int() const { return descriptor; } + operator int() const { return descriptor; } // NOLINT(*explicit-constructor) int descriptor; #endif }; +#ifdef _WIN32 +static constexpr auto kInvalidDescriptor = + static_cast(~0); +#else +static constexpr auto kInvalidDescriptor = -1; +#endif + template <> struct Formatter { template diff --git a/aether/poller/unix_poller.h b/aether/poller/unix_poller.h index 665d9a43..d43e0129 100644 --- a/aether/poller/unix_poller.h +++ b/aether/poller/unix_poller.h @@ -20,6 +20,9 @@ #if defined(__linux__) || defined(__unix__) || defined(__APPLE__) || \ defined(__FreeBSD__) +# include +# include + # include "aether/poller/poller.h" # include "aether/poller/poller_types.h" # include "aether/types/small_function.h" @@ -30,23 +33,84 @@ namespace ae { */ class UnixPollerImpl : public NativePoller { public: - using EventCb = SmallFunction; + using EventCb = SmallFunction; + + private: + friend class UnixPolledFd; + friend class std::scoped_lock; + friend class std::unique_lock; - virtual ~UnixPollerImpl() = default; + virtual void lock() = 0; + virtual void unlock() = 0; /** - * \brief Add or change file descriptor to the poller. - * fd - file descriptor to add or change. - * event - event to set. It contains file descriptor and ORed event list to - * wait for. + * \brief Add file descriptor to the poller with event callback. + * fd - file descriptor to add. * cb - callback to call when event occurs. */ - virtual void Event(DescriptorType fd, EventType event, EventCb cb) = 0; + virtual void Callback(DescriptorType fd, EventCb cb) = 0; + /** + * \brief Setup event poller for file descriptor + */ + virtual void Event(DescriptorType fd, EventType events) = 0; /** * \brief Remove file descriptor from the poller. */ virtual void Remove(DescriptorType descriptor) = 0; }; + +class UnixPolledFd { + public: + class Fd { + public: + Fd(std::unique_lock&& lock, DescriptorType fd) noexcept + : lock_{std::move(lock)}, fd_{fd} {} + + DescriptorType operator*() const noexcept { return fd_; } + + private: + std::unique_lock lock_; + DescriptorType fd_; + }; + + UnixPolledFd(DescriptorType fd, std::shared_ptr const& poller, + UnixPollerImpl::EventCb cb) + : fd_{fd}, poller_{std::static_pointer_cast(poller)} { + auto lock = std::scoped_lock{*poller_}; + poller_->Callback(fd_, std::move(cb)); + } + + ~UnixPolledFd() { + auto lock = std::scoped_lock{*poller_}; + if (fd_ != kInvalidDescriptor) { + poller_->Remove(fd_); + } + } + + void Events(EventType events) { + auto lock = std::scoped_lock{*poller_}; + poller_->Event(fd_, events); + } + + auto fd() const noexcept { return Fd{std::unique_lock{*poller_}, fd_}; } + + /** + * \brief Use Remove to remove fd from the poller and return the descriptor + */ + auto Remove() noexcept { + auto fd = Fd{std::unique_lock{*poller_}, fd_}; + if (fd_ != kInvalidDescriptor) { + poller_->Remove(fd_); + fd_ = kInvalidDescriptor; + } + return fd; + } + + private: + DescriptorType fd_; + mutable std::shared_ptr poller_; +}; + } // namespace ae #endif diff --git a/aether/poller/win_poller.cpp b/aether/poller/win_poller.cpp index f2171648..f79f8b0b 100644 --- a/aether/poller/win_poller.cpp +++ b/aether/poller/win_poller.cpp @@ -114,11 +114,11 @@ WinPoller::WinPoller() = default; WinPoller::WinPoller(ObjProp prop) : IPoller{prop} {} WinPoller::~WinPoller() = default; -NativePoller* WinPoller::Native() { +std::shared_ptr WinPoller::Native() { if (!impl_) { - impl_.emplace(); + impl_ = std::make_shared(); } - return static_cast(&*impl_); + return impl_; } } // namespace ae #endif diff --git a/aether/poller/win_poller.h b/aether/poller/win_poller.h index 23a3bf54..692ece84 100644 --- a/aether/poller/win_poller.h +++ b/aether/poller/win_poller.h @@ -31,7 +31,7 @@ # include # include # include -# include +# include # include "aether/poller/poller.h" # include "aether/poller/poller_types.h" @@ -72,10 +72,10 @@ class WinPoller : public IPoller { AE_OBJECT_REFLECT() - NativePoller* Native() override; + std::shared_ptr Native() override; private: - std::optional impl_; + std::shared_ptr impl_; }; } // namespace ae #endif diff --git a/aether/serial_ports/unix_serial_port.cpp b/aether/serial_ports/unix_serial_port.cpp index 2de1e5e3..c532b435 100644 --- a/aether/serial_ports/unix_serial_port.cpp +++ b/aether/serial_ports/unix_serial_port.cpp @@ -26,20 +26,25 @@ # include "aether/serial_ports/serial_ports_tele.h" namespace ae { -static constexpr int kInvalidPort = -1; - UnixSerialPort::UnixSerialPort(AeContext const& ae_context, SerialInit serial_init, IPoller::ptr const& poller) : ae_context_{ae_context}, serial_init_{std::move(serial_init)}, - poller_{static_cast(poller->Native())}, - fd_{OpenPort(serial_init_)} {} + poller_fd_{OpenPort(serial_init_), poller->Native(), + MethodPtr<&UnixSerialPort::PolleEvent>{this}} { + poller_fd_.Events(EventType::kRead | EventType::kError); +} -UnixSerialPort::~UnixSerialPort() { Close(); } +UnixSerialPort::~UnixSerialPort() { + auto fd = poller_fd_.Remove(); + if (*fd != kInvalidDescriptor) { + close(*fd); + } +} void UnixSerialPort::Write(std::span data) { - auto bytes_written = write(fd_, data.data(), data.size()); + auto bytes_written = write(*poller_fd_.fd(), data.data(), data.size()); if (bytes_written < 0) { AE_TELE_ERROR(kAdapterSerialWriteFailed, "Write serial port error {}", strerror(errno)); @@ -55,7 +60,7 @@ UnixSerialPort::DataReadEvent::Subscriber UnixSerialPort::read_event() { return EventSubscriber{read_event_}; } -bool UnixSerialPort::IsOpen() { return fd_ != kInvalidPort; } +bool UnixSerialPort::IsOpen() { return *poller_fd_.fd() != kInvalidDescriptor; } int UnixSerialPort::OpenPort(SerialInit const& serial_init) { /* open the port */ @@ -63,12 +68,12 @@ int UnixSerialPort::OpenPort(SerialInit const& serial_init) { if (fd < 0) { AE_TELED_ERROR("Open serial port at {} error {}", serial_init.port_name, strerror(errno)); - return kInvalidPort; + return kInvalidDescriptor; } auto close_on_exit = ae_defer_at[&] { close(fd); }; if (!SetOptions(fd, serial_init)) { - return kInvalidPort; + return kInvalidDescriptor; } close_on_exit.Reset(); return fd; @@ -106,22 +111,20 @@ bool UnixSerialPort::SetOptions(int fd, SerialInit const& serial_init) { return true; } -void UnixSerialPort::PolleEvent(EventType event) { +void UnixSerialPort::PolleEvent(DescriptorType fd, EventType event) { auto event_type = event & EventType::kRead; switch (event_type) { case EventType::kRead: - ReadData(); + ReadData(fd); break; default: break; } } -void UnixSerialPort::ReadData() { - auto lock = std::scoped_lock{fd_lock_}; - - static std::uint8_t buffer[1024]; - ssize_t bytes_read = read(fd_, buffer, sizeof(buffer)); +void UnixSerialPort::ReadData(DescriptorType fd) { + static std::uint8_t buffer[1024]; // NOLINT(*avoid-c-arrays, *magic-numbers) + ssize_t bytes_read = read(fd, buffer, sizeof(buffer)); if (bytes_read < 0) { AE_TELED_ERROR("Read serial port error {}", strerror(errno)); return; @@ -129,11 +132,12 @@ void UnixSerialPort::ReadData() { DataBuffer data(static_cast(bytes_read)); std::copy(buffer, buffer + bytes_read, data.begin()); + auto lock = std::scoped_lock{buffers_lock_}; buffers_.emplace_back(std::move(data)); if (!read_flag_.exchange(true)) { scheduler_sub_ = ae_context_.scheduler().Task([this]() noexcept { - auto lock = std::scoped_lock{fd_lock_}; + auto lock = std::scoped_lock{buffers_lock_}; EmitData(); read_flag_ = false; }); @@ -146,14 +150,6 @@ void UnixSerialPort::EmitData() { } buffers_.clear(); } - -void UnixSerialPort::Close() { - if (fd_ != kInvalidPort) { - close(fd_); - fd_ = kInvalidPort; - } -} - } // namespace ae #endif diff --git a/aether/serial_ports/unix_serial_port.h b/aether/serial_ports/unix_serial_port.h index 2bb537b1..8130f87b 100644 --- a/aether/serial_ports/unix_serial_port.h +++ b/aether/serial_ports/unix_serial_port.h @@ -50,20 +50,19 @@ class UnixSerialPort final : public ISerialPort { static int OpenPort(SerialInit const& serial_init); static bool SetOptions(int fd, SerialInit const& serial_init); - void PolleEvent(EventType event); - void ReadData(); + void PolleEvent(DescriptorType fd, EventType event); + void ReadData(DescriptorType fd); void EmitData(); void Close(); AeContext ae_context_; SerialInit serial_init_; - UnixPollerImpl* poller_; + UnixPolledFd poller_fd_; - std::mutex fd_lock_; - int fd_; DataReadEvent read_event_; + std::mutex buffers_lock_; std::list buffers_; std::atomic_bool read_flag_; TaskSubscription scheduler_sub_; diff --git a/aether/serial_ports/win_serial_port.cpp b/aether/serial_ports/win_serial_port.cpp index d867d128..f4eb6e2d 100644 --- a/aether/serial_ports/win_serial_port.cpp +++ b/aether/serial_ports/win_serial_port.cpp @@ -27,7 +27,7 @@ WinSerialPort::WinSerialPort(AeContext const& ae_context, SerialInit serial_init, IPoller::ptr const& poller) : ae_context_{ae_context}, serial_init_{std::move(serial_init)}, - poller_{static_cast(poller->Native())}, + poller_{std::static_pointer_cast(poller->Native())}, fd_{OpenPort(serial_init_)}, read_buffer_(kReadBufSize) { if (fd_ != INVALID_HANDLE_VALUE) { diff --git a/aether/serial_ports/win_serial_port.h b/aether/serial_ports/win_serial_port.h index 4c0ae50e..517c015e 100644 --- a/aether/serial_ports/win_serial_port.h +++ b/aether/serial_ports/win_serial_port.h @@ -60,7 +60,7 @@ class WinSerialPort final : public ISerialPort { AeContext ae_context_; SerialInit serial_init_; - IoCpPoller* poller_; + std::shared_ptr poller_; std::mutex fd_lock_; void* fd_; diff --git a/aether/transport/system_sockets/sockets/lwip_cb_tcp_socket.h b/aether/transport/system_sockets/sockets/lwip_cb_tcp_socket.h index 06c15b3a..0e4304f5 100644 --- a/aether/transport/system_sockets/sockets/lwip_cb_tcp_socket.h +++ b/aether/transport/system_sockets/sockets/lwip_cb_tcp_socket.h @@ -37,8 +37,6 @@ class IPoller; */ class LwipCBTcpSocket : public ISocket { public: - static constexpr int kInvalidSocket = -1; - explicit LwipCBTcpSocket(Ptr const& poller); ~LwipCBTcpSocket() override; diff --git a/aether/transport/system_sockets/sockets/lwip_cb_udp_socket.h b/aether/transport/system_sockets/sockets/lwip_cb_udp_socket.h index 8ccffcff..7ccf8020 100644 --- a/aether/transport/system_sockets/sockets/lwip_cb_udp_socket.h +++ b/aether/transport/system_sockets/sockets/lwip_cb_udp_socket.h @@ -37,8 +37,6 @@ class IPoller; */ class LwipCBUdpSocket : public ISocket { public: - static constexpr int kInvalidSocket = -1; - explicit LwipCBUdpSocket(Ptr const& poller); ~LwipCBUdpSocket() override; diff --git a/aether/transport/system_sockets/sockets/lwip_socket.cpp b/aether/transport/system_sockets/sockets/lwip_socket.cpp index 7aede961..e4ebb3ac 100644 --- a/aether/transport/system_sockets/sockets/lwip_socket.cpp +++ b/aether/transport/system_sockets/sockets/lwip_socket.cpp @@ -24,8 +24,7 @@ namespace ae { LwipSocket::LwipSocket(IPoller& poller, int socket) - : poller_{static_cast(poller.Native())}, - socket_{socket} {} + : socket_{std::in_place, socket, poller.Native()} {} LwipSocket::~LwipSocket() { Disconnect(); } @@ -45,16 +44,17 @@ ISocket& LwipSocket::Error(ErrorCb error_cb) { } std::optional LwipSocket::Send(Span data) { - auto lock = std::scoped_lock{socket_lock_}; + if (!socket_) { + return std::nullopt; + } auto size_to_send = data.size(); // add nosignal to prevent throw SIGPIPE and handle it manually int flags = MSG_NOSIGNAL; - auto res = send(socket_, data.data(), size_to_send, flags); + auto res = send(*socket_->fd(), data.data(), size_to_send, flags); if (res == -1) { if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { // add wait for kWrite - poller_->Event(socket_, - EventType::kRead | EventType::kWrite | EventType::kError, + socket_->Event(EventType::kRead | EventType::kWrite | EventType::kError, MethodPtr<&LwipSocket::OnPollerEvent>{this}); } @@ -69,34 +69,35 @@ std::optional LwipSocket::Send(Span data) { } void LwipSocket::Disconnect() { - auto lock = std::scoped_lock{socket_lock_}; - if (socket_ == kInvalidSocket) { + if (!socket_) { return; } - auto s = socket_; - socket_ = kInvalidSocket; - - poller_->Remove(s); - shutdown(s, SHUT_RDWR); - if (close(s) != 0) { - return; + { + auto s = socket_->Remove(); + shutdown(*s, SHUT_RDWR); + if (close(*s) != 0) { + return; + } } + socket_.reset(); } void LwipSocket::Poll() { - poller_->Event(socket_, EventType::kRead | EventType::kError, - MethodPtr<&LwipSocket::OnPollerEvent>{this}); + if (socket_) { + socket_->Event(EventType::kRead | EventType::kError, + MethodPtr<&LwipSocket::OnPollerEvent>{this}); + } } -void LwipSocket::OnPollerEvent(EventType event) { - AE_TELED_DEBUG("Poll event desc={},event={}", socket_, event); +void LwipSocket::OnPollerEvent(DescriptorType fd, EventType event) { + AE_TELED_DEBUG("Poll event desc={}, event={}", fd, event); for (auto e : {EventType::kRead, EventType::kWrite, EventType::kError}) { if ((event & e) == 0) { continue; } switch (e) { case ae::EventType::kRead: - OnReadEvent(); + OnReadEvent(fd); break; case ae::EventType::kWrite: OnWriteEvent(); @@ -110,11 +111,11 @@ void LwipSocket::OnPollerEvent(EventType event) { } } -void LwipSocket::OnReadEvent() { +void LwipSocket::OnReadEvent(DescriptorType fd) { // read all data while (true) { auto buffer = Span{recv_buffer_.data(), recv_buffer_.size()}; - auto res = Receive(buffer); + auto res = Receive(fd, buffer); if (!res) { OnErrorEvent(); return; @@ -126,6 +127,9 @@ void LwipSocket::OnReadEvent() { buffer = buffer.sub(0, *res); if (recv_data_cb_) { recv_data_cb_(buffer); + } else { + printf("fd %d, Received bytes=%zu but no callback set\n", + static_cast(fd), *res); } return; } @@ -145,10 +149,9 @@ void LwipSocket::OnErrorEvent() { } } -std::optional LwipSocket::Receive(Span buffer) { - auto lock = std::scoped_lock{socket_lock_}; - - auto res = recv(socket_, buffer.data(), buffer.size(), 0); +std::optional LwipSocket::Receive(DescriptorType fd, + Span buffer) { + auto res = recv(fd, buffer.data(), buffer.size(), 0); if (res < 0) { // No data if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { @@ -166,13 +169,12 @@ std::optional LwipSocket::Receive(Span buffer) { return static_cast(res); } -std::optional LwipSocket::GetSocketError() { - auto lock = std::scoped_lock{socket_lock_}; +std::optional LwipSocket::GetSocketError(DescriptorType fd) { int err{}; socklen_t len = sizeof(len); - if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, static_cast(&err), - &len) != 0) { + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, static_cast(&err), &len) != + 0) { AE_TELED_ERROR("Getsockopt error: {}, {}", static_cast(errno), strerror(errno)); return std::nullopt; diff --git a/aether/transport/system_sockets/sockets/lwip_socket.h b/aether/transport/system_sockets/sockets/lwip_socket.h index a0740330..72b5e70b 100644 --- a/aether/transport/system_sockets/sockets/lwip_socket.h +++ b/aether/transport/system_sockets/sockets/lwip_socket.h @@ -26,7 +26,6 @@ # include "aether/poller/poller.h" # include "aether/types/data_buffer.h" # include "aether/poller/freertos_poller.h" -# include "aether/events/event_subscription.h" # include "aether/transport/system_sockets/sockets/isocket.h" namespace ae { @@ -35,8 +34,6 @@ namespace ae { */ class LwipSocket : public ISocket { public: - static constexpr int kInvalidSocket = -1; - explicit LwipSocket(IPoller& poller, int socket); ~LwipSocket() override; @@ -49,19 +46,18 @@ class LwipSocket : public ISocket { protected: void Poll(); - void OnPollerEvent(EventType event); + virtual void OnPollerEvent(DescriptorType fd, EventType event); - void OnReadEvent(); + void OnReadEvent(DescriptorType fd); void OnWriteEvent(); void OnErrorEvent(); - std::optional Receive(Span buffer); + std::optional Receive(DescriptorType fd, + Span buffer); - std::optional GetSocketError(); + std::optional GetSocketError(DescriptorType fd); - FreeRtosLwipPollerImpl* poller_; - int socket_; - std::mutex socket_lock_; + std::optional socket_; ReadyToWriteCb ready_to_write_cb_; RecvDataCb recv_data_cb_; diff --git a/aether/transport/system_sockets/sockets/lwip_tcp_socket.cpp b/aether/transport/system_sockets/sockets/lwip_tcp_socket.cpp index b8537994..7310fc2f 100644 --- a/aether/transport/system_sockets/sockets/lwip_tcp_socket.cpp +++ b/aether/transport/system_sockets/sockets/lwip_tcp_socket.cpp @@ -87,7 +87,7 @@ int LwipTcpSocket::MakeSocket() { if (sock < 0) { AE_TELED_ERROR("LwIp TCP socket creation error {} {}", errno, strerror(errno)); - return kInvalidSocket; + return kInvalidDescriptor; } // close the socket if not created @@ -98,18 +98,18 @@ int LwipTcpSocket::MakeSocket() { }; if (!lwip_tcp_socket_internal::SetNonblocking(sock)) { - return kInvalidSocket; + return kInvalidDescriptor; } if (!lwip_tcp_socket_internal::SetTcpNoDelay(sock, on)) { - return kInvalidSocket; + return kInvalidDescriptor; } if (!lwip_tcp_socket_internal::SetReuseAddress(sock, on)) { - return kInvalidSocket; + return kInvalidDescriptor; } if (!lwip_tcp_socket_internal::SetReciveTimeouts(sock, kRcvTimeoutSec, kRcvTimeoutUsec)) { - return kInvalidSocket; + return kInvalidDescriptor; } AE_TELED_DEBUG("LwIp TCP socket created"); @@ -119,21 +119,19 @@ int LwipTcpSocket::MakeSocket() { ISocket& LwipTcpSocket::Connect(AddressPort const& destination, ConnectedCb connected_cb) { - assert((socket_ != kInvalidSocket) && "Socket is not initialized"); + assert(socket_ && "Socket is not initialized"); connected_cb_ = std::move(connected_cb); ae_defer[&]() { // wait for all events to detect connection - poller_->Event(socket_, - EventType::kRead | EventType::kWrite | EventType::kError, + socket_->Event(EventType::kRead | EventType::kWrite | EventType::kError, MethodPtr<&LwipTcpSocket::OnPollerEvent>{this}); connected_cb_(connection_state_); }; - auto lock = std::scoped_lock{socket_lock_}; - auto addr = GetSockAddr(destination); - auto res = connect(socket_, addr.addr(), static_cast(addr.size)); + auto res = + connect(*socket_->fd(), addr.addr(), static_cast(addr.size)); if (res == -1) { if ((errno == EAGAIN) || (errno == EINPROGRESS)) { AE_TELED_DEBUG("Wait connection"); @@ -151,15 +149,15 @@ ISocket& LwipTcpSocket::Connect(AddressPort const& destination, return *this; } -void LwipTcpSocket::OnPollerEvent(EventType event) { +void LwipTcpSocket::OnPollerEvent(DescriptorType fd, EventType event) { if (connection_state_ == ConnectionState::kConnecting) { - OnConnectionEvent(); + OnConnectionEvent(fd); return; } - LwipSocket::OnPollerEvent(event); + LwipSocket::OnPollerEvent(fd, event); } -void LwipTcpSocket::OnConnectionEvent() { +void LwipTcpSocket::OnConnectionEvent(DescriptorType fd) { ae_defer[&]() { if (connected_cb_) { AE_TELED_DEBUG("LwIp TCP socket connectioin event {}", connection_state_); @@ -168,7 +166,7 @@ void LwipTcpSocket::OnConnectionEvent() { }; // check socket status - auto sock_err = GetSocketError(); + auto sock_err = GetSocketError(fd); if (!sock_err || (sock_err.value() != 0)) { AE_TELED_ERROR("Connect error {}, {}", sock_err, strerror(sock_err.value_or(0))); diff --git a/aether/transport/system_sockets/sockets/lwip_tcp_socket.h b/aether/transport/system_sockets/sockets/lwip_tcp_socket.h index 5231415d..d9cdb679 100644 --- a/aether/transport/system_sockets/sockets/lwip_tcp_socket.h +++ b/aether/transport/system_sockets/sockets/lwip_tcp_socket.h @@ -23,6 +23,8 @@ #if AE_SUPPORT_TCP && LWIP_SOCKET_ENABLED # include "aether/poller/poller.h" +# define LWIP_TCP_SOCKET_ENABLED 1 + namespace ae { class LwipTcpSocket final : public LwipSocket { static constexpr int kRcvTimeoutSec = 0; @@ -34,11 +36,13 @@ class LwipTcpSocket final : public LwipSocket { ISocket& Connect(AddressPort const& destination, ConnectedCb connected_cb) override; + protected: + void OnPollerEvent(DescriptorType fd, EventType event) override; + private: static int MakeSocket(); - void OnPollerEvent(EventType event); - void OnConnectionEvent(); + void OnConnectionEvent(DescriptorType fd); ConnectionState connection_state_; ConnectedCb connected_cb_; diff --git a/aether/transport/system_sockets/sockets/lwip_udp_socket.cpp b/aether/transport/system_sockets/sockets/lwip_udp_socket.cpp index 41c981ce..63e70ceb 100644 --- a/aether/transport/system_sockets/sockets/lwip_udp_socket.cpp +++ b/aether/transport/system_sockets/sockets/lwip_udp_socket.cpp @@ -38,7 +38,7 @@ int LwipUdpSocket::MakeSocket() { auto sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); if (sock < 0) { AE_TELED_ERROR("LwIp UDP socket not created {}", strerror(errno)); - return kInvalidSocket; + return kInvalidDescriptor; } // close the socket if not created @@ -52,7 +52,7 @@ int LwipUdpSocket::MakeSocket() { if (lwip_fcntl(sock, F_SETFL, O_NONBLOCK) != ESP_OK) { AE_TELED_ERROR("lwip_fcntl set nonblocking mode error {} {}", errno, strerror(errno)); - return kInvalidSocket; + return kInvalidDescriptor; } AE_TELED_DEBUG("LwIp UDP socket created"); @@ -63,7 +63,7 @@ int LwipUdpSocket::MakeSocket() { ISocket& LwipUdpSocket::Connect(AddressPort const& destination, ConnectedCb connected_cb) { // UDP connection means binding socket to a destination address - assert((socket_ != kInvalidSocket) && "Socket is not initialized"); + assert(socket_ && "Socket is not initialized"); ae_defer[&]() { Poll(); @@ -71,10 +71,9 @@ ISocket& LwipUdpSocket::Connect(AddressPort const& destination, connected_cb(connection_state_); }; - auto lock = std::scoped_lock{socket_lock_}; - auto addr = GetSockAddr(destination); - auto res = connect(socket_, addr.addr(), static_cast(addr.size)); + auto res = + connect(*socket_->fd(), addr.addr(), static_cast(addr.size)); if (res == -1) { AE_TELED_ERROR("Not connected {} {}", errno, strerror(errno)); connection_state_ = ConnectionState::kConnectionFailed; diff --git a/aether/transport/system_sockets/sockets/lwip_udp_socket.h b/aether/transport/system_sockets/sockets/lwip_udp_socket.h index fc42be3f..e4ba6497 100644 --- a/aether/transport/system_sockets/sockets/lwip_udp_socket.h +++ b/aether/transport/system_sockets/sockets/lwip_udp_socket.h @@ -23,6 +23,8 @@ #if AE_SUPPORT_UDP && LWIP_SOCKET_ENABLED # include "aether/poller/poller.h" +# define LWIP_UDP_SOCKET_ENABLED 1 + namespace ae { class LwipUdpSocket final : public LwipSocket { public: diff --git a/aether/transport/system_sockets/sockets/unix_socket.cpp b/aether/transport/system_sockets/sockets/unix_socket.cpp index 426d4e58..94249294 100644 --- a/aether/transport/system_sockets/sockets/unix_socket.cpp +++ b/aether/transport/system_sockets/sockets/unix_socket.cpp @@ -36,7 +36,8 @@ namespace ae { UnixSocket::UnixSocket(IPoller& poller, int socket) - : poller_{static_cast(poller.Native())}, socket_{socket} {} + : socket_{std::in_place, socket, poller.Native(), + MethodPtr<&UnixSocket::OnPollerEvent>{this}} {} UnixSocket::~UnixSocket() { Disconnect(); } @@ -56,16 +57,17 @@ ISocket& UnixSocket::Error(ErrorCb error_cb) { } std::optional UnixSocket::Send(Span data) { - auto lock = std::scoped_lock{socket_lock_}; + if (!socket_) { + return std::nullopt; + } auto size_to_send = data.size(); // add nosignal to prevent throw SIGPIPE and handle it manually int flags = MSG_NOSIGNAL; - auto res = send(socket_, data.data(), size_to_send, flags); + auto res = send(*socket_->fd(), data.data(), size_to_send, flags); if (res == -1) { if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) { - poller_->Event(socket_, - EventType::kRead | EventType::kError | EventType::kWrite, - MethodPtr<&UnixSocket::OnPollerEvent>{this}); + // poll read and write + socket_->Events(EventType::kRead | EventType::kError | EventType::kWrite); } else { AE_TELED_ERROR("Send to socket error {} {}", errno, strerror(errno)); return std::nullopt; @@ -76,31 +78,32 @@ std::optional UnixSocket::Send(Span data) { } void UnixSocket::Disconnect() { - auto lock = std::scoped_lock{socket_lock_}; - if (socket_ == kInvalidSocket) { + if (!socket_) { return; } - auto s = socket_; - socket_ = kInvalidSocket; - - poller_->Remove(s); - shutdown(s, SHUT_RDWR); - if (close(s) != 0) { - return; + { + auto s = socket_->Remove(); + shutdown(*s, SHUT_RDWR); + if (close(*s) != 0) { + return; + } } + socket_.reset(); } void UnixSocket::Poll() { - poller_->Event(socket_, EventType::kRead | EventType::kError, - MethodPtr<&UnixSocket::OnPollerEvent>{this}); + if (socket_) { + // normally poll without write + socket_->Events(EventType::kRead | EventType::kError); + } } -void UnixSocket::OnPollerEvent(EventType event) { +void UnixSocket::OnPollerEvent(DescriptorType fd, EventType event) { for (auto e : {EventType::kRead, EventType::kWrite, EventType::kWrite}) { auto event_type = event & e; switch (event_type) { case EventType::kRead: - OnReadEvent(); + OnReadEvent(fd); break; case EventType::kWrite: OnWriteEvent(); @@ -114,12 +117,11 @@ void UnixSocket::OnPollerEvent(EventType event) { } } -void UnixSocket::OnReadEvent() { +void UnixSocket::OnReadEvent(DescriptorType fd) { // read all data - auto lock = std::scoped_lock{socket_lock_}; while (true) { auto buffer = Span{recv_buffer_.data(), recv_buffer_.size()}; - auto res = Receive(buffer); + auto res = Receive(fd, buffer); if (!res) { OnErrorEvent(); return; @@ -150,8 +152,9 @@ void UnixSocket::OnErrorEvent() { } // call on locked socket -std::optional UnixSocket::Receive(Span buffer) { - auto res = recv(socket_, buffer.data(), buffer.size(), 0); +std::optional UnixSocket::Receive(DescriptorType fd, + Span buffer) { + auto res = recv(fd, buffer.data(), buffer.size(), 0); if (res < 0) { // No data if ((errno == EWOULDBLOCK) || (errno == EAGAIN)) { @@ -168,12 +171,11 @@ std::optional UnixSocket::Receive(Span buffer) { return static_cast(res); } -std::optional UnixSocket::GetSocketError() { - auto lock = std::scoped_lock{socket_lock_}; +std::optional UnixSocket::GetSocketError(DescriptorType fd) { int err{}; socklen_t len = sizeof(len); - if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, static_cast(&err), - &len) != 0) { + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, static_cast(&err), &len) != + 0) { AE_TELED_ERROR("Getsockopt error {}, {}", errno, strerror(errno)); return std::nullopt; } diff --git a/aether/transport/system_sockets/sockets/unix_socket.h b/aether/transport/system_sockets/sockets/unix_socket.h index afd09071..f99b4b3b 100644 --- a/aether/transport/system_sockets/sockets/unix_socket.h +++ b/aether/transport/system_sockets/sockets/unix_socket.h @@ -22,12 +22,11 @@ # define UNIX_SOCKET_ENABLED 1 -# include +# include # include "aether/poller/poller.h" -# include "aether/poller/unix_poller.h" # include "aether/types/data_buffer.h" -# include "aether/events/event_subscription.h" +# include "aether/poller/unix_poller.h" # include "aether/transport/system_sockets/sockets/isocket.h" namespace ae { @@ -36,8 +35,6 @@ namespace ae { */ class UnixSocket : public ISocket { public: - static constexpr int kInvalidSocket = -1; - explicit UnixSocket(IPoller& poller, int socket); ~UnixSocket() override; @@ -50,19 +47,18 @@ class UnixSocket : public ISocket { protected: void Poll(); - virtual void OnPollerEvent(EventType event); + virtual void OnPollerEvent(DescriptorType fd, EventType event); - void OnReadEvent(); + void OnReadEvent(DescriptorType fd); void OnWriteEvent(); void OnErrorEvent(); - std::optional Receive(Span buffer); + std::optional Receive(DescriptorType fd, + Span buffer); - std::optional GetSocketError(); + std::optional GetSocketError(DescriptorType fd); - UnixPollerImpl* poller_; - int socket_; - std::mutex socket_lock_; + std::optional socket_; ReadyToWriteCb ready_to_write_cb_; RecvDataCb recv_data_cb_; diff --git a/aether/transport/system_sockets/sockets/unix_tcp_socket.cpp b/aether/transport/system_sockets/sockets/unix_tcp_socket.cpp index 96b87494..b4302b50 100644 --- a/aether/transport/system_sockets/sockets/unix_tcp_socket.cpp +++ b/aether/transport/system_sockets/sockets/unix_tcp_socket.cpp @@ -83,9 +83,9 @@ int UnixTcpSocket::MakeSocket() { bool created = false; // TCP socket int sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock == kInvalidSocket) { + if (sock == kInvalidDescriptor) { AE_TELED_DEBUG("Socket creation error {} {}", errno, strerror(errno)); - return kInvalidSocket; + return kInvalidDescriptor; } // close the socket if it fails to setup @@ -96,13 +96,13 @@ int UnixTcpSocket::MakeSocket() { }; if (!unix_tcp_socket_internal::SetNonblocking(sock)) { - return kInvalidSocket; + return kInvalidDescriptor; } if (!unix_tcp_socket_internal::SetTcpNoDelay(sock)) { - return kInvalidSocket; + return kInvalidDescriptor; } if (!unix_tcp_socket_internal::SetNoSigpipe(sock)) { - return kInvalidSocket; + return kInvalidDescriptor; } created = true; return sock; @@ -110,21 +110,18 @@ int UnixTcpSocket::MakeSocket() { ISocket& UnixTcpSocket::Connect(AddressPort const& destination, ConnectedCb connected_cb) { - assert((socket_ != kInvalidSocket) && "Socket is not initialized"); + assert(socket_.has_value() && "Socket is not initialized"); connected_cb_ = std::move(connected_cb); ae_defer[&]() { // add poll for all events to detect connection - poller_->Event(socket_, - EventType::kRead | EventType::kError | EventType::kWrite, - MethodPtr<&UnixTcpSocket::OnPollerEvent>{this}); + socket_->Events(EventType::kRead | EventType::kError | EventType::kWrite); connected_cb_(connection_state_); }; - auto lock = std::scoped_lock{socket_lock_}; - auto addr = GetSockAddr(destination); - auto res = connect(socket_, addr.addr(), static_cast(addr.size)); + auto res = + connect(*socket_->fd(), addr.addr(), static_cast(addr.size)); if (res == -1) { if ((errno == EAGAIN) || (errno == EINPROGRESS)) { AE_TELED_DEBUG("Wait connection"); @@ -140,15 +137,15 @@ ISocket& UnixTcpSocket::Connect(AddressPort const& destination, return *this; } -void UnixTcpSocket::OnPollerEvent(EventType event) { +void UnixTcpSocket::OnPollerEvent(DescriptorType fd, EventType event) { if (connection_state_ == ConnectionState::kConnecting) { - OnConnectionEvent(); + OnConnectionEvent(fd); return; } - UnixSocket::OnPollerEvent(event); + UnixSocket::OnPollerEvent(fd, event); } -void UnixTcpSocket::OnConnectionEvent() { +void UnixTcpSocket::OnConnectionEvent(DescriptorType fd) { ae_defer[&]() { if (connected_cb_) { connected_cb_(connection_state_); @@ -156,7 +153,7 @@ void UnixTcpSocket::OnConnectionEvent() { }; // check socket status - auto sock_err = GetSocketError(); + auto sock_err = GetSocketError(fd); if (!sock_err || (sock_err.value() != 0)) { AE_TELED_ERROR("Connect error {}, {}", sock_err, strerror(sock_err.value_or(0))); diff --git a/aether/transport/system_sockets/sockets/unix_tcp_socket.h b/aether/transport/system_sockets/sockets/unix_tcp_socket.h index 81c29a33..2ee8089e 100644 --- a/aether/transport/system_sockets/sockets/unix_tcp_socket.h +++ b/aether/transport/system_sockets/sockets/unix_tcp_socket.h @@ -31,11 +31,13 @@ class UnixTcpSocket final : public UnixSocket { ISocket& Connect(AddressPort const& destination, ConnectedCb connected_cb) override; + protected: + void OnPollerEvent(DescriptorType fd, EventType event) override; + private: static int MakeSocket(); - void OnPollerEvent(EventType event) override; - void OnConnectionEvent(); + void OnConnectionEvent(DescriptorType fd); ConnectionState connection_state_; ConnectedCb connected_cb_; diff --git a/aether/transport/system_sockets/sockets/unix_udp_socket.cpp b/aether/transport/system_sockets/sockets/unix_udp_socket.cpp index 5f5693e9..5fb07b30 100644 --- a/aether/transport/system_sockets/sockets/unix_udp_socket.cpp +++ b/aether/transport/system_sockets/sockets/unix_udp_socket.cpp @@ -40,9 +40,9 @@ UnixUdpSocket::UnixUdpSocket(Ptr const& poller) int UnixUdpSocket::MakeSocket() { bool created = false; auto sock = socket(AF_INET, SOCK_DGRAM, 0); - if (sock == kInvalidSocket) { + if (sock == kInvalidDescriptor) { AE_TELED_ERROR("Socket creation error {} {}", errno, strerror(errno)); - return kInvalidSocket; + return kInvalidDescriptor; } // close socket on error @@ -55,7 +55,7 @@ int UnixUdpSocket::MakeSocket() { // make socket non-blocking if (fcntl(sock, F_SETFL, O_NONBLOCK) != 0) { AE_TELED_ERROR("Socket set O_NONBLOCK error {} {}", errno, strerror(errno)); - return kInvalidSocket; + return kInvalidDescriptor; } created = true; return sock; @@ -64,17 +64,16 @@ int UnixUdpSocket::MakeSocket() { ISocket& UnixUdpSocket::Connect(AddressPort const& destination, ConnectedCb connected_cb) { // UDP connection means binding socket to a destination address - assert((socket_ != kInvalidSocket) && "Socket is not initialized"); + assert(socket_.has_value() && "Socket is not initialized"); ConnectionState connection_state{ConnectionState::kNone}; ae_defer[&]() { Poll(); connected_cb(connection_state); }; - auto lock = std::scoped_lock{socket_lock_}; - auto addr = GetSockAddr(destination); - auto res = connect(socket_, addr.addr(), static_cast(addr.size)); + auto res = + connect(*socket_->fd(), addr.addr(), static_cast(addr.size)); if (res == -1) { AE_TELED_ERROR("Not connected {} {}", errno, strerror(errno)); connection_state = ConnectionState::kConnectionFailed; diff --git a/aether/transport/system_sockets/sockets/win_socket.cpp b/aether/transport/system_sockets/sockets/win_socket.cpp index fe77ab49..8bf19064 100644 --- a/aether/transport/system_sockets/sockets/win_socket.cpp +++ b/aether/transport/system_sockets/sockets/win_socket.cpp @@ -28,7 +28,7 @@ namespace ae { WinSocket::WinSocket(IPoller& poller, std::size_t max_packet_size) - : poller_{static_cast(poller.Native())}, + : poller_{std::static_pointer_cast(poller.Native())}, recv_overlapped_{}, send_overlapped_{}, recv_buffer_(max_packet_size) {} diff --git a/aether/transport/system_sockets/sockets/win_socket.h b/aether/transport/system_sockets/sockets/win_socket.h index ab30af1c..141bdcb9 100644 --- a/aether/transport/system_sockets/sockets/win_socket.h +++ b/aether/transport/system_sockets/sockets/win_socket.h @@ -34,9 +34,6 @@ namespace ae { class WinSocket : public ISocket { public: - static constexpr auto kInvalidSocketValue = - static_cast(~0); - WinSocket(IPoller& poller, std::size_t max_packet_size); ~WinSocket() override; @@ -58,9 +55,9 @@ class WinSocket : public ISocket { bool RequestRecv(); std::optional HandleRecv(); - IoCpPoller* poller_; + std::shared_ptr poller_; SocketInitializer socket_initializer_; - DescriptorType::Socket socket_ = kInvalidSocketValue; + DescriptorType::Socket socket_ = kInvalidDescriptor; std::mutex socket_lock_; ReadyToWriteCb ready_to_write_cb_; From c1326e871ae5398f154080617c91a2fb5a06c058 Mon Sep 17 00:00:00 2001 From: BartolomeyKant Date: Tue, 2 Jun 2026 18:53:52 +0500 Subject: [PATCH 2/2] fix udp buffer swap --- aether/transport/system_sockets/udp/udp.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aether/transport/system_sockets/udp/udp.cpp b/aether/transport/system_sockets/udp/udp.cpp index 6af8dac6..d2b6fc91 100644 --- a/aether/transport/system_sockets/udp/udp.cpp +++ b/aether/transport/system_sockets/udp/udp.cpp @@ -72,8 +72,10 @@ void UdpBase::OnRecvData(Span data) { read_event_sub_ = ae_context_.scheduler().Task([&]() { auto buffers = std::invoke([&]() { auto lock = std::scoped_lock{socket_mutex_}; + auto rb = std::vector(); + std::swap(rb, read_buffers_); read_event_ = false; - return std::move(read_buffers_); + return rb; }); for (auto const& d : buffers) {