Skip to content

Commit 87fbb74

Browse files
committed
fix test
1 parent 7f626ea commit 87fbb74

1 file changed

Lines changed: 36 additions & 32 deletions

File tree

tests/contrib/tasks/test_vertex_task_store.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -543,56 +543,60 @@ async def test_update_task_status_details(
543543
original_task = Task(
544544
id=task_id,
545545
context_id='session-update',
546-
status=TaskStatus(state=TaskState.submitted),
547-
kind='task',
546+
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
548547
metadata=None,
549548
artifacts=[],
550549
history=[],
551550
)
552-
await vertex_store.save(original_task)
551+
await vertex_store.save(original_task, ServerCallContext())
553552

554-
retrieved_before_update = await vertex_store.get(task_id)
553+
retrieved_before_update = await vertex_store.get(
554+
task_id, ServerCallContext()
555+
)
555556
assert retrieved_before_update is not None
556-
assert retrieved_before_update.status.message is None
557-
558-
updated_task = original_task.model_copy(deep=True)
559-
updated_task.status.state = TaskState.failed
560-
updated_task.status.timestamp = '2023-01-02T11:00:00Z'
561-
updated_task.status.message = Message(
562-
message_id='msg-error-1',
563-
role=Role.agent,
564-
parts=[
565-
Part(
566-
root=TextPart(
557+
assert (
558+
retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED
559+
)
560+
561+
updated_task = Task()
562+
updated_task.CopyFrom(original_task)
563+
updated_task.status.state = TaskState.TASK_STATE_FAILED
564+
updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z')
565+
updated_task.status.message.CopyFrom(
566+
Message(
567+
message_id='msg-error-1',
568+
role=Role.ROLE_AGENT,
569+
parts=[
570+
Part(
567571
text='Task failed due to an unknown error',
568572
metadata={'error_code': 'UNKNOWN', 'retryable': False},
569573
)
570-
)
571-
],
574+
],
575+
)
572576
)
573577

574-
await vertex_store.save(updated_task)
578+
await vertex_store.save(updated_task, ServerCallContext())
575579

576-
retrieved_after_update = await vertex_store.get(task_id)
580+
retrieved_after_update = await vertex_store.get(
581+
task_id, ServerCallContext()
582+
)
577583
assert retrieved_after_update is not None
578-
assert retrieved_after_update.status.state == TaskState.failed
584+
assert retrieved_after_update.status.state == TaskState.TASK_STATE_FAILED
579585
assert retrieved_after_update.status.message is not None
580586
assert retrieved_after_update.status.message.message_id == 'msg-error-1'
581-
assert retrieved_after_update.status.message.role == Role.agent
587+
assert retrieved_after_update.status.message.role == Role.ROLE_AGENT
582588
assert len(retrieved_after_update.status.message.parts) == 1
583589

584-
assert isinstance(
585-
retrieved_after_update.status.message.parts[0].root, TextPart
586-
)
587-
text_part = retrieved_after_update.status.message.parts[0].root
588-
assert text_part.text == 'Task failed due to an unknown error'
589-
assert text_part.metadata == {'error_code': 'UNKNOWN', 'retryable': False}
590+
part = retrieved_after_update.status.message.parts[0]
591+
assert part.text == 'Task failed due to an unknown error'
592+
assert part.metadata == {'error_code': 'UNKNOWN', 'retryable': False}
590593

591594
# Also test clearing the message
592-
cleared_task = updated_task.model_copy(deep=True)
593-
cleared_task.status.message = None
595+
cleared_task = Task()
596+
cleared_task.CopyFrom(updated_task)
597+
cleared_task.status.ClearField('message')
594598

595-
await vertex_store.save(cleared_task)
596-
retrieved_cleared = await vertex_store.get(task_id)
599+
await vertex_store.save(cleared_task, ServerCallContext())
600+
retrieved_cleared = await vertex_store.get(task_id, ServerCallContext())
597601
assert retrieved_cleared is not None
598-
assert retrieved_cleared.status.message is None
602+
assert not retrieved_cleared.status.HasField('message')

0 commit comments

Comments
 (0)