@@ -26,10 +26,13 @@ enum OomState {
2626
2727 /// We are inside an OOM test and should inject an OOM when the counter
2828 /// reaches zero.
29- OomOnAlloc ( u32 ) ,
29+ OomOnAlloc {
30+ counter : u32 ,
31+ allow_alloc_after : bool ,
32+ } ,
3033
3134 /// We are inside an OOM test and we already injected an OOM.
32- DidOom ,
35+ DidOom { allow_alloc : bool } ,
3336}
3437
3538thread_local ! {
@@ -107,22 +110,43 @@ unsafe impl GlobalAlloc for OomTestAllocator {
107110 match old_state {
108111 OomState :: OutsideOomTest => unreachable ! ( "handled above" ) ,
109112
110- OomState :: OomOnAlloc ( 0 ) => {
113+ OomState :: OomOnAlloc {
114+ counter : 0 ,
115+ allow_alloc_after,
116+ } => {
111117 log:: trace!(
112118 "injecting OOM for allocation: {layout:?}\n Allocation backtrace:\n {bt}"
113119 ) ;
114- new_state = OomState :: DidOom ;
120+ new_state = OomState :: DidOom {
121+ allow_alloc : allow_alloc_after,
122+ } ;
115123 ptr = ptr:: null_mut ( ) ;
116124 }
117125
118- OomState :: OomOnAlloc ( c) => {
119- new_state = OomState :: OomOnAlloc ( c - 1 ) ;
126+ OomState :: OomOnAlloc {
127+ counter : c,
128+ allow_alloc_after,
129+ } => {
130+ new_state = OomState :: OomOnAlloc {
131+ counter : c - 1 ,
132+ allow_alloc_after,
133+ } ;
120134 ptr = unsafe { std:: alloc:: System . alloc ( layout) } ;
121135 }
122136
123- OomState :: DidOom => {
137+ OomState :: DidOom { allow_alloc } => {
124138 log:: trace!( "Attempt to allocate {layout:?} after OOM:\n {bt}" ) ;
125- panic ! ( "OOM test attempted to allocate after OOM: {layout:?}" )
139+ if allow_alloc {
140+ new_state = OomState :: DidOom { allow_alloc : true } ;
141+ ptr = ptr:: null_mut ( ) ;
142+ } else {
143+ panic ! (
144+ "OOM test attempted to allocate after OOM: {layout:?}\n \
145+ \n \
146+ Hint: if this is acceptable, configure the OOM test to allow allocation \
147+ after OOM with `OomTest::allow_alloc_after_oom`"
148+ )
149+ }
126150 }
127151 }
128152 }
@@ -161,6 +185,7 @@ unsafe impl GlobalAlloc for OomTestAllocator {
161185/// OomTest::new()
162186/// .max_iters(1_000_000)
163187/// .max_duration(Duration::from_secs(5))
188+ /// .allow_alloc_after_oom(true)
164189/// .test(|| {
165190/// todo!("insert code here that should handle OOM here...")
166191/// })
@@ -169,6 +194,7 @@ unsafe impl GlobalAlloc for OomTestAllocator {
169194pub struct OomTest {
170195 max_iters : Option < u32 > ,
171196 max_duration : Option < time:: Duration > ,
197+ allow_alloc_after_oom : bool ,
172198}
173199
174200impl OomTest {
@@ -187,6 +213,7 @@ impl OomTest {
187213 OomTest {
188214 max_iters : None ,
189215 max_duration : None ,
216+ allow_alloc_after_oom : false ,
190217 }
191218 }
192219
@@ -202,6 +229,15 @@ impl OomTest {
202229 self
203230 }
204231
232+ /// Configure whether to allow allocation attempts after an OOM has already
233+ /// been injected.
234+ ///
235+ /// The default is `false`.
236+ pub fn allow_alloc_after_oom ( & mut self , allow : bool ) -> & mut Self {
237+ self . allow_alloc_after_oom = allow;
238+ self
239+ }
240+
205241 /// Repeatedly run the given test function, injecting OOMs at different
206242 /// times and checking that it correctly handles them.
207243 ///
@@ -225,7 +261,10 @@ impl OomTest {
225261
226262 log:: trace!( "=== Injecting OOM after {i} allocations ===" ) ;
227263 let ( result, old_state) = {
228- let guard = ScopedOomState :: new ( OomState :: OomOnAlloc ( i) ) ;
264+ let guard = ScopedOomState :: new ( OomState :: OomOnAlloc {
265+ counter : i,
266+ allow_alloc_after : self . allow_alloc_after_oom ,
267+ } ) ;
229268 assert_eq ! ( guard. prev_state, OomState :: OutsideOomTest ) ;
230269
231270 let result = test_func ( ) ;
@@ -238,24 +277,24 @@ impl OomTest {
238277
239278 // The test function completed successfully before we ran out of
240279 // allocation fuel, so we're done.
241- ( Ok ( ( ) ) , OomState :: OomOnAlloc ( _ ) ) => break ,
280+ ( Ok ( ( ) ) , OomState :: OomOnAlloc { .. } ) => break ,
242281
243282 // We injected an OOM and the test function handled it
244283 // correctly; continue to the next iteration.
245- ( Err ( e) , OomState :: DidOom ) if self . is_oom_error ( & e) => { }
284+ ( Err ( e) , OomState :: DidOom { .. } ) if self . is_oom_error ( & e) => { }
246285
247286 // Missed OOMs.
248- ( Ok ( ( ) ) , OomState :: DidOom ) => {
287+ ( Ok ( ( ) ) , OomState :: DidOom { .. } ) => {
249288 bail ! ( "OOM test function missed an OOM: returned Ok(())" ) ;
250289 }
251- ( Err ( e) , OomState :: DidOom ) => {
290+ ( Err ( e) , OomState :: DidOom { .. } ) => {
252291 return Err (
253292 e. context ( "OOM test function missed an OOM: returned non-OOM error" )
254293 ) ;
255294 }
256295
257296 // Unexpected error.
258- ( Err ( e) , OomState :: OomOnAlloc ( _ ) ) => {
297+ ( Err ( e) , OomState :: OomOnAlloc { .. } ) => {
259298 return Err (
260299 e. context ( "OOM test function returned an error when there was no OOM" )
261300 ) ;
0 commit comments