From 74db29693dd8922cf8e4e3292879bc79deefb033 Mon Sep 17 00:00:00 2001 From: geofmureithi Date: Wed, 20 May 2026 23:58:05 +0300 Subject: [PATCH] chore (api): metadata as a key-value store --- CHANGELOG.md | 5 +- Cargo.lock | 16 +- apalis-core/src/backend/impls/memory.rs | 35 +- apalis-core/src/task/builder.rs | 21 +- apalis-core/src/task/extensions.rs | 59 +- apalis-core/src/task/metadata.rs | 621 +++++++++++++++++- apalis-core/src/task/mod.rs | 34 +- apalis-sql/src/context.rs | 38 +- apalis-sql/src/from_row.rs | 25 +- apalis-workflow/src/dag/context.rs | 185 +++++- apalis-workflow/src/dag/error.rs | 6 + apalis-workflow/src/dag/executor.rs | 15 +- apalis-workflow/src/dag/service.rs | 32 +- apalis-workflow/src/lib.rs | 4 +- .../src/sequential/and_then/mod.rs | 6 +- apalis-workflow/src/sequential/context.rs | 43 +- apalis-workflow/src/sequential/delay/mod.rs | 8 +- .../src/sequential/filter_map/mod.rs | 181 ++++- apalis-workflow/src/sequential/fold/mod.rs | 79 ++- .../src/sequential/repeat_until/mod.rs | 98 ++- apalis-workflow/src/sequential/service.rs | 9 +- apalis-workflow/src/sequential/workflow.rs | 4 +- apalis-workflow/src/sink.rs | 16 +- apalis/src/layers/retry/mod.rs | 53 +- apalis/src/layers/tracing/contextual_span.rs | 11 +- apalis/src/layers/tracing/mod.rs | 11 +- apalis/src/lib.rs | 6 +- apalis/tests/otel_context_propagation.rs | 6 +- examples/monitor/src/main.rs | 2 +- examples/tracing/src/main.rs | 2 +- supply-chain/config.toml | 8 +- utils/apalis-file-storage/src/meta.rs | 25 +- 32 files changed, 1371 insertions(+), 293 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39a1fb3d..1321b102 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,14 +4,15 @@ All notable changes to this project are documented in this file. ## [Unreleased] -- **chore**: improve `FromRequest` api adding `Option` ([#745](https://github.com/geofmureithi/apalis/pull/745)) +- **chore (api)!**: metadata as a key-value store ([#747](https://github.com/geofmureithi/apalis/pull/747)) +- **chore**: improve `FromRequest` api adding `Option` ([#746](https://github.com/geofmureithi/apalis/pull/746)) - **chore**: bump to v1.0.0 rc.9 ([#744](https://github.com/geofmureithi/apalis/pull/744)) - **feat (api)!**: remove queue input in expose endpoints ([#741](https://github.com/geofmureithi/apalis/pull/741)) - **feat**: idempotency for sql tasks ([#736](https://github.com/geofmureithi/apalis/pull/736)) - **chore**: bump to v1.0.0 rc.8 ([#734](https://github.com/geofmureithi/apalis/pull/734)) - **feat**: idempotency for tasks ([#726](https://github.com/apalis-dev/apalis/pull/726)) - **fix(tracing)**: improve OpenTelemetry context propagation across worker tracing layers ([#716](https://github.com/apalis-dev/apalis/pull/716)) -- **deps(deps)**: bump sentry-* from 0.46.2 to 0.47.0 ([#715](https://github.com/apalis-dev/apalis/pull/715)) +- **deps(deps)**: bump sentry-\* from 0.46.2 to 0.47.0 ([#715](https://github.com/apalis-dev/apalis/pull/715)) - **chore**: bump to v1.0.0 rc.7 ([#714](https://github.com/apalis-dev/apalis/pull/714)) - **chore**: bump to v1.0.0 rc.6 ([#705](https://github.com/apalis-dev/apalis/pull/705)) - **fix (workflow)**: remove Display constraints in workflow service ([#704](https://github.com/apalis-dev/apalis/pull/704)) diff --git a/Cargo.lock b/Cargo.lock index c7eed335..b6b4b6ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -790,9 +790,9 @@ dependencies = [ [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" [[package]] name = "email-service" @@ -1050,9 +1050,9 @@ checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-timer" -version = "3.0.3" +version = "3.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" +checksum = "af43fadb8a98512d547e37b4e92e0ced13e205c061b87b4623eff01d918d6968" [[package]] name = "futures-util" @@ -1816,9 +1816,9 @@ dependencies = [ [[package]] name = "nix" -version = "0.30.1" +version = "0.31.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +checksum = "cf20d2fde8ff38632c426f1165ed7436270b44f199fc55284c38276f9db47c3d" dependencies = [ "bitflags", "cfg-if", @@ -2177,9 +2177,9 @@ dependencies = [ [[package]] name = "os_info" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4022a17595a00d6a369236fdae483f0de7f0a339960a53118b818238e132224" +checksum = "9cf20a545b305cf1da722b236b5155c9bb35f1d5ceb28c048bd96ca842f41b5b" dependencies = [ "android_system_properties", "log", diff --git a/apalis-core/src/backend/impls/memory.rs b/apalis-core/src/backend/impls/memory.rs index dcdcb9a5..b4147bb7 100644 --- a/apalis-core/src/backend/impls/memory.rs +++ b/apalis-core/src/backend/impls/memory.rs @@ -40,7 +40,7 @@ use crate::backend::BackendExt; use crate::backend::codec::IdentityCodec; use crate::features_table; -use crate::task::extensions::Extensions; +use crate::task::metadata::{MetadataExt, MetadataStore}; use crate::{ backend::{Backend, TaskStream}, task::{ @@ -109,25 +109,44 @@ pub type BoxedReceiver = Pin not_supported("List all workers registered with the backend"), ListTasks => not_supported("List all tasks in the backend"), }] -pub struct MemoryStorage { +pub struct MemoryStorage { pub(super) sender: MemorySink, pub(super) receiver: BoxedReceiver, } -impl Default for MemoryStorage { +impl Default for MemoryStorage { fn default() -> Self { Self::new() } } -impl MemoryStorage { +/// Store extra context related to task metadata +#[derive(Debug, Clone, Default)] +pub struct MemoryContext { + metadata: MetadataStore, +} + +impl MetadataExt for MemoryContext { + fn metadata(&self) -> &MetadataStore { + &self.metadata + } + + fn metadata_mut(&mut self) -> &mut MetadataStore { + &mut self.metadata + } +} + +impl MemoryStorage { /// Create a new in-memory storage #[must_use] pub fn new() -> Self { let (sender, receiver) = unbounded(); let sender = Box::new(sender) as Box< - dyn Sink, Error = SendError> + Send + Sync + Unpin, + dyn Sink, Error = SendError> + + Send + + Sync + + Unpin, >; Self { sender: MemorySink { @@ -170,7 +189,7 @@ impl Sink> for MemoryStorage { } } -type ArcMemorySink = Arc< +type ArcMemorySink = Arc< Mutex< Box, Error = SendError> + Send + Sync + Unpin + 'static>, >, @@ -179,7 +198,7 @@ type ArcMemorySink = Arc< type ArcIdempotencySet = Arc>>; /// Memory sink for sending tasks to the in-memory backend -pub struct MemorySink { +pub struct MemorySink { pub(super) inner: ArcMemorySink, pub(super) idempotency_keys: ArcIdempotencySet, } @@ -226,6 +245,8 @@ impl Sink> for MemorySink { ) -> Result<(), Self::Error> { let this = self.get_mut(); + let _ = item.parts.data.get_or_insert(MetadataStore::default()); + // Ensure task id exists item.parts .task_id diff --git a/apalis-core/src/task/builder.rs b/apalis-core/src/task/builder.rs index 7a49883d..175780aa 100644 --- a/apalis-core/src/task/builder.rs +++ b/apalis-core/src/task/builder.rs @@ -25,10 +25,17 @@ //! ``` //! use crate::task::{ - Parts, Task, attempt::Attempt, extensions::Extensions, metadata::MetadataExt, status::Status, + Parts, Task, + attempt::Attempt, + extensions::Extensions, + metadata::{Metadata, MetadataExt}, + status::Status, task_id::TaskId, }; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use std::{ + fmt::Debug, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; /// Builder for creating [`Task`] instances with optional configuration #[derive(Debug)] @@ -87,13 +94,13 @@ impl TaskBuilder { /// Insert a value into the task's ctx context #[must_use] - pub fn meta(mut self, value: M) -> Self + pub fn meta(mut self, value: &M) -> Self where - Ctx: MetadataExt, + Ctx: MetadataExt, + M: Metadata, + M::Error: Debug, { - self.ctx - .inject(value) - .unwrap_or_else(|_| panic!("Failed to inject item into context")); + self.ctx.inject(value).expect("Could not add Metadata"); self } diff --git a/apalis-core/src/task/extensions.rs b/apalis-core/src/task/extensions.rs index fa826ee6..aa42d9a7 100644 --- a/apalis-core/src/task/extensions.rs +++ b/apalis-core/src/task/extensions.rs @@ -27,7 +27,6 @@ use std::fmt; use std::hash::{BuildHasherDefault, Hasher}; use crate::task::data::MissingDataError; -use crate::task::metadata::MetadataExt; type AnyMap = HashMap, BuildHasherDefault>; @@ -256,6 +255,53 @@ impl Extensions { } } } + + /// Get a mutable reference to the value of type `T`, + /// inserting `value` if it does not already exist. + /// + /// # Example + /// + /// ``` + /// # use apalis_core::task::extensions::Extensions; + /// let mut ext = Extensions::new(); + /// + /// let value = ext.get_or_insert(String::from("Hello")); + /// value.push_str(" World"); + /// + /// assert_eq!(ext.get::().unwrap(), "Hello World"); + /// ``` + pub fn get_or_insert(&mut self, value: T) -> &mut T + where + T: Clone + Send + Sync + 'static, + { + self.get_or_insert_with(|| value) + } + + /// Get a mutable reference to the value of type `T`, + /// inserting the result of `f` if it does not already exist. + /// + /// # Example + /// + /// ``` + /// # use apalis_core::task::extensions::Extensions; + /// let mut ext = Extensions::new(); + /// + /// let value = ext.get_or_insert_with(|| String::from("Hello")); + /// value.push_str(" World"); + /// + /// assert_eq!(ext.get::().unwrap(), "Hello World"); + /// ``` + pub fn get_or_insert_with(&mut self, f: F) -> &mut T + where + T: Clone + Send + Sync + 'static, + F: FnOnce() -> T, + { + if self.get::().is_none() { + self.insert(f()); + } + + self.get_mut::().expect("value was just inserted") + } } impl fmt::Debug for Extensions { @@ -295,17 +341,6 @@ impl Clone for Box { } } -impl MetadataExt for Extensions { - type Error = MissingDataError; - fn inject(&mut self, value: T) -> Result<(), Self::Error> { - self.insert(value); - Ok(()) - } - fn extract(&self) -> Result { - Ok(self.get_checked::()?.clone()) - } -} - #[test] fn test_extensions() { #[derive(Clone, Debug, PartialEq)] diff --git a/apalis-core/src/task/metadata.rs b/apalis-core/src/task/metadata.rs index 3c3f7a9f..52dd6fa4 100644 --- a/apalis-core/src/task/metadata.rs +++ b/apalis-core/src/task/metadata.rs @@ -4,15 +4,20 @@ //! It includes implementations for common metadata types. //! //! ## Overview -//! - `MetadataExt`: A trait for extracting and injecting metadata of type `T`. +//! - `Metadata`: A trait for extracting and injecting metadata. //! //! # Usage -//! Implement the `MetadataExt` trait for your metadata types to enable easy extraction and injection +//! Implement the `MetadataExt` trait for your context types to enable easy extraction and injection //! from task contexts. This allows middleware and services to access and modify task metadata in a //! type-safe manner. use crate::task::Task; use crate::task_fn::FromRequest; +use std::collections::HashMap; +use std::convert::Infallible; +use std::fmt; use std::ops::Deref; +#[cfg(feature = "tracing")] +use std::str::FromStr; /// Metadata wrapper for task contexts. #[derive(Debug, Clone)] @@ -25,27 +30,466 @@ impl Deref for Meta { } } -/// Task metadata extension trait and implementations. -/// This trait allows for injecting and extracting metadata associated with tasks. -pub trait MetadataExt { - /// The error type that can occur during extraction or injection. +/// Extension trait for types that expose task metadata. +/// +/// `MetadataExt` provides a uniform interface for accessing, mutating, +/// injecting, and extracting strongly typed metadata values backed by a +/// [`MetadataStore`]. +/// +/// This trait is commonly implemented by task contexts, job payloads, +/// execution environments, or workflow state containers that need to +/// carry metadata across system boundaries. +/// +/// # Provided Methods +/// +/// - [`metadata`](Self::metadata): Returns an immutable reference to the +/// underlying [`MetadataStore`]. +/// - [`metadata_mut`](Self::metadata_mut): Returns a mutable reference to +/// the underlying [`MetadataStore`]. +/// - [`extract`](Self::extract): Extracts a typed metadata value from the store. +/// - [`inject`](Self::inject): Injects a typed metadata value into the store. +/// +/// # Examples +/// +/// ```rust +/// # use std::collections::HashMap; +/// # use apalis_core::task::metadata::Metadata; +/// # use apalis_core::task::metadata::MetadataStore; +/// # use apalis_core::task::metadata::MetadataExt; +/// # +/// struct TaskContext { +/// metadata: MetadataStore, +/// } +/// +/// impl MetadataExt for TaskContext { +/// fn metadata(&self) -> &MetadataStore { +/// &self.metadata +/// } +/// +/// fn metadata_mut(&mut self) -> &mut MetadataStore { +/// &mut self.metadata +/// } +/// } +/// ``` +/// +/// Injecting and extracting typed metadata: +/// +/// ```rust +/// # use std::collections::HashMap; +/// # use apalis_core::task::metadata::Metadata; +/// # use apalis_core::task::metadata::MetadataStore; +/// # use apalis_core::task::metadata::MetadataExt; +/// # +/// #[derive(Debug, PartialEq)] +/// struct RequestId(String); +/// +/// impl Metadata for RequestId { +/// type Error = std::convert::Infallible; +/// +/// fn inject(&self, metadata: &mut MetadataStore) -> Result<(), Self::Error> { +/// let _ = metadata.insert("request_id", self.0.clone()); +/// Ok(()) +/// } +/// +/// fn extract(metadata: &MetadataStore) -> Result { +/// Ok(Self( +/// metadata +/// .get("request_id") +/// .cloned() +/// .unwrap_or_default(), +/// )) +/// } +/// } +/// +/// struct Context { +/// metadata: MetadataStore, +/// } +/// +/// impl MetadataExt for Context { +/// fn metadata(&self) -> &MetadataStore { +/// &self.metadata +/// } +/// +/// fn metadata_mut(&mut self) -> &mut MetadataStore { +/// &mut self.metadata +/// } +/// } +/// +/// let mut ctx = Context { +/// metadata: MetadataStore::new(), +/// }; +/// +/// ctx.inject(&RequestId("req-123".into())); +/// +/// let request_id = ctx.extract::()?; +/// +/// assert_eq!(request_id, RequestId("req-123".into())); +/// +/// # Ok::<(), std::convert::Infallible>(()) +/// ``` +pub trait MetadataExt { + /// Returns an immutable reference to the underlying metadata store. + fn metadata(&self) -> &MetadataStore; + + /// Returns a mutable reference to the underlying metadata store. + fn metadata_mut(&mut self) -> &mut MetadataStore; + + /// Extracts a strongly typed metadata value from the underlying store. + /// + /// This method delegates extraction logic to the [`Metadata`] implementation + /// for `T`. + /// + /// # Errors + /// + /// Returns `T::Error` if extraction fails. + /// + /// # Examples + /// + /// ```rust + /// # use std::convert::Infallible; + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # + /// struct UserId(String); + /// + /// impl Metadata for UserId { + /// type Error = Infallible; + /// + /// fn inject(&self, metadata: &mut MetadataStore) -> Result<(), Self::Error> { + /// let _ = metadata.insert("user_id", self.0.clone()); + /// Ok(()) + /// } + /// + /// fn extract(metadata: &MetadataStore) -> Result { + /// Ok(Self( + /// metadata + /// .get("user_id") + /// .cloned() + /// .unwrap_or_default(), + /// )) + /// } + /// } + /// ``` + fn extract(&self) -> Result { + T::extract(self.metadata()) + } + + /// Injects a strongly typed metadata value into the underlying store. + /// + /// This method delegates serialization/injection logic to the [`Metadata`] + /// implementation for `T`. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// struct CorrelationId(String); + /// + /// impl Metadata for CorrelationId { + /// type Error = std::convert::Infallible; + /// + /// fn inject(&self, metadata: &mut MetadataStore) -> Result<(), Self::Error> { + /// let _ = metadata.insert("correlation_id", self.0.clone()); + /// Ok(()) + /// } + /// + /// fn extract(_: &MetadataStore) -> Result { + /// unreachable!() + /// } + /// } + /// ``` + fn inject(&mut self, value: &T) -> Result<(), T::Error> { + value.inject(self.metadata_mut()) + } +} + +/// A lightweight key-value metadata store backed by a `HashMap`. +/// +/// `MetadataStore` is designed for storing arbitrary string metadata such as +/// task attributes, labels, annotations, headers, or contextual information. +/// +/// Keys are unique within the store. Attempting to insert a duplicate key +/// returns a [`MetadataError::DuplicateKey`] error. +/// +/// # Examples +/// +/// ```rust +/// # use std::collections::HashMap; +/// # use apalis_core::task::metadata::Metadata; +/// # use apalis_core::task::metadata::MetadataStore; +/// # use apalis_core::task::metadata::MetadataError; +/// # +/// let mut metadata = MetadataStore::new(); +/// +/// metadata.insert("request_id", "abc-123")?; +/// metadata.insert("environment", "production")?; +/// +/// assert_eq!( +/// metadata.get("request_id"), +/// Some(&"abc-123".to_string()) +/// ); +/// +/// assert!(metadata.contains_key("environment")); +/// +/// # Ok::<(), MetadataError>(()) +/// ``` +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct MetadataStore(HashMap); + +/// Errors returned by [`MetadataStore`]. +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum MetadataError { + /// Returned when attempting to insert a key that already exists. + #[error("The key already exists in the store")] + DuplicateKey(String), +} + +impl MetadataStore { + /// Creates an empty [`MetadataStore`]. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// let metadata = MetadataStore::new(); + /// + /// assert_eq!(metadata.iter().count(), 0); + /// ``` + #[must_use] + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Inserts a key-value pair into the store. + /// + /// Returns an error if the key already exists. + /// + /// # Errors + /// + /// Returns [`MetadataError::DuplicateKey`] if the provided key is already + /// present in the store. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # use apalis_core::task::metadata::MetadataError; + /// let mut metadata = MetadataStore::new(); + /// + /// metadata.insert("region", "us-east-1")?; + /// + /// assert_eq!( + /// metadata.get("region"), + /// Some(&"us-east-1".to_string()) + /// ); + /// + /// # Ok::<(), MetadataError>(()) + /// ``` + /// + /// Duplicate keys are rejected: + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # use apalis_core::task::metadata::MetadataError; + /// let mut metadata = MetadataStore::new(); + /// + /// metadata.insert("service", "api")?; + /// + /// let err = metadata.insert("service", "worker").unwrap_err(); + /// + /// assert_eq!( + /// err, + /// MetadataError::DuplicateKey("service".to_string()) + /// ); + /// + /// # Ok::<(), MetadataError>(()) + /// ``` + pub fn insert(&mut self, key: K, value: V) -> Result<(), MetadataError> + where + K: Into, + V: Into, + { + let key = key.into(); + + if self.0.contains_key(&key) { + return Err(MetadataError::DuplicateKey(key)); + } + + self.0.insert(key, value.into()); + + Ok(()) + } + + /// Returns a reference to the value corresponding to the given key. + /// + /// Returns `None` if the key does not exist. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # use apalis_core::task::metadata::MetadataError; + /// let mut metadata = MetadataStore::new(); + /// + /// metadata.insert("version", "1.0")?; + /// + /// assert_eq!( + /// metadata.get("version"), + /// Some(&"1.0".to_string()) + /// ); + /// + /// assert_eq!(metadata.get("missing"), None); + /// + /// # Ok::<(), MetadataError>(()) + /// ``` + #[must_use] + pub fn get(&self, key: &str) -> Option<&String> { + self.0.get(key) + } + + /// Removes a key from the store, returning the stored value if it existed. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # use apalis_core::task::metadata::MetadataError; + /// let mut metadata = MetadataStore::new(); + /// + /// metadata.insert("token", "secret")?; + /// + /// assert_eq!( + /// metadata.remove("token"), + /// Some("secret".to_string()) + /// ); + /// + /// assert!(!metadata.contains_key("token")); + /// + /// # Ok::<(), MetadataError>(()) + /// ``` + pub fn remove(&mut self, key: &str) -> Option { + self.0.remove(key) + } + + /// Returns `true` if the store contains the specified key. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # use apalis_core::task::metadata::MetadataError; + /// let mut metadata = MetadataStore::new(); + /// + /// metadata.insert("owner", "alice")?; + /// + /// assert!(metadata.contains_key("owner")); + /// assert!(!metadata.contains_key("missing")); + /// + /// # Ok::<(), MetadataError>(()) + /// ``` + #[must_use] + pub fn contains_key(&self, key: &str) -> bool { + self.0.contains_key(key) + } + + /// Returns an iterator over all key-value pairs in the store. + /// + /// The iterator yields `(&String, &String)` pairs. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # use apalis_core::task::metadata::MetadataError; + /// let mut metadata = MetadataStore::new(); + /// + /// metadata.insert("a", "1")?; + /// metadata.insert("b", "2")?; + /// + /// let items: Vec<_> = metadata.iter().collect(); + /// + /// assert_eq!(items.len(), 2); + /// + /// # Ok::<(), MetadataError>(()) + /// ``` + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + /// Consumes the store and returns the underlying `HashMap`. + /// + /// # Examples + /// + /// ```rust + /// # use apalis_core::task::metadata::Metadata; + /// # use apalis_core::task::metadata::MetadataStore; + /// # use apalis_core::task::metadata::MetadataError; + /// let mut metadata = MetadataStore::new(); + /// + /// metadata.insert("key", "value")?; + /// + /// let inner = metadata.into_inner(); + /// + /// assert_eq!( + /// inner.get("key"), + /// Some(&"value".to_string()) + /// ); + /// + /// # Ok::<(), MetadataError>(()) + /// ``` + #[must_use] + pub fn into_inner(self) -> HashMap { + self.0 + } + + /// Get a typed metadata entry. + pub fn extract_as(&self) -> Result { + M::extract(self) + } +} + +/// Implemented by types that can be stored as metadata. +/// Provides a stable key and string-based serialization. +pub trait Metadata: Sized { + /// The error produced when extracting the Metadata type Error; - /// Extract metadata of type `T`. - fn extract(&self) -> Result; - /// Inject metadata of type `T`. - fn inject(&mut self, value: T) -> Result<(), Self::Error>; + + /// Extract `Metadata` from the store + fn extract(store: &MetadataStore) -> Result; + + /// Inject [`Self`] into the store + fn inject(&self, map: &mut MetadataStore) -> Result<(), Self::Error>; } -impl + Send + Sync, IdType: Send + Sync> +impl FromRequest> for Meta { - type Error = Ctx::Error; + type Error = T::Error; async fn from_request(task: &Task) -> Result { task.parts.ctx.extract().map(Meta) } } +impl + FromRequest> for MetadataStore +{ + type Error = Infallible; + + async fn from_request(task: &Task) -> Result { + Ok(task.parts.ctx.metadata().clone()) + } +} + /// Metadata used specifically for storing tracing context. #[cfg(feature = "tracing")] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -118,16 +562,126 @@ impl TracingContext { } } +#[cfg(feature = "tracing")] +/// Error provided by parsing TracingContext +#[derive(Debug, thiserror::Error)] +pub enum TracingContextParseError { + /// Missing Field + #[error("Missing Field: {0}")] + MissingField(&'static str), + /// Invalid flags + #[error("Invalid flags: {0}")] + InvalidTraceFlags(std::num::ParseIntError), + /// Invalid Format + #[error("Invalid Format")] + InvalidFormat, + /// Key {apalis_core.tracing.context} not found in Metadata + #[error("Key {{apalis_core.tracing.context}} not found in Metadata")] + MissingKey, + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +// Serialization format: a single W3C traceparent-style string. +// +// ;;; +// +// Each field is either its value or `-` if None. +// +// Example: +// "4bf92f3577b34da6a3ce929d0e0e4736;00f067aa0ba902b7;01;congo=t61rcWkgMzE" +// "4bf92f3577b34da6a3ce929d0e0e4736;00f067aa0ba902b7;-;-" +#[cfg(feature = "tracing")] +impl fmt::Display for TracingContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{};{};{};{}", + self.trace_id.as_deref().unwrap_or("-"), + self.span_id.as_deref().unwrap_or("-"), + self.trace_flags + .map(|v| v.to_string()) + .as_deref() + .unwrap_or("-"), + self.trace_state.as_deref().unwrap_or("-"), + ) + } +} + +#[cfg(feature = "tracing")] +impl FromStr for TracingContext { + type Err = TracingContextParseError; + + fn from_str(s: &str) -> Result { + let mut parts = s.splitn(4, ';'); + + let mut next = |field| { + parts + .next() + .ok_or(TracingContextParseError::MissingField(field)) + }; + + let trace_id = match next("trace_id")? { + "-" => None, + v => Some(v.to_owned()), + }; + + let span_id = match next("span_id")? { + "-" => None, + v => Some(v.to_owned()), + }; + + let trace_flags = match next("trace_flags")? { + "-" => None, + v => Some( + v.parse::() + .map_err(TracingContextParseError::InvalidTraceFlags)?, + ), + }; + + let trace_state = match next("trace_state")? { + "-" => None, + v => Some(v.to_owned()), + }; + + Ok(Self { + trace_id, + span_id, + trace_flags, + trace_state, + }) + } +} + +#[cfg(feature = "tracing")] +const TRACING_CONTENT_KEY: &str = "apalis_core.tracing.context"; + +#[cfg(feature = "tracing")] +impl Metadata for TracingContext { + type Error = TracingContextParseError; + fn extract(store: &MetadataStore) -> Result { + store + .get(TRACING_CONTENT_KEY) + .ok_or(TracingContextParseError::InvalidFormat)? + .parse() + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), Self::Error> { + Ok(map.insert(TRACING_CONTENT_KEY, self.to_string())?) + } +} + #[cfg(test)] #[allow(unused)] mod tests { - use std::{convert::Infallible, fmt::Debug, task::Poll, time::Duration}; + use std::{convert::Infallible, fmt::Debug, num::ParseIntError, task::Poll, time::Duration}; use crate::{ error::BoxDynError, task::{ Task, - metadata::{Meta, MetadataExt}, + metadata::{Meta, Metadata, MetadataExt, MetadataStore}, }, task_fn::FromRequest, }; @@ -143,17 +697,35 @@ mod tests { timeout: Duration, } - struct SampleStore; + const EXAMPLE_CONFIG: &str = "apalis_core.example.config"; - impl MetadataExt for SampleStore { - type Error = Infallible; - fn extract(&self) -> Result { - Ok(ExampleConfig { - timeout: Duration::from_secs(1), - }) + impl Metadata for ExampleConfig { + type Error = ParseIntError; + fn extract(store: &MetadataStore) -> Result { + let timeout = store + .get(EXAMPLE_CONFIG) + .unwrap() + .parse::() + .map(Duration::from_secs)?; + Ok(ExampleConfig { timeout }) + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), ParseIntError> { + let value = self.timeout.as_secs().to_string(); + map.insert(EXAMPLE_CONFIG, value).unwrap(); + Ok(()) } - fn inject(&mut self, _: ExampleConfig) -> Result<(), Self::Error> { - unreachable!() + } + + struct SampleStore(MetadataStore); + + impl MetadataExt for SampleStore { + fn metadata(&self) -> &MetadataStore { + &self.0 + } + + fn metadata_mut(&mut self) -> &mut MetadataStore { + &mut self.0 } } @@ -161,8 +733,7 @@ mod tests { Service> for ExampleService where S: Service> + Clone + Send + 'static, - Ctx: MetadataExt + Send, - Ctx::Error: Debug, + Ctx: MetadataExt + Send, S::Future: Send + 'static, { type Response = S::Response; diff --git a/apalis-core/src/task/mod.rs b/apalis-core/src/task/mod.rs index 83192c38..57be362c 100644 --- a/apalis-core/src/task/mod.rs +++ b/apalis-core/src/task/mod.rs @@ -14,7 +14,7 @@ //! //! The [`Task`] struct is generic over: //! - `Args`: The type of arguments or payload for the task. -//! - `Ctx`: Ctxdata associated with the task, such as custom fields or backend-specific information. +//! - `Ctx`: Data associated with the task, such as custom fields or backend-specific information. //! - `IdType`: The type used for uniquely identifying the task (defaults to [`RandomId`]). //! //! ## [`Parts`] @@ -56,13 +56,35 @@ //! ```rust //! # use apalis_core::task::{Task, Parts}; //! # use apalis_core::task::builder::TaskBuilder; -//! # use apalis_core::task::extensions::Extensions; //! # use apalis_core::task::task_id::RandomId; +//! # use apalis_core::task::metadata::Metadata; +//! # use apalis_core::task::metadata::MetadataStore; +//! # use apalis_core::task::metadata::MetadataExt; +//! # use apalis_core::backend::memory::MemoryContext; +//! # +//! #[derive(Debug, PartialEq)] +//! struct RequestId(String); //! -//! #[derive(Default, Clone)] -//! struct MyCtx { priority: u8 } -//! let task: Task = TaskBuilder::new("important work".to_string()) -//! .meta(MyCtx { priority: 5 }) +//! impl Metadata for RequestId { +//! type Error = std::convert::Infallible; +//! +//! fn inject(&self, metadata: &mut MetadataStore) -> Result<(), Self::Error> { +//! let _ = metadata.insert("request_id", self.0.clone()); +//! Ok(()) +//! } +//! +//! fn extract(metadata: &MetadataStore) -> Result { +//! Ok(Self( +//! metadata +//! .get("request_id") +//! .cloned() +//! .unwrap_or_default(), +//! )) +//! } +//! } +//! +//! let task: Task = TaskBuilder::new("important work".to_string()) +//! .meta(&RequestId("user_id".to_string())) //! .build(); //! ``` //! diff --git a/apalis-sql/src/context.rs b/apalis-sql/src/context.rs index 15f589d7..05f9c717 100644 --- a/apalis-sql/src/context.rs +++ b/apalis-sql/src/context.rs @@ -1,16 +1,14 @@ use std::convert::Infallible; use apalis_core::{ - task::{Task, metadata::MetadataExt}, + task::{ + Task, + metadata::{MetadataExt, MetadataStore}, + }, task_fn::FromRequest, }; -type JsonMapMetadata = serde_json::Map; - -use serde::{ - Deserialize, Serialize, - de::{DeserializeOwned, Error}, -}; +use serde::{Deserialize, Serialize}; /// The SQL context used for jobs stored in a SQL database #[derive(Debug, Serialize, Deserialize)] @@ -22,7 +20,7 @@ pub struct SqlContext { done_at: Option, priority: i32, queue: Option, - meta: JsonMapMetadata, + meta: MetadataStore, // Marker to hold the Pool type // Used to associate the context with a specific database pool type _pool: std::marker::PhantomData, @@ -160,13 +158,13 @@ impl SqlContext { /// Get the metadata map #[must_use] - pub fn meta(&self) -> &JsonMapMetadata { + pub fn meta(&self) -> &MetadataStore { &self.meta } /// Set the metadata map #[must_use] - pub fn with_meta(mut self, meta: JsonMapMetadata) -> Self { + pub fn with_meta(mut self, meta: MetadataStore) -> Self { self.meta = meta; self } @@ -181,19 +179,11 @@ impl FromRequest> } } -impl MetadataExt for SqlContext { - type Error = serde_json::Error; - fn extract(&self) -> Result { - self.meta - .get(std::any::type_name::()) - .and_then(|v| T::deserialize(v).ok()) - .ok_or(serde_json::Error::custom("Failed to extract metadata")) - } - fn inject(&mut self, value: T) -> Result<(), Self::Error> { - self.meta.insert( - std::any::type_name::().to_owned(), - serde_json::to_value(&value).unwrap(), - ); - Ok(()) +impl MetadataExt for SqlContext { + fn metadata(&self) -> &MetadataStore { + &self.meta + } + fn metadata_mut(&mut self) -> &mut MetadataStore { + &mut self.meta } } diff --git a/apalis-sql/src/from_row.rs b/apalis-sql/src/from_row.rs index 1bd87e28..75fbd5e6 100644 --- a/apalis-sql/src/from_row.rs +++ b/apalis-sql/src/from_row.rs @@ -3,7 +3,10 @@ use std::str::FromStr; use apalis_core::{ backend::codec::Codec, error::BoxDynError, - task::{Task, attempt::Attempt, builder::TaskBuilder, status::Status, task_id::TaskId}, + task::{ + Task, attempt::Attempt, builder::TaskBuilder, metadata::MetadataStore, status::Status, + task_id::TaskId, + }, }; use crate::context::SqlContext; @@ -53,7 +56,7 @@ pub struct TaskRow { /// Priority level of the task (higher values indicate higher priority) pub priority: Option, /// Additional metadata associated with the task, stored as JSON - pub metadata: Option, + pub metadata: Option, /// Idempotency key for enforcing uniqueness pub idempotency_key: Option, } @@ -76,17 +79,7 @@ impl TaskRow { .with_max_attempts(self.max_attempts.unwrap_or(25) as i32) .with_last_result(self.last_result) .with_priority(self.priority.unwrap_or(0) as i32) - .with_meta( - self.metadata - .map(|m| { - serde_json::to_value(&m) - .unwrap_or_default() - .as_object() - .cloned() - .unwrap_or_default() - }) - .unwrap_or_default(), - ) + .with_meta(self.metadata.unwrap_or_default()) .with_queue(self.job_type) .with_lock_at(self.lock_at.map(|dt| dt.to_unix_timestamp())); @@ -126,11 +119,7 @@ impl TaskRow { .with_max_attempts(self.max_attempts.unwrap_or(25) as i32) .with_last_result(self.last_result) .with_priority(self.priority.unwrap_or(0) as i32) - .with_meta( - self.metadata - .map(|m| m.as_object().cloned().unwrap()) - .unwrap_or_default(), - ) + .with_meta(self.metadata.unwrap_or_default()) .with_queue(self.job_type) .with_lock_at(self.lock_at.map(|dt| dt.to_unix_timestamp())); diff --git a/apalis-workflow/src/dag/context.rs b/apalis-workflow/src/dag/context.rs index 8a9cf464..f09a20d7 100644 --- a/apalis-workflow/src/dag/context.rs +++ b/apalis-workflow/src/dag/context.rs @@ -1,6 +1,17 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + fmt::Display, + num::ParseIntError, + str::FromStr, +}; -use apalis_core::task::task_id::TaskId; +use apalis_core::{ + error::BoxDynError, + task::{ + metadata::{Metadata, MetadataError, MetadataStore}, + task_id::TaskId, + }, +}; use petgraph::graph::NodeIndex; use serde::{Deserialize, Serialize}; @@ -72,3 +83,173 @@ impl DagFlowContext { .collect() } } + +const DAG_FLOW_PREV_NODE_KEY: &str = "apalis_workflow.dag.prev_node"; + +const DAG_FLOW_CURRENT_NODE_KEY: &str = "apalis_workflow.dag.current_node"; + +const DAG_FLOW_COMPLETED_NODES_KEY: &str = "apalis_workflow.dag.completed_nodes"; + +const DAG_FLOW_NODE_TASK_IDS_KEY: &str = "apalis_workflow.dag.node_task_ids"; + +const DAG_FLOW_CURRENT_POSITION_KEY: &str = "apalis_workflow.dag.current_position"; + +const DAG_FLOW_IS_INITIAL_KEY: &str = "apalis_workflow.dag.is_initial"; + +const DAG_FLOW_ROOT_TASK_ID_KEY: &str = "apalis_workflow.dag.root_task_id"; + +/// An error representing an invalid [`DagFlowContext`] +#[derive(Debug, thiserror::Error)] +pub enum DagFlowContextError { + /// Missing current node key + #[error("missing key {DAG_FLOW_CURRENT_NODE_KEY}")] + MissingCurrentNode, + + /// Missing current position key + #[error("missing key {DAG_FLOW_CURRENT_POSITION_KEY}")] + MissingCurrentPosition, + + /// Could not parse a node index + #[error("could not parse node index")] + ParseNodeIndex(#[from] ParseIntError), + + /// Could not parse a task_id + #[error("could not parse task id: {0}")] + ParseTaskId(BoxDynError), + + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +impl Metadata for DagFlowContext +where + IdType: FromStr + Display, + ::Err: std::error::Error + Send + Sync + 'static, +{ + type Error = DagFlowContextError; + + fn extract(map: &MetadataStore) -> Result { + let prev_node = map + .get(DAG_FLOW_PREV_NODE_KEY) + .map(|v| v.parse::()) + .transpose()? + .map(NodeIndex::new); + + let current_node = map + .get(DAG_FLOW_CURRENT_NODE_KEY) + .ok_or(DagFlowContextError::MissingCurrentNode)? + .parse::()?; + + let completed_nodes = map + .get(DAG_FLOW_COMPLETED_NODES_KEY) + .map(|v| { + v.split(',') + .filter(|s| !s.is_empty()) + .map(|s| s.parse::().map(NodeIndex::new)) + .collect::, _>>() + }) + .transpose()? + .unwrap_or_default(); + + let node_task_ids = map + .get(DAG_FLOW_NODE_TASK_IDS_KEY) + .map(|v| { + v.split(',') + .filter(|s| !s.is_empty()) + .map(|s| { + s.split_once('=') + .ok_or(DagFlowContextError::ParseTaskId("Invalid delimiter".into())) + .and_then(|(k, v)| { + let node = k + .parse::() + .map(NodeIndex::new) + .map_err(DagFlowContextError::ParseNodeIndex)?; + let task_id = v + .parse::>() + .map_err(|e| DagFlowContextError::ParseTaskId(e.into()))?; + Ok((node, task_id)) + }) + }) + .collect::, _>>() + }) + .transpose()? + .unwrap_or_default(); + + let current_position = map + .get(DAG_FLOW_CURRENT_POSITION_KEY) + .ok_or(DagFlowContextError::MissingCurrentPosition)? + .parse::()?; + + let is_initial = map + .get(DAG_FLOW_IS_INITIAL_KEY) + .map(|v| v.parse::()) + .transpose() + .unwrap_or(None) + .unwrap_or(true); + + let root_task_id = map + .get(DAG_FLOW_ROOT_TASK_ID_KEY) + .map(|v| v.parse::().map(TaskId::new)) + .transpose() + .map_err(|e| DagFlowContextError::ParseTaskId(e.into()))?; + + Ok(Self { + prev_node, + current_node: NodeIndex::new(current_node), + completed_nodes, + node_task_ids, + current_position, + is_initial, + root_task_id, + }) + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), DagFlowContextError> { + if let Some(prev_node) = self.prev_node { + map.insert(DAG_FLOW_PREV_NODE_KEY, prev_node.index().to_string())?; + } + + map.insert( + DAG_FLOW_CURRENT_NODE_KEY, + self.current_node.index().to_string(), + )?; + + let completed_nodes = self + .completed_nodes + .iter() + .map(|n| n.index()) + .collect::>(); + + map.insert( + DAG_FLOW_COMPLETED_NODES_KEY, + completed_nodes + .iter() + .map(ToString::to_string) + .collect::>() + .join(","), + )?; + + let node_task_ids = self + .node_task_ids + .iter() + .map(|(k, v)| format!("{}={v}", k.index())) + .collect::>() + .join(","); + + map.insert(DAG_FLOW_NODE_TASK_IDS_KEY, node_task_ids)?; + + map.insert( + DAG_FLOW_CURRENT_POSITION_KEY, + self.current_position.to_string(), + )?; + + map.insert(DAG_FLOW_IS_INITIAL_KEY, self.is_initial.to_string()) + .expect("A value already exists"); + + if let Some(root_task_id) = &self.root_task_id { + map.insert(DAG_FLOW_ROOT_TASK_ID_KEY, root_task_id.to_string())?; + } + Ok(()) + } +} diff --git a/apalis-workflow/src/dag/error.rs b/apalis-workflow/src/dag/error.rs index 32dd8b09..8a07a1ba 100644 --- a/apalis-workflow/src/dag/error.rs +++ b/apalis-workflow/src/dag/error.rs @@ -3,6 +3,8 @@ use petgraph::{algo::Cycle, graph::NodeIndex}; use std::fmt::Debug; use thiserror::Error; +use crate::dag::context::DagFlowContextError; + /// Errors that can occur during DAG workflow execution. #[derive(Error, Debug)] pub enum DagFlowError { @@ -64,6 +66,10 @@ pub enum DagFlowError { /// DAG contains cycles. #[error("DAG contains cycles involving nodes: {0:?}")] CyclicDAG(Cycle), + + /// Metadata error + #[error("Metadata error: {0}")] + MetadataError(#[from] DagFlowContextError), } /// Error encountered by Service Error diff --git a/apalis-workflow/src/dag/executor.rs b/apalis-workflow/src/dag/executor.rs index c6563158..60e436c0 100644 --- a/apalis-workflow/src/dag/executor.rs +++ b/apalis-workflow/src/dag/executor.rs @@ -1,13 +1,13 @@ use std::{ collections::{HashMap, VecDeque}, - fmt::Debug, + fmt::{Debug, Display}, pin::Pin, + str::FromStr, task::{Context, Poll}, }; use apalis_core::{ backend::{BackendExt, codec::RawDataBackend}, - error::BoxDynError, task::{ Task, metadata::{Meta, MetadataExt}, @@ -71,14 +71,13 @@ where } } -impl Service> for DagExecutor +impl Service> for DagExecutor where B: BackendExt, - B::Context: - Send + Sync + 'static + MetadataExt, Error = MetaError> + Default, - B::IdType: Clone + Send + Sync + 'static + GenerateId + Debug, + B::Context: Send + Sync + 'static + MetadataExt + Default, + B::IdType: Clone + Send + Sync + 'static + GenerateId + Debug + FromStr + Display, B::Compact: Send + Sync + 'static, - MetaError: Into, + ::Err: std::error::Error + Send + Sync + 'static, { type Response = B::Compact; type Error = DagFlowError; @@ -136,7 +135,7 @@ impl IntoWorkerService, B::Compact, B::Con where B: BackendExt + Clone, Err: std::error::Error + Send + Sync + 'static, - B::Context: MetadataExt> + Send + Sync + 'static, + B::Context: MetadataExt + Send + Sync + 'static, B::IdType: Send + Sync + 'static + Default + GenerateId + PartialEq + Debug, B::Compact: Send + Sync + 'static + Clone, RootDagService: Service>, diff --git a/apalis-workflow/src/dag/service.rs b/apalis-workflow/src/dag/service.rs index 5d30368f..e89ce6fe 100644 --- a/apalis-workflow/src/dag/service.rs +++ b/apalis-workflow/src/dag/service.rs @@ -13,7 +13,8 @@ use futures::{FutureExt, Sink, SinkExt, StreamExt}; use petgraph::Direction; use petgraph::graph::NodeIndex; use std::collections::HashMap; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; +use std::str::FromStr; use tower::Service; use crate::DagExecutor; @@ -72,8 +73,7 @@ fn find_designated_fan_in_handler( designated_handler.ok_or(DagFlowError::Service(DagServiceError::MissingFaninHandler)) } -impl Service> - for RootDagService +impl Service> for RootDagService where B: BackendExt + Send @@ -81,17 +81,16 @@ where + 'static + Clone + WaitForCompletion>, - IdType: GenerateId + Send + Sync + 'static + PartialEq + Debug + Clone, + IdType: GenerateId + Send + Sync + 'static + PartialEq + Debug + Clone + FromStr + Display, B::Compact: Send + Sync + 'static + Clone, - B::Context: - Send + Sync + Default + MetadataExt, Error = MetaError> + 'static, + B::Context: Send + Sync + Default + MetadataExt + 'static, Err: std::error::Error + Send + Sync + 'static, B: Sink, Error = Err> + Unpin, B::Codec: Codec, Compact = B::Compact, Error = CdcErr> + 'static + Codec, Compact = B::Compact, Error = CdcErr>, CdcErr: Into, - MetaError: Into + Send + Sync + 'static, + ::Err: std::error::Error + Send + Sync + 'static, { type Response = DagExecutionResponse; type Error = DagFlowError; @@ -244,10 +243,7 @@ where #[cfg(feature = "tracing")] tracing::debug!("Single start node detected, proceeding with execution"); let context = DagFlowContext::new(req.parts.task_id.clone()); - req.parts - .ctx - .inject(context.clone()) - .map_err(|e| DagFlowError::Metadata(e.into()))?; + req.parts.ctx.inject(&context)?; let response = executor.call(req).await?; #[cfg(feature = "tracing")] tracing::debug!(node = ?context.current_node, "Execution complete at node"); @@ -293,7 +289,7 @@ where let task = TaskBuilder::new(response.clone()) .with_task_id(TaskId::new(B::IdType::generate())) - .meta(new_context) + .meta(&new_context) .build(); backend .send(task) @@ -337,12 +333,14 @@ async fn fan_out_next_nodes( where B::IdType: GenerateId + Send + Sync + 'static + PartialEq, B::Compact: Send + Sync + 'static + Clone, - B::Context: Send + Sync + Default + MetadataExt> + 'static, + B::Context: Send + Sync + Default + MetadataExt + 'static, B: Sink, Error = Err> + Unpin, Err: std::error::Error + Send + Sync + 'static, B: BackendExt + Send + Sync + 'static + Clone, B::Codec: Codec, Compact = B::Compact, Error = CdcErr>, CdcErr: Into, + B::IdType: FromStr + Display, + ::Err: std::error::Error + Send + Sync + 'static, { let mut enqueue_futures = vec![]; let next_nodes = outgoing_nodes @@ -358,7 +356,7 @@ where .clone(); let task = TaskBuilder::new(input.clone()) .with_task_id(task_id) - .meta(DagFlowContext { + .meta(&DagFlowContext { prev_node: context.prev_node, current_node: outgoing_node, completed_nodes: context.completed_nodes.clone(), @@ -392,12 +390,14 @@ async fn fan_out_entry_nodes( where B::IdType: GenerateId + Send + Sync + 'static + PartialEq + Debug, B::Compact: Send + Sync + 'static + Clone, - B::Context: Send + Sync + Default + MetadataExt> + 'static, + B::Context: Send + Sync + Default + MetadataExt + 'static, B: Sink, Error = Err> + Unpin, Err: std::error::Error + Send + Sync + 'static, B: BackendExt + Send + Sync + 'static + Clone, B::Codec: Codec, Compact = B::Compact, Error = CdcErr>, CdcErr: Into, + B::IdType: FromStr + Display, + ::Err: std::error::Error + Send + Sync + 'static, { let values: Vec = B::Codec::decode(input).map_err(|e: CdcErr| DagFlowError::Codec(e.into()))?; @@ -421,7 +421,7 @@ where .ok_or(DagFlowError::Service(DagServiceError::MissingNextNode))?; let task = TaskBuilder::new(input) .with_task_id(task_id.clone()) - .meta(DagFlowContext { + .meta(&DagFlowContext { prev_node: None, current_node: outgoing_node, completed_nodes: Default::default(), diff --git a/apalis-workflow/src/lib.rs b/apalis-workflow/src/lib.rs index d5c73d82..eb4aca4f 100644 --- a/apalis-workflow/src/lib.rs +++ b/apalis-workflow/src/lib.rs @@ -38,6 +38,7 @@ mod tests { use apalis_core::{ task::{ builder::TaskBuilder, + metadata::Meta, task_id::{RandomId, TaskId}, }, task_fn::task_fn, @@ -56,13 +57,14 @@ mod tests { #[tokio::test] async fn basic_workflow() { + type RepeatUntilState = Meta>; let workflow = Workflow::new("and-then-workflow") .and_then(async |input: u32| (input) as usize) .delay_for(Duration::from_secs(1)) .and_then(async |input: usize| (input) as usize) .delay_for(Duration::from_secs(1)) .delay_with(|_| Duration::from_secs(1)) - .repeat_until(|res: usize, state: RepeaterState| async move { + .repeat_until(|res: usize, state: RepeatUntilState| async move { println!("Iteration {}: got result {}", state.iterations(), res); // Repeat until we have iterated 3 times // Of course, in a real-world scenario, the condition would be based on `res` diff --git a/apalis-workflow/src/sequential/and_then/mod.rs b/apalis-workflow/src/sequential/and_then/mod.rs index 34f91774..624c6efa 100644 --- a/apalis-workflow/src/sequential/and_then/mod.rs +++ b/apalis-workflow/src/sequential/and_then/mod.rs @@ -15,7 +15,7 @@ use tower::{Service, ServiceBuilder, layer::layer_fn}; use crate::{ SteppedService, id_generator::GenerateId, - sequential::context::{StepContext, WorkflowContext}, + sequential::context::StepContext, sequential::router::{GoTo, StepResult, WorkflowRouter}, sequential::service::handle_step_result, sequential::step::{Layer, Stack, Step}, @@ -82,7 +82,7 @@ where B::IdType: GenerateId + Send + 'static, S::Response: Send + 'static, B::Compact: Send + 'static, - B::Context: Send + MetadataExt + 'static, + B::Context: Send + MetadataExt + 'static, SinkError: std::error::Error + Send + Sync + 'static, F::Response: Send + 'static, { @@ -150,7 +150,7 @@ where SinkError: std::error::Error + Send + Sync + 'static, Res: Send + 'static, B::Compact: Send + 'static, - B::Context: Send + MetadataExt + 'static, + B::Context: Send + MetadataExt + 'static, { type Response = GoTo>; type Error = BoxDynError; diff --git a/apalis-workflow/src/sequential/context.rs b/apalis-workflow/src/sequential/context.rs index f4b5c488..bf1ab0e5 100644 --- a/apalis-workflow/src/sequential/context.rs +++ b/apalis-workflow/src/sequential/context.rs @@ -1,3 +1,6 @@ +use std::num::ParseIntError; + +use apalis_core::task::metadata::{Metadata, MetadataError, MetadataStore}; use serde::{Deserialize, Serialize}; /// Context information for the current step in the workflow @@ -26,8 +29,40 @@ impl StepContext { pub struct WorkflowContext { /// Index of the step in the workflow pub step_index: usize, - // / Additional fields can be added as needed - // / name: String, - // / version: String, - // / parent_workflow_id: Option, +} + +/// Represents an invalid [`WorkflowContext`] state +#[derive(Debug, thiserror::Error)] +pub enum WorkflowContextError { + /// An entry for the key is missing + #[error("the data for key {WORKFLOW_CONTEXT_KEY} is missing")] + MissingKey, + /// Could not parse the value provided + #[error("Could not parse key {WORKFLOW_CONTEXT_KEY}")] + Parse(#[from] ParseIntError), + + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +const WORKFLOW_CONTEXT_KEY: &str = "apalis_workflow.context.step_index"; + +impl Metadata for WorkflowContext { + type Error = WorkflowContextError; + + fn extract(map: &MetadataStore) -> Result { + let step_index = map + .get(WORKFLOW_CONTEXT_KEY) + .ok_or(WorkflowContextError::MissingKey)? + .parse::() + .map_err(WorkflowContextError::Parse)?; + + Ok(Self { step_index }) + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), WorkflowContextError> { + map.insert(WORKFLOW_CONTEXT_KEY, self.step_index.to_string())?; + Ok(()) + } } diff --git a/apalis-workflow/src/sequential/delay/mod.rs b/apalis-workflow/src/sequential/delay/mod.rs index 0a05ec9c..96eadf73 100644 --- a/apalis-workflow/src/sequential/delay/mod.rs +++ b/apalis-workflow/src/sequential/delay/mod.rs @@ -59,7 +59,7 @@ where S::Response: Send + 'static, B::Codec: Codec + Codec + 'static, >::Error: Into, - B::Context: Send + 'static + MetadataExt, + B::Context: Send + 'static + MetadataExt, Input: Send + Sync + 'static, >::Error: Into, B: BackendExt, @@ -133,7 +133,7 @@ where S::Response: Send + 'static, B::Codec: Codec + Codec + 'static, >::Error: Into, - B::Context: Send + 'static + MetadataExt, + B::Context: Send + 'static + MetadataExt, Input: Send + Sync + 'static, >::Error: Into, B: BackendExt, @@ -166,7 +166,7 @@ where B::Codec: Codec + Codec + 'static, >::Error: Into, >::Error: Into, - B::Context: Send + 'static + MetadataExt, + B::Context: Send + 'static + MetadataExt, { type Response = GoTo>; type Error = BoxDynError; @@ -194,7 +194,7 @@ where }); let task = TaskBuilder::new(args) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step + 1, }) .run_after(delay_duration) diff --git a/apalis-workflow/src/sequential/filter_map/mod.rs b/apalis-workflow/src/sequential/filter_map/mod.rs index abe1f890..3bc85f4c 100644 --- a/apalis-workflow/src/sequential/filter_map/mod.rs +++ b/apalis-workflow/src/sequential/filter_map/mod.rs @@ -1,9 +1,14 @@ -use std::marker::PhantomData; +use std::{fmt::Display, marker::PhantomData, str::FromStr}; use apalis_core::{ backend::{BackendExt, TaskSinkError, WaitForCompletion, codec::Codec}, error::BoxDynError, - task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId}, + task::{ + Task, + builder::TaskBuilder, + metadata::{Metadata, MetadataError, MetadataExt, MetadataStore}, + task_id::TaskId, + }, task_fn::{TaskFn, task_fn}, }; use futures::{FutureExt, Sink, StreamExt, future::BoxFuture}; @@ -80,21 +85,146 @@ impl Clone for FilterService) -> std::fmt::Result { + match self { + Self::Init => write!(f, "Init"), + Self::SingleStep => write!(f, "SingleStep"), + Self::Collector => write!(f, "Collector"), + } + } +} + +/// Represents an invalid FilterState +#[derive(Debug, thiserror::Error)] +pub enum FilterStateParseError { + /// Invalid filter state + #[error("invalid filter state: {0}")] + InvalidState(String), +} + +impl std::str::FromStr for FilterState { + type Err = FilterStateParseError; + + fn from_str(s: &str) -> Result { + match s { + "Init" => Ok(Self::Init), + "SingleStep" => Ok(Self::SingleStep), + "Collector" => Ok(Self::Collector), + _ => Err(FilterStateParseError::InvalidState(s.to_owned())), + } + } +} + +/// Represents an invalid [`FilterState`] +#[derive(Debug, thiserror::Error)] +pub enum FilterStateError { + /// The filter state is missing + #[error("the data for key {FILTER_STATE_KEY} is missing")] + MissingKey, + + /// Could not parse the filter state + #[error(transparent)] + Parse(#[from] FilterStateParseError), + + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +impl Metadata for FilterState { + type Error = FilterStateError; + + fn extract(map: &MetadataStore) -> Result { + let value = map + .get(FILTER_STATE_KEY) + .ok_or(FilterStateError::MissingKey)?; + + Ok(value.parse::()?) + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), FilterStateError> { + map.insert(FILTER_STATE_KEY, self.to_string())?; + Ok(()) + } +} + /// The context for the filter operation #[derive(Debug, Clone, Deserialize, Serialize)] pub struct FilterContext { task_ids: Vec>, } -impl +const FILTER_CONTEXT_TASK_IDS_KEY: &str = "apalis_workflow.filter.task_ids"; + +/// Error representing an invalid [`FilterContext`] state +#[derive(Debug, thiserror::Error)] +pub enum FilterContextError { + /// The entry for key {FILTER_CONTEXT_TASK_IDS_KEY} is missing" + #[error("the entry for key {FILTER_CONTEXT_TASK_IDS_KEY} is missing")] + MissingKey, + + /// Could not parse the provided task_id + #[error("could not parse task id")] + ParseTaskId, + + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +impl Metadata for FilterContext +where + IdType: std::str::FromStr + Display, +{ + type Error = FilterContextError; + + fn extract(map: &MetadataStore) -> Result { + let value = map + .get(FILTER_CONTEXT_TASK_IDS_KEY) + .ok_or(FilterContextError::MissingKey)?; + + let task_ids = if value.is_empty() { + Vec::new() + } else { + value + .split(',') + .map(|id| { + id.parse::() + .map(TaskId::new) + .map_err(|_| FilterContextError::ParseTaskId) + }) + .collect::, _>>()? + }; + + Ok(Self { task_ids }) + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), FilterContextError> { + let value = self + .task_ids + .iter() + .map(ToString::to_string) + .collect::>() + .join(","); + + map.insert(FILTER_CONTEXT_TASK_IDS_KEY, value)?; + + Ok(()) + } +} + +impl Service> for FilterService where F: Service, Response = Option>, @@ -106,23 +236,17 @@ where + Sink, Error = Err> + WaitForCompletion>> + Unpin, - B::Context: MetadataExt, B::Codec: Codec, Error = CodecError, Compact = B::Compact> + Codec + Codec + Codec + Codec, Error = CodecError, Compact = B::Compact> + 'static, - IdType: GenerateId + Send + 'static + Clone, - B::Context: MetadataExt - + MetadataExt, Error = MetaErr> - + Send - + Sync - + 'static, + IdType: GenerateId + FromStr + Display + Send + 'static + Clone, + B::Context: MetadataExt + Send + Sync + 'static, Err: std::error::Error + Send + Sync + 'static, CodecError: std::error::Error + Send + Sync + 'static, F::Error: Into + Send + 'static, - MetaErr: std::error::Error + Send + Sync + 'static, F::Future: Send + 'static, B::Compact: Send + 'static, Input: Send + 'static, @@ -141,11 +265,11 @@ where } fn call(&mut self, request: Task) -> Self::Future { - let filter_state: FilterState = request.parts.ctx.extract().unwrap_or(FilterState::Unknown); + let filter_state: FilterState = request.parts.ctx.extract().unwrap_or(FilterState::Init); let mut ctx = request.parts.data.get::>().cloned().unwrap(); use futures::SinkExt; match filter_state { - FilterState::Unknown => { + FilterState::Init => { // Handle unknown state async move { let main_args: Vec = vec![]; @@ -158,11 +282,11 @@ where let task_id = TaskId::new(B::IdType::generate()); let task = TaskBuilder::new(B::Codec::encode(&step)?) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step, }) .with_task_id(task_id.clone()) - .meta(FilterState::SingleStep) + .meta(&FilterState::SingleStep) .build(); ctx.backend .send(task) @@ -174,11 +298,11 @@ where let task_id = TaskId::new(B::IdType::generate()); let task = TaskBuilder::new(B::Codec::encode(&main_args)?) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step, }) - .meta(FilterContext { task_ids }) - .meta(FilterState::Collector) + .meta(&FilterContext { task_ids }) + .meta(&FilterState::Collector) .build(); ctx.backend @@ -245,7 +369,7 @@ where } } -impl Step +impl Step for FilterMapStep where I: IntoIterator + Send + Sync + 'static, @@ -273,14 +397,8 @@ where B::IdType: GenerateId + Send + 'static + Clone, S::Response: Send + 'static, B::Compact: Send + 'static, - B::Context: Send - + MetadataExt - + MetadataExt - + MetadataExt> - + 'static, SinkError: std::error::Error + Send + Sync + 'static, F::Response: Send + 'static, - B::Context: MetadataExt, B::Codec: Codec, Error = CodecError, Compact = B::Compact> + Codec + Codec @@ -288,17 +406,12 @@ where + Codec, Error = CodecError, Compact = B::Compact> + 'static, B::IdType: GenerateId + Send + 'static, - B::Context: MetadataExt - + MetadataExt, Error = MetaErr> - + Send - + Sync - + 'static, + B::Context: MetadataExt + Send + Sync + 'static, CodecError: std::error::Error + Send + Sync + 'static, - MetaErr: std::error::Error + Send + Sync + 'static, F::Future: Send + 'static, B::Compact: Send + 'static, - Input: Send + 'static, Output: Send + 'static, + IdType: FromStr + Display, { type Response = Vec; type Error = F::Error; diff --git a/apalis-workflow/src/sequential/fold/mod.rs b/apalis-workflow/src/sequential/fold/mod.rs index bedea332..9450705d 100644 --- a/apalis-workflow/src/sequential/fold/mod.rs +++ b/apalis-workflow/src/sequential/fold/mod.rs @@ -3,7 +3,12 @@ use std::{marker::PhantomData, task::Context}; use apalis_core::{ backend::{BackendExt, TaskSinkError, codec::Codec}, error::BoxDynError, - task::{Task, builder::TaskBuilder, metadata::MetadataExt, task_id::TaskId}, + task::{ + Task, + builder::TaskBuilder, + metadata::{Metadata, MetadataError, MetadataExt, MetadataStore}, + task_id::TaskId, + }, task_fn::{TaskFn, task_fn}, }; use futures::{FutureExt, Sink, SinkExt, future::BoxFuture}; @@ -66,7 +71,7 @@ pub struct FoldStep { _marker: std::marker::PhantomData, } -impl, Init, B, MetaErr, Err, CodecError> Step +impl, Init, B, Err, CodecError> Step for FoldStep where F: Service, Response = Init> @@ -83,10 +88,7 @@ where + Unpin + 'static, I: IntoIterator + Send + Sync + 'static, - B::Context: MetadataExt - + MetadataExt - + Send - + 'static, + B::Context: MetadataExt + Send + 'static, B::Codec: Codec<(Init, Vec), Error = CodecError, Compact = B::Compact> + Codec + Codec @@ -97,7 +99,6 @@ where Err: std::error::Error + Send + Sync + 'static, CodecError: std::error::Error + Send + Sync + 'static, F::Error: Into + Send + 'static, - MetaErr: std::error::Error + Send + Sync + 'static, F::Future: Send + 'static, B::Compact: Send + 'static, Input: Send + 'static, @@ -141,8 +142,8 @@ impl FoldService { } } -impl - Service> for FoldService +impl Service> + for FoldService where F: Service, Response = Init> + Send @@ -156,10 +157,7 @@ where + Unpin + 'static, I: IntoIterator + Send + 'static, - B::Context: MetadataExt - + MetadataExt - + Send - + 'static, + B::Context: MetadataExt + Send + 'static, B::Codec: Codec<(Init, Vec), Error = CodecError, Compact = B::Compact> + Codec + Codec @@ -170,7 +168,6 @@ where Err: std::error::Error + Send + Sync + 'static, CodecError: std::error::Error + Send + Sync + 'static, F::Error: Into + Send + 'static, - MetaErr: std::error::Error + Send + Sync + 'static, F::Future: Send + 'static, B::Compact: Send + 'static, Input: Send + 'static, @@ -184,21 +181,21 @@ where } fn call(&mut self, task: Task) -> Self::Future { - let state = task.parts.ctx.extract().unwrap_or(FoldState::Unknown); + let state = task.parts.ctx.extract().unwrap_or(FoldState::Init); let mut ctx = task.parts.data.get::>().cloned().unwrap(); let mut fold = self.fold.clone(); match state { - FoldState::Unknown => async move { + FoldState::Init => async move { let task_id = TaskId::new(B::IdType::generate()); let steps: Task = task.try_map(|arg| B::Codec::decode(&arg))?; let steps = steps.args.into_iter().collect::>(); let task = TaskBuilder::new(B::Codec::encode(&(Init::default(), steps))?) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step, }) .with_task_id(task_id.clone()) - .meta(FoldState::Collection) + .meta(&FoldState::Collection) .build(); ctx.backend .send(task) @@ -226,7 +223,7 @@ where let result = B::Codec::encode(&response)?; let next_step = TaskBuilder::new(result) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step + 1, }) .build(); @@ -249,10 +246,10 @@ where let result = B::Codec::encode(&response)?; let steps = TaskBuilder::new(B::Codec::encode(&(response, rest))?) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step, }) - .meta(FoldState::Collection) + .meta(&FoldState::Collection) .build(); ctx.backend .send(steps) @@ -273,8 +270,44 @@ where /// The state of the fold operation #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub enum FoldState { - /// Unknown - Unknown, + /// Initializing state + Init, /// Collection has started Collection, } + +const FOLD_STATE_KEY: &str = "apalis_workflow.fold.state"; + +/// An error representing an invalid [`FoldState`] +#[derive(Debug, thiserror::Error)] +pub enum FoldStateError { + /// The fold state key is missing + #[error("the data for key {FOLD_STATE_KEY} is missing")] + MissingKey, + + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +impl Metadata for FoldState { + type Error = FoldStateError; + + fn extract(map: &MetadataStore) -> Result { + let value = map.get(FOLD_STATE_KEY).ok_or(FoldStateError::MissingKey)?; + + match value.as_str() { + "Collection" => Ok(Self::Collection), + _ => Ok(Self::Init), + } + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), FoldStateError> { + let value = match self { + Self::Init => "Init", + Self::Collection => "Collection", + }; + map.insert(FOLD_STATE_KEY, value)?; + Ok(()) + } +} diff --git a/apalis-workflow/src/sequential/repeat_until/mod.rs b/apalis-workflow/src/sequential/repeat_until/mod.rs index efaf289a..1c279094 100644 --- a/apalis-workflow/src/sequential/repeat_until/mod.rs +++ b/apalis-workflow/src/sequential/repeat_until/mod.rs @@ -1,14 +1,16 @@ -use std::convert::Infallible; +use std::fmt::Display; use std::marker::PhantomData; +use std::num::ParseIntError; +use std::str::FromStr; use std::task::Context; use apalis_core::backend::TaskSinkError; use apalis_core::backend::codec::Codec; use apalis_core::error::BoxDynError; use apalis_core::task::builder::TaskBuilder; -use apalis_core::task::metadata::MetadataExt; +use apalis_core::task::metadata::{Metadata, MetadataError, MetadataExt, MetadataStore}; use apalis_core::task::task_id::TaskId; -use apalis_core::task_fn::{FromRequest, TaskFn, task_fn}; +use apalis_core::task_fn::{TaskFn, task_fn}; use apalis_core::{backend::BackendExt, task::Task}; use futures::future::BoxFuture; use futures::{FutureExt, Sink, SinkExt}; @@ -90,7 +92,7 @@ where } } -impl Service> +impl Service> for RepeatUntilService where F: Service, Response = Option> + Send + 'static + Clone, @@ -101,19 +103,15 @@ where + Sink, Error = Err> + Unpin + 'static, - B::Context: MetadataExt, Error = MetaErr> - + MetadataExt - + Send - + 'static, + B::Context: MetadataExt + Send + 'static, B::Codec: Codec + Codec + Codec, Error = CodecError, Compact = B::Compact> + 'static, - B::IdType: GenerateId + Send + 'static, + B::IdType: GenerateId + Send + Display + FromStr + 'static, Err: std::error::Error + Send + Sync + 'static, CodecError: std::error::Error + Send + Sync + 'static, F::Error: Into + Send + 'static, - MetaErr: std::error::Error + Send + Sync + 'static, F::Future: Send + 'static, B::Compact: Send + 'static, Input: Send + 'static, // We don't need Clone because decoding just needs a reference @@ -149,7 +147,7 @@ where let task_id = TaskId::new(B::IdType::generate()); let next_step = TaskBuilder::new(B::Codec::encode(&res)?) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step + 1, }) .build(); @@ -171,10 +169,10 @@ where let next_step = TaskBuilder::new(compact.take().expect("Compact args should be set")) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step, }) - .meta(RepeaterState { + .meta(&RepeaterState { iterations: state.iterations + 1, prev_task_id, }) @@ -194,7 +192,7 @@ where } } -/// The state of the fold operation +/// The state of the repeat operation #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct RepeaterState { iterations: usize, @@ -222,21 +220,68 @@ impl RepeaterState { } } -impl + Sync, IdType: Sync> FromRequest> - for RepeaterState +/// An error representing an invalid [`RepeaterState`] +#[derive(Debug, thiserror::Error)] +pub enum RepeaterStateError { + /// Missing iterations key + #[error("the data for key {REPEATER_ITERATIONS_KEY} is missing")] + MissingIterations, + + /// Could not parse iterations + #[error("could not parse key {REPEATER_ITERATIONS_KEY}")] + ParseIterations(#[from] ParseIntError), + + /// Could not parse a task id + #[error("could not parse key {REPEATER_PREV_TASK_ID_KEY}")] + ParseTaskId, + + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +const REPEATER_ITERATIONS_KEY: &str = "apalis_workflow.repeater.iterations"; +const REPEATER_PREV_TASK_ID_KEY: &str = "apalis_workflow.repeater.prev_task_id"; + +impl Metadata for RepeaterState +where + IdType: std::str::FromStr + ToString, { - type Error = Infallible; - async fn from_request(task: &Task) -> Result { - let state: Self = task.parts.ctx.extract().unwrap_or_default(); + type Error = RepeaterStateError; + + fn extract(map: &MetadataStore) -> Result { + let iterations = map + .get(REPEATER_ITERATIONS_KEY) + .ok_or(RepeaterStateError::MissingIterations)? + .parse::()?; + + let prev_task_id = map + .get(REPEATER_PREV_TASK_ID_KEY) + .map(|value| { + value + .parse::() + .map(TaskId::new) + .map_err(|_| RepeaterStateError::ParseTaskId) + }) + .transpose()?; + Ok(Self { - iterations: state.iterations, - prev_task_id: state.prev_task_id, + iterations, + prev_task_id, }) } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), RepeaterStateError> { + map.insert(REPEATER_ITERATIONS_KEY, self.iterations.to_string())?; + + if let Some(task_id) = &self.prev_task_id { + map.insert(REPEATER_PREV_TASK_ID_KEY, task_id.to_string())?; + } + Ok(()) + } } -impl Step - for RepeatUntilStep +impl Step for RepeatUntilStep where F: Service, Response = Option> + Send @@ -250,10 +295,7 @@ where + Sink, Error = Err> + Unpin + 'static, - B::Context: MetadataExt, Error = MetaErr> - + MetadataExt - + Send - + 'static, + B::Context: MetadataExt + Send + 'static, B::Codec: Codec + Codec + Codec, Error = CodecError, Compact = B::Compact> @@ -262,12 +304,12 @@ where Err: std::error::Error + Send + Sync + 'static, CodecError: std::error::Error + Send + Sync + 'static, F::Error: Into + Send + 'static, - MetaErr: std::error::Error + Send + Sync + 'static, F::Future: Send + 'static, B::Compact: Send + 'static, Input: Send + Sync + 'static, // We don't need Clone because decoding just needs a reference Res: Send + Sync + 'static, S: Step + Send + 'static, + B::IdType: FromStr + Display, { type Response = Res; type Error = F::Error; diff --git a/apalis-workflow/src/sequential/service.rs b/apalis-workflow/src/sequential/service.rs index 55148685..b0d55769 100644 --- a/apalis-workflow/src/sequential/service.rs +++ b/apalis-workflow/src/sequential/service.rs @@ -52,10 +52,9 @@ impl Service> for Workflo where B::Compact: Send + 'static, B: Sync, - B::Context: Send + Default + MetadataExt, + B::Context: Send + Default + MetadataExt, Err: std::error::Error + Send + Sync + 'static, B::IdType: GenerateId + Send + 'static, - >::Error: Into, B: Sink, Error = Err> + Unpin, B: Clone + Send + Sync + 'static + BackendExt, { @@ -119,7 +118,7 @@ where + BackendExt + Send + Unpin, - B::Context: MetadataExt, + B::Context: MetadataExt, B::Codec: Codec, >::Error: Into, Compact: 'static, @@ -134,7 +133,7 @@ where B::Codec::encode(&next).map_err(|e| TaskSinkError::CodecError(e.into()))?, ) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step + 1, }) .build(); @@ -152,7 +151,7 @@ where ) .run_after(delay) .with_task_id(task_id.clone()) - .meta(WorkflowContext { + .meta(&WorkflowContext { step_index: ctx.current_step + 1, }) .build(); diff --git a/apalis-workflow/src/sequential/workflow.rs b/apalis-workflow/src/sequential/workflow.rs index a7a2caa8..7655ab02 100644 --- a/apalis-workflow/src/sequential/workflow.rs +++ b/apalis-workflow/src/sequential/workflow.rs @@ -10,7 +10,6 @@ use futures::Sink; use crate::{ id_generator::GenerateId, - sequential::context::WorkflowContext, sequential::router::WorkflowRouter, sequential::service::WorkflowService, sequential::step::{Identity, Layer, Stack, Step}, @@ -111,11 +110,10 @@ where + Unpin + Clone, Err: std::error::Error + Send + Sync + 'static, - B::Context: MetadataExt + Send + Sync + 'static, + B::Context: MetadataExt + Send + Sync + 'static, B::IdType: Send + 'static + Default + GenerateId, B: Sync + Backend, B::Compact: Send + Sync + 'static, - >::Error: Into, L: Layer>, L::Step: Step, { diff --git a/apalis-workflow/src/sink.rs b/apalis-workflow/src/sink.rs index a3194329..d57eb837 100644 --- a/apalis-workflow/src/sink.rs +++ b/apalis-workflow/src/sink.rs @@ -1,3 +1,5 @@ +use std::{fmt::Display, str::FromStr}; + use apalis_core::{ backend::{BackendExt, TaskSinkError, codec::Codec}, error::BoxDynError, @@ -57,20 +59,18 @@ where ) -> impl Future>> + Send; } -impl WorkflowSink for S +impl WorkflowSink for S where S: Sink, Error = Err> + BackendExt + Unpin, - S::IdType: GenerateId + Send, + S::IdType: GenerateId + Send + FromStr + Display, S::Codec: Codec, - S::Context: MetadataExt - + MetadataExt, Error = MetaErr> - + Send, + S::Context: MetadataExt + Send, Err: std::error::Error + Send + Sync + 'static, >::Error: Into + Send + Sync + 'static, - MetaErr: Into + Send + Sync + 'static, Compact: Send + 'static, + ::Err: std::error::Error + Send + Sync + 'static, { async fn push_start(&mut self, args: Args) -> Result<(), TaskSinkError> { use futures::SinkExt; @@ -110,7 +110,7 @@ where let task_id = TaskId::new(S::IdType::generate()); let compact = S::Codec::encode(&step).map_err(|e| TaskSinkError::CodecError(e.into()))?; let task = TaskBuilder::new(compact) - .meta(WorkflowContext { step_index: index }) + .meta(&WorkflowContext { step_index: index }) .with_task_id(task_id.clone()) .build(); self.send(task) @@ -127,7 +127,7 @@ where let task_id = TaskId::new(S::IdType::generate()); let compact = S::Codec::encode(&node).map_err(|e| TaskSinkError::CodecError(e.into()))?; let task = TaskBuilder::new(compact) - .meta(DagFlowContext { + .meta(&DagFlowContext { current_node: index, completed_nodes: Default::default(), current_position: index.index(), diff --git a/apalis/src/layers/retry/mod.rs b/apalis/src/layers/retry/mod.rs index 5e353b02..afd01366 100644 --- a/apalis/src/layers/retry/mod.rs +++ b/apalis/src/layers/retry/mod.rs @@ -80,11 +80,12 @@ use apalis_core::error::AbortError; use apalis_core::task::Task; use apalis_core::task::builder::TaskBuilder; -use apalis_core::task::metadata::MetadataExt; +use apalis_core::task::metadata::{Metadata, MetadataError, MetadataExt, MetadataStore}; use apalis_core::task::status::Status; use apalis_core::worker::context::WorkerContext; use std::any::Any; use std::fmt::Debug; +use std::num::ParseIntError; use tower::retry::backoff::Backoff; /// Re-exports from [`tower::retry`] @@ -304,6 +305,43 @@ pub struct RetryConfig { pub retries: usize, } +/// An error that represents an invalid [`RetryConfig`] +#[derive(Debug, thiserror::Error)] +pub enum RetryConfigError { + /// The retry config key is missing + #[error("the data for key {RETRY_CONFIG_KEY} is missing")] + MissingKey, + + /// Could not parse the retry config key + #[error("Could not parse key {RETRY_CONFIG_KEY}")] + Parse(#[from] ParseIntError), + + /// Duplicate entry + #[error("Duplicate entry: {0}")] + DuplicateEntry(#[from] MetadataError), +} + +const RETRY_CONFIG_KEY: &str = "apalis.retries.config"; + +impl Metadata for RetryConfig { + type Error = RetryConfigError; + + fn extract(map: &MetadataStore) -> Result { + let retries = map + .get(RETRY_CONFIG_KEY) + .ok_or(RetryConfigError::MissingKey)? + .parse::() + .map_err(RetryConfigError::Parse)?; + + Ok(RetryConfig { retries }) + } + + fn inject(&self, map: &mut MetadataStore) -> Result<(), RetryConfigError> { + map.insert(RETRY_CONFIG_KEY, self.retries.to_string())?; + Ok(()) + } +} + /// Retry the task based on the [`RetryConfig`] metadata #[derive(Debug, Clone)] pub struct FromTaskConfigPolicy

{ @@ -338,7 +376,7 @@ where T: Clone, Ctx: Clone, P: Policy, Res, Err>, - Ctx: MetadataExt, + Ctx: MetadataExt, { type Future = P::Future; @@ -356,7 +394,7 @@ where Err(_) => { let attempt = req.parts.attempt.current(); // If we have a retry config, we need to respect it - if let Ok(cfg) = req.parts.ctx.extract() { + if let Ok(cfg) = req.parts.ctx.extract::() { if cfg.retries <= attempt { return None; } @@ -380,12 +418,11 @@ pub trait RetryMetadataExt { impl RetryMetadataExt for TaskBuilder where - Ctx: MetadataExt, - Ctx::Error: Debug, + Ctx: MetadataExt, { /// Set number of retries in the metadata fn retries(self, retries: usize) -> Self { - self.meta(RetryConfig { retries }) + self.meta(&RetryConfig { retries }) } } @@ -411,7 +448,9 @@ mod tests { async fn basic_worker_retries() { let mut in_memory = MemoryStorage::new(); - let task1 = TaskBuilder::new(1).meta(RetryConfig { retries: 3 }).build(); + let task1 = TaskBuilder::new(1) + .meta(&RetryConfig { retries: 3 }) + .build(); let task2 = TaskBuilder::new(2).retries(5).build(); let task3 = TaskBuilder::new(3).build(); diff --git a/apalis/src/layers/tracing/contextual_span.rs b/apalis/src/layers/tracing/contextual_span.rs index b9e24138..44955239 100644 --- a/apalis/src/layers/tracing/contextual_span.rs +++ b/apalis/src/layers/tracing/contextual_span.rs @@ -1,9 +1,6 @@ use std::fmt::Display; -use apalis_core::task::{ - Task, - metadata::{MetadataExt, TracingContext}, -}; +use apalis_core::task::{Task, metadata::MetadataExt}; use tracing::{Level, Span}; #[cfg(feature = "opentelemetry")] @@ -48,7 +45,7 @@ impl Default for ContextualTaskSpan { impl MakeSpan for ContextualTaskSpan where - Ctx: MetadataExt, + Ctx: MetadataExt, IdType: Display, { fn make_span(&mut self, req: &Task) -> Span { @@ -58,8 +55,10 @@ where .as_ref() .expect("A task must have an ID") .to_string(); + println!("Fetching"); #[cfg(feature = "opentelemetry")] - let tracing_ctx: TracingContext = req.parts.ctx.extract().unwrap_or_default(); + let tracing_ctx: apalis_core::task::metadata::TracingContext = + req.parts.ctx.extract().unwrap_or_default(); let attempt = &req.parts.attempt; let span = Span::current(); diff --git a/apalis/src/layers/tracing/mod.rs b/apalis/src/layers/tracing/mod.rs index 36df1968..439f8be8 100644 --- a/apalis/src/layers/tracing/mod.rs +++ b/apalis/src/layers/tracing/mod.rs @@ -382,9 +382,12 @@ mod tests { use super::*; use apalis_core::{ - backend::{TaskSink, memory::MemoryStorage}, + backend::{ + TaskSink, + memory::{MemoryContext, MemoryStorage}, + }, error::BoxDynError, - task::{extensions::Extensions, task_id::RandomId}, + task::task_id::RandomId, worker::{ builder::WorkerBuilder, context::WorkerContext, ext::event_listener::EventListenerExt, }, @@ -430,14 +433,14 @@ mod tests { .backend(in_memory) .layer( TraceLayer::new() - .make_span_with(|req: &Task| { + .make_span_with(|req: &Task| { tracing::span!( tracing::Level::INFO, "custom_span", task_id = req.parts.task_id.as_ref().unwrap().to_string() ) }) - .on_request(|task: &Task, span: &tracing::Span| { + .on_request(|task: &Task, span: &tracing::Span| { tracing::info!(parent: span, "Custom OnRequest: Received task: {:?}", task); }) .on_response(|_: &() , duration: Duration, span: &tracing::Span| { diff --git a/apalis/src/lib.rs b/apalis/src/lib.rs index 1f6bb8ea..05e41bb3 100644 --- a/apalis/src/lib.rs +++ b/apalis/src/lib.rs @@ -23,6 +23,10 @@ pub mod layers; /// Common imports pub mod prelude { pub use crate::layers::WorkerBuilderExt; + #[cfg(feature = "retry")] + pub use crate::layers::retry::{ + BackoffRetryPolicy, FromTaskConfigPolicy, RetryIfPolicy, RetryPolicy, + }; pub use apalis_core::{ backend::{ Backend, BackendExt, Expose, FetchById, Filter, ListAllTasks, ListQueues, ListTasks, @@ -40,7 +44,7 @@ pub mod prelude { task::builder::TaskBuilder, task::data::{AddExtension, Data, MissingDataError}, task::extensions::Extensions, - task::metadata::MetadataExt, + task::metadata::{Meta, Metadata, MetadataError, MetadataExt, MetadataStore}, task::status::Status, task::task_id::RandomId, task::task_id::TaskId, diff --git a/apalis/tests/otel_context_propagation.rs b/apalis/tests/otel_context_propagation.rs index 9345edd5..e72b0c06 100644 --- a/apalis/tests/otel_context_propagation.rs +++ b/apalis/tests/otel_context_propagation.rs @@ -61,7 +61,7 @@ async fn otel_context_propagates_producer_to_consumer() { }; backend - .send(Task::builder(1_u8).meta(metadata).build()) + .send(Task::builder(1_u8).meta(&metadata).build()) .await .unwrap(); @@ -134,7 +134,7 @@ async fn otel_context_invalid_traceparent_is_ignored_safely() { .with_trace_flags(1); backend - .send(Task::builder(1_u8).meta(invalid_context).build()) + .send(Task::builder(1_u8).meta(&invalid_context).build()) .await .unwrap(); @@ -168,7 +168,7 @@ async fn otel_tracestate_roundtrip_is_preserved() { TracingContext::from(OtelTraceContext::current()).with_trace_state("vendor=acme"); backend - .send(Task::builder(1_u8).meta(metadata).build()) + .send(Task::builder(1_u8).meta(&metadata).build()) .await .unwrap(); diff --git a/examples/monitor/src/main.rs b/examples/monitor/src/main.rs index 989e4c0d..ca4c0ddf 100644 --- a/examples/monitor/src/main.rs +++ b/examples/monitor/src/main.rs @@ -51,7 +51,7 @@ async fn produce_task_with_ctx(storage: &mut JsonStorage) -> Result<()> { subject: "Welcome Sentry Email".to_string(), }; let context = TracingContext::from(OtelTraceContext::current()); - let task = Task::builder(email).meta(context).build(); + let task = Task::builder(email).meta(&context).build(); storage.push_task(task).await?; Ok(()) } diff --git a/examples/tracing/src/main.rs b/examples/tracing/src/main.rs index a6f452fe..73fc735b 100644 --- a/examples/tracing/src/main.rs +++ b/examples/tracing/src/main.rs @@ -74,7 +74,7 @@ async fn produce_task_with_ctx(storage: &mut MemoryStorage) -> Result<()> subject: "Welcome Sentry Email".to_string(), }; let task = Task::builder(email) - .meta(TracingContext::from(OtelTraceContext::current())) + .meta(&TracingContext::from(OtelTraceContext::current())) .build(); storage.send(task).await?; Ok(()) diff --git a/supply-chain/config.toml b/supply-chain/config.toml index 79eb1d54..aa974ff3 100644 --- a/supply-chain/config.toml +++ b/supply-chain/config.toml @@ -261,7 +261,7 @@ version = "1.0.5" criteria = "safe-to-deploy" [[exemptions.either]] -version = "1.15.0" +version = "1.16.0" criteria = "safe-to-deploy" [[exemptions.email_address]] @@ -373,7 +373,7 @@ version = "0.3.32" criteria = "safe-to-deploy" [[exemptions.futures-timer]] -version = "3.0.3" +version = "3.0.4" criteria = "safe-to-deploy" [[exemptions.futures-util]] @@ -653,7 +653,7 @@ version = "0.2.18" criteria = "safe-to-deploy" [[exemptions.nix]] -version = "0.30.1" +version = "0.31.3" criteria = "safe-to-deploy" [[exemptions.nu-ansi-term]] @@ -777,7 +777,7 @@ version = "0.31.0" criteria = "safe-to-deploy" [[exemptions.os_info]] -version = "3.14.0" +version = "3.15.0" criteria = "safe-to-deploy" [[exemptions.parking_lot]] diff --git a/utils/apalis-file-storage/src/meta.rs b/utils/apalis-file-storage/src/meta.rs index c9c4bb2c..9488e584 100644 --- a/utils/apalis-file-storage/src/meta.rs +++ b/utils/apalis-file-storage/src/meta.rs @@ -1,27 +1,16 @@ -use apalis_core::task::metadata::MetadataExt; +use apalis_core::task::metadata::{MetadataExt, MetadataStore}; use serde::{Deserialize, Serialize}; /// A simple wrapper around a JSON map to represent task metadata #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub struct JsonMapMetadata(serde_json::Map); +pub struct JsonMapMetadata(MetadataStore); -impl Deserialize<'de>> MetadataExt for JsonMapMetadata { - type Error = serde_json::Error; - fn extract(&self) -> Result { - use serde::de::Error as _; - let key = std::any::type_name::(); - match self.0.get(key) { - Some(value) => T::deserialize(value), - None => Err(serde_json::Error::custom(format!( - "No entry for type `{key}` in metadata" - ))), - } +impl MetadataExt for JsonMapMetadata { + fn metadata(&self) -> &MetadataStore { + &self.0 } - fn inject(&mut self, value: T) -> Result<(), serde_json::Error> { - let key = std::any::type_name::(); - let json_value = serde_json::to_value(value)?; - self.0.insert(key.to_owned(), json_value); - Ok(()) + fn metadata_mut(&mut self) -> &mut MetadataStore { + &mut self.0 } }