@@ -535,3 +535,39 @@ async def test_concurrent_updates_race_condition(event_queue):
535535 assert len (successes ) == 1
536536 assert len (failures ) == 1
537537 assert event_queue .enqueue_event .call_count == 1
538+
539+
540+ @pytest .mark .asyncio
541+ async def test_reject_invalid_task_id (event_queue ):
542+ """Test rejecting a task with an invalid ID is handled gracefully."""
543+ pass
544+
545+
546+ @pytest .mark .asyncio
547+ async def test_reject_concurrently_with_complete (event_queue ):
548+ """Test for race conditions when reject and complete are called concurrently."""
549+ task_updater = TaskUpdater (
550+ event_queue = event_queue ,
551+ task_id = 'concurrent-task' ,
552+ context_id = 'concurrent-context' ,
553+ )
554+
555+ tasks = [
556+ task_updater .reject (),
557+ task_updater .complete (),
558+ ]
559+
560+ results = await asyncio .gather (* tasks , return_exceptions = True )
561+
562+ successes = [r for r in results if not isinstance (r , Exception )]
563+ failures = [r for r in results if isinstance (r , RuntimeError )]
564+
565+ assert len (successes ) == 1
566+ assert len (failures ) == 1
567+
568+ assert event_queue .enqueue_event .call_count == 1
569+
570+ event = event_queue .enqueue_event .call_args [0 ][0 ]
571+ assert isinstance (event , TaskStatusUpdateEvent )
572+ assert event .final is True
573+ assert event .status .state in [TaskState .rejected , TaskState .completed ]
0 commit comments