Skip to content

Commit 4dc4813

Browse files
50U10FCA7tyranron
andauthored
Close connection when keep-alive timeout is reached in graphql-transport-ws protocol of juniper_graphql_ws crate (#1367)
- remake `ConnectionConfig::keep_alive_interval` option as `ConnectionConfig::keep_alive` represented by `KeepAliveConfig` - consider `ConnectionConfig::keep_alive::timeout` in `graphql-transport-ws` protocol Co-authored-by: Kai Ren <tyranron@gmail.com>
1 parent 2ed9d9d commit 4dc4813

5 files changed

Lines changed: 159 additions & 27 deletions

File tree

juniper_graphql_ws/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@ All user visible changes to `juniper_graphql_ws` crate will be documented in thi
1212

1313
- `Schema::Context` now requires `Clone` bound for ability to have a "fresh" context value each time a new [GraphQL] operation is started in a [WebSocket] connection. ([#1369])
1414
> **COMPATIBILITY**: Previously, it was `Arc`ed inside, sharing the same context value across all [GraphQL] operations of a [WebSocket] connection. To preserve the previous behavior, the `Schema::Context` type should be either wrapped into `Arc` or made `Arc`-based internally.
15+
- Replaced `ConnectionConfig::keep_alive_interval` option with `ConnectionConfig::keep_alive` one as `KeepAliveConfig`. ([#1367])
16+
- Made [WebSocket] connection closed once `ConnectionConfig::keep_alive::timeout` is reached in [`graphql-transport-ws` GraphQL over WebSocket Protocol][proto-6.0.7]. ([#1367])
17+
> **COMPATIBILITY**: Previously, a [WebSocket] connection was kept alive, even when clients do not respond to server's `Pong` messages at all. To preserve the previous behavior, the `ConnectionConfig::keep_alive::timeout` should be set to `Duration:::ZERO`.
1518
1619
### Fixed
1720

1821
- Inability to re-subscribe with the same operation `id` after subscription was completed by server. ([#1368])
1922

23+
[#1367]: /../../pull/1367
2024
[#1368]: /../../pull/1368
2125
[#1369]: /../../pull/1369
26+
[proto-6.0.7]: https://github.com/enisdenjo/graphql-ws/blob/v6.0.7/PROTOCOL.md
2227

2328

2429

juniper_graphql_ws/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ derive_more = { version = "2.0", features = ["debug", "from"] }
3030
juniper = { version = "0.17", path = "../juniper", default-features = false }
3131
juniper_subscriptions = { version = "0.18", path = "../juniper_subscriptions" }
3232
serde = { version = "1.0.122", features = ["derive"], default-features = false }
33-
tokio = { version = "1.0", features = ["macros", "rt", "time"], default-features = false }
33+
tokio = { version = "1.0", features = ["macros", "rt", "sync", "time"], default-features = false }
3434

3535
[dev-dependencies]
3636
serde_json = "1.0.18"

juniper_graphql_ws/src/graphql_transport_ws/mod.rs

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use juniper::{
2929
task::{Context, Poll, Waker},
3030
},
3131
};
32+
use tokio::{sync::Notify, time};
3233

3334
use super::{ConnectionConfig, Init, Schema};
3435

@@ -83,6 +84,7 @@ enum ConnectionState<S: Schema, I: Init<S::ScalarValue, S::Context>> {
8384
Active {
8485
config: ConnectionConfig<S::Context>,
8586
stoppers: HashMap<String, oneshot::Sender<()>>,
87+
ping: Arc<Notify>,
8688
schema: S,
8789
},
8890
/// Terminated is the state after a ConnectionInit message has been rejected.
@@ -100,60 +102,89 @@ impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
100102
Self::PreInit { init, schema } => match msg {
101103
ClientMessage::ConnectionInit { payload } => match init.init(payload).await {
102104
Ok(config) => {
103-
let keep_alive_interval = config.keep_alive_interval;
105+
let keep_alive_interval = config.keep_alive.interval;
106+
let keep_alive_timeout = config.keep_alive.timeout;
104107

105-
let mut s =
106-
stream::iter(vec![Output::Message(ServerMessage::ConnectionAck)])
107-
.boxed();
108+
let ping = Arc::new(Notify::new());
109+
110+
let mut s = Output::Message(ServerMessage::ConnectionAck)
111+
.into_stream()
112+
.boxed();
108113

109-
#[expect(closure_returning_async_block, reason = "not possible")]
110114
if keep_alive_interval > Duration::from_secs(0) {
111115
s = s
112116
.chain(Output::Message(ServerMessage::Pong).into_stream())
113117
.boxed();
114118
s = s
115-
.chain(stream::unfold((), move |_| async move {
116-
tokio::time::sleep(keep_alive_interval).await;
117-
Some((Output::Message(ServerMessage::Pong), ()))
119+
.chain(stream::repeat(()).then(move |()| {
120+
tokio::time::sleep(keep_alive_interval)
121+
.map(|()| Output::Message(ServerMessage::Pong))
118122
}))
119123
.boxed();
120124
}
121125

126+
if keep_alive_timeout > Duration::from_secs(0) {
127+
let ping_rx = ping.clone();
128+
s = stream::select_all([
129+
s,
130+
stream::repeat(())
131+
.then(move |()| {
132+
let ping_rx = ping_rx.clone();
133+
async move {
134+
time::timeout(keep_alive_timeout, ping_rx.notified())
135+
.await
136+
.is_err()
137+
.then(|| Output::Close {
138+
code: 1000,
139+
message: "Connection lost unexpectedly".into(),
140+
})
141+
}
142+
})
143+
.filter_map(future::ready)
144+
.boxed(),
145+
])
146+
.boxed();
147+
}
148+
122149
(
123150
Self::Active {
124151
config,
125152
stoppers: HashMap::new(),
153+
ping,
126154
schema,
127155
},
128156
s,
129157
)
130158
}
131159
Err(e) => (
132160
Self::Terminated,
133-
stream::iter(vec![Output::Close {
161+
Output::Close {
134162
code: 4403,
135163
message: e.to_string(),
136-
}])
164+
}
165+
.into_stream()
137166
.boxed(),
138167
),
139168
},
140169
ClientMessage::Ping { .. } => (
141170
Self::PreInit { init, schema },
142-
stream::iter(vec![Output::Message(ServerMessage::Pong)]).boxed(),
171+
Output::Message(ServerMessage::Pong).into_stream().boxed(),
143172
),
144173
ClientMessage::Subscribe { .. } => (
145174
Self::PreInit { init, schema },
146-
stream::iter(vec![Output::Close {
175+
Output::Close {
147176
code: 4401,
148177
message: "Unauthorized".to_string(),
149-
}])
178+
}
179+
.into_stream()
150180
.boxed(),
151181
),
152182
_ => (Self::PreInit { init, schema }, stream::empty().boxed()),
153183
},
154184
Self::Active {
155185
config,
156186
mut stoppers,
187+
ping,
157188
schema,
158189
} => {
159190
let reactions = match msg {
@@ -225,14 +256,16 @@ impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
225256
stream::empty().boxed()
226257
}
227258
ClientMessage::Ping { .. } => {
228-
stream::iter(vec![Output::Message(ServerMessage::Pong)]).boxed()
259+
ping.notify_waiters();
260+
Output::Message(ServerMessage::Pong).into_stream().boxed()
229261
}
230262
_ => stream::empty().boxed(),
231263
};
232264
(
233265
Self::Active {
234266
config,
235267
stoppers,
268+
ping,
236269
schema,
237270
},
238271
reactions,
@@ -956,10 +989,12 @@ mod test {
956989
}
957990

958991
#[tokio::test]
959-
async fn test_keep_alives() {
992+
async fn test_keep_alive_interval() {
960993
let mut conn = Connection::new(
961994
new_test_schema(),
962-
ConnectionConfig::new(Context(1)).with_keep_alive_interval(Duration::from_millis(20)),
995+
ConnectionConfig::new(Context(1))
996+
.with_keep_alive_interval(Duration::from_millis(20))
997+
.with_keep_alive_timeout(Duration::from_secs(0)),
963998
);
964999

9651000
conn.send(ClientMessage::ConnectionInit {
@@ -981,6 +1016,35 @@ mod test {
9811016
}
9821017
}
9831018

1019+
#[tokio::test]
1020+
async fn test_keep_alive_timeout() {
1021+
let mut conn = Connection::new(
1022+
new_test_schema(),
1023+
ConnectionConfig::new(Context(1))
1024+
.with_keep_alive_interval(Duration::from_millis(0))
1025+
.with_keep_alive_timeout(Duration::from_millis(20)),
1026+
);
1027+
1028+
conn.send(ClientMessage::ConnectionInit {
1029+
payload: graphql_vars! {},
1030+
})
1031+
.await
1032+
.unwrap();
1033+
1034+
assert_eq!(
1035+
Output::Message(ServerMessage::ConnectionAck),
1036+
conn.next().await.unwrap()
1037+
);
1038+
1039+
assert_eq!(
1040+
Output::Close {
1041+
code: 1000,
1042+
message: "Connection lost unexpectedly".into(),
1043+
},
1044+
conn.next().await.unwrap(),
1045+
);
1046+
}
1047+
9841048
#[tokio::test]
9851049
async fn test_slow_init() {
9861050
let mut conn = Connection::new(
@@ -1009,7 +1073,7 @@ mod test {
10091073

10101074
assert_eq!(
10111075
Output::Message(ServerMessage::ConnectionAck),
1012-
conn.next().await.unwrap()
1076+
conn.next().await.unwrap(),
10131077
);
10141078

10151079
assert_eq!(

juniper_graphql_ws/src/graphql_ws/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl<S: Schema, I: Init<S::ScalarValue, S::Context>> ConnectionState<S, I> {
8484
Self::PreInit { init, schema } => match msg {
8585
ClientMessage::ConnectionInit { payload } => match init.init(payload).await {
8686
Ok(config) => {
87-
let keep_alive_interval = config.keep_alive_interval;
87+
let keep_alive_interval = config.keep_alive.interval;
8888

8989
let mut s = stream::iter(vec![Reaction::ServerMessage(
9090
ServerMessage::ConnectionAck,

juniper_graphql_ws/src/lib.rs

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@ pub struct ConnectionConfig<CtxT> {
3434
/// By default, there is no limit to in-flight operations.
3535
pub max_in_flight_operations: usize,
3636

37-
/// Interval at which to send keep-alives.
38-
///
39-
/// Specifying a [`Duration::ZERO`] will disable keep-alives.
40-
///
41-
/// By default, keep-alives are sent every 15 seconds.
42-
pub keep_alive_interval: Duration,
37+
/// Keep-alive configuration.
38+
pub keep_alive: KeepAliveConfig,
4339
}
4440

4541
impl<CtxT> ConnectionConfig<CtxT> {
@@ -48,7 +44,7 @@ impl<CtxT> ConnectionConfig<CtxT> {
4844
Self {
4945
context,
5046
max_in_flight_operations: 0,
51-
keep_alive_interval: Duration::from_secs(15),
47+
keep_alive: KeepAliveConfig::default(),
5248
}
5349
}
5450

@@ -66,10 +62,37 @@ impl<CtxT> ConnectionConfig<CtxT> {
6662
///
6763
/// Specifying a [`Duration::ZERO`] will disable keep-alives.
6864
///
65+
/// Also, sets a keep-alive timeout to the provided [`Duration`].
66+
///
6967
/// By default, keep-alives are sent every 15 seconds.
7068
#[must_use]
7169
pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self {
72-
self.keep_alive_interval = interval;
70+
self.keep_alive.interval = interval;
71+
#[cfg(feature = "graphql-transport-ws")]
72+
{
73+
self.keep_alive.timeout = interval;
74+
}
75+
self
76+
}
77+
78+
#[cfg(feature = "graphql-transport-ws")]
79+
/// Specifies the timeout for waiting a keep-alive response from clients after sending them a
80+
/// keep-alive message.
81+
///
82+
/// Once the timeout is hit, the connection is closed by the server.
83+
///
84+
/// Specifying a [`Duration::ZERO`] disables timeout checking.
85+
///
86+
/// Applicable only for the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new],
87+
/// and does nothing for the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old].
88+
///
89+
/// By default, timeout equals to the [`KeepAliveConfig::interval`].
90+
///
91+
/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
92+
/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md
93+
#[must_use]
94+
pub fn with_keep_alive_timeout(mut self, timeout: Duration) -> Self {
95+
self.keep_alive.timeout = timeout;
7396
self
7497
}
7598
}
@@ -83,6 +106,46 @@ impl<S: ScalarValue, CtxT: Unpin + Send + 'static> Init<S, CtxT> for ConnectionC
83106
}
84107
}
85108

109+
/// Config for keeping a connection alive.
110+
#[derive(Clone, Copy, Debug)]
111+
pub struct KeepAliveConfig {
112+
/// Interval at which to send keep-alives.
113+
///
114+
/// Specifying a [`Duration::ZERO`] disables keep-alives.
115+
///
116+
/// By default, keep-alives are sent every 15 seconds.
117+
pub interval: Duration,
118+
119+
#[cfg(feature = "graphql-transport-ws")]
120+
/// Timeout for waiting a keep-alive response from clients after sending them a keep-alive
121+
/// message.
122+
///
123+
/// Once the timeout is hit, the connection is closed by the server.
124+
///
125+
/// Specifying a [`Duration::ZERO`] disables timeout checking.
126+
///
127+
/// Applicable only for the [new `graphql-transport-ws` GraphQL over WebSocket Protocol][new],
128+
/// and does nothing for the [legacy `graphql-ws` GraphQL over WebSocket Protocol][old].
129+
///
130+
/// By default, timeout equals to the [`interval`].
131+
///
132+
/// [`interval`]: Self::interval
133+
/// [new]: https://github.com/enisdenjo/graphql-ws/blob/v5.14.0/PROTOCOL.md
134+
/// [old]: https://github.com/apollographql/subscriptions-transport-ws/blob/v0.11.0/PROTOCOL.md
135+
pub timeout: Duration,
136+
}
137+
138+
impl Default for KeepAliveConfig {
139+
fn default() -> Self {
140+
let interval = Duration::from_secs(15);
141+
Self {
142+
interval,
143+
#[cfg(feature = "graphql-transport-ws")]
144+
timeout: interval,
145+
}
146+
}
147+
}
148+
86149
/// Init defines the requirements for types that can provide connection configurations when
87150
/// ConnectionInit messages are received. Implementations are provided for `ConnectionConfig` and
88151
/// closures that meet the requirements.

0 commit comments

Comments
 (0)