Skip to content

Commit bd1ba90

Browse files
committed
Fixes after #978
1 parent 819f3d8 commit bd1ba90

5 files changed

Lines changed: 29 additions & 28 deletions

File tree

tests/integration/cross_version/client_server/server_0_3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from starlette.requests import Request
3939
from starlette.concurrency import iterate_in_threadpool
4040
import time
41-
from a2a.utils.task import new_task
41+
from a2a.helpers.proto_helpers import new_task_from_user_message
4242
from server_common import CustomLoggingMiddleware
4343

4444

@@ -49,7 +49,7 @@ def __init__(self):
4949
async def execute(self, context: RequestContext, event_queue: EventQueue):
5050
print(f'SERVER: execute called for task {context.task_id}')
5151

52-
task = new_task(context.message)
52+
task = new_task_from_user_message(context.message)
5353
task.id = context.task_id
5454
task.context_id = context.context_id
5555
task.status.state = TaskState.working

tests/integration/cross_version/client_server/server_1_0.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from a2a.utils import TransportProtocol
2929
from server_common import CustomLoggingMiddleware
3030
from google.protobuf.struct_pb2 import Struct, Value
31-
from a2a.utils.task import new_task
31+
from a2a.helpers.proto_helpers import new_task_from_user_message
3232

3333

3434
class MockAgentExecutor(AgentExecutor):
@@ -37,7 +37,7 @@ def __init__(self):
3737

3838
async def execute(self, context: RequestContext, event_queue: EventQueue):
3939
print(f'SERVER: execute called for task {context.task_id}')
40-
task = new_task(context.message)
40+
task = new_task_from_user_message(context.message)
4141
task.id = context.task_id
4242
task.context_id = context.context_id
4343
task.status.state = TaskState.TASK_STATE_WORKING

tests/integration/test_copying_observability.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
SendMessageRequest,
2626
TaskState,
2727
)
28+
from a2a.helpers.proto_helpers import new_task_from_user_message
2829
from a2a.utils import TransportProtocol
29-
from a2a.utils.task import new_task
3030

3131

3232
class MockMutatingAgentExecutor(AgentExecutor):
@@ -43,7 +43,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
4343

4444
if user_input == 'Init task':
4545
# Explicitly save status change to ensure task exists with some state
46-
task = new_task(context.message)
46+
task = new_task_from_user_message(context.message)
4747
task.id = context.task_id
4848
task.context_id = context.context_id
4949
task.status.state = TaskState.TASK_STATE_WORKING

tests/integration/test_scenarios.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
TaskStatus,
4848
TaskStatusUpdateEvent,
4949
)
50+
from a2a.helpers.proto_helpers import new_task_from_user_message
5051
from a2a.utils import TransportProtocol
51-
from a2a.utils.task import new_task
5252
from a2a.utils.errors import (
5353
InvalidParamsError,
5454
TaskNotCancelableError,
@@ -249,7 +249,7 @@ class DummyAgentExecutor(AgentExecutor):
249249
async def execute(
250250
self, context: RequestContext, event_queue: EventQueue
251251
):
252-
task = new_task(context.message)
252+
task = new_task_from_user_message(context.message)
253253
task.status.state = TaskState.TASK_STATE_WORKING
254254
await event_queue.enqueue_event(task)
255255
await event_queue.enqueue_event(
@@ -294,7 +294,7 @@ class DummyAgentExecutor(AgentExecutor):
294294
async def execute(
295295
self, context: RequestContext, event_queue: EventQueue
296296
):
297-
task = new_task(context.message)
297+
task = new_task_from_user_message(context.message)
298298
task.status.state = TaskState.TASK_STATE_WORKING
299299
await event_queue.enqueue_event(task)
300300
await event_queue.enqueue_event(
@@ -349,7 +349,7 @@ class DummyAgentExecutor(AgentExecutor):
349349
async def execute(
350350
self, context: RequestContext, event_queue: EventQueue
351351
):
352-
task = new_task(context.message)
352+
task = new_task_from_user_message(context.message)
353353
task.status.state = TaskState.TASK_STATE_COMPLETED
354354
await event_queue.enqueue_event(task)
355355

@@ -478,7 +478,7 @@ class ErrorAfterAgent(AgentExecutor):
478478
async def execute(
479479
self, context: RequestContext, event_queue: EventQueue
480480
):
481-
task = new_task(context.message)
481+
task = new_task_from_user_message(context.message)
482482
task.status.state = TaskState.TASK_STATE_WORKING
483483
await event_queue.enqueue_event(task)
484484
started_event.set()
@@ -543,7 +543,7 @@ class ErrorCancelAgent(AgentExecutor):
543543
async def execute(
544544
self, context: RequestContext, event_queue: EventQueue
545545
):
546-
task = new_task(context.message)
546+
task = new_task_from_user_message(context.message)
547547
task.status.state = TaskState.TASK_STATE_WORKING
548548
await event_queue.enqueue_event(task)
549549
started_event.set()
@@ -599,7 +599,7 @@ class ErrorAfterAgent(AgentExecutor):
599599
async def execute(
600600
self, context: RequestContext, event_queue: EventQueue
601601
):
602-
task = new_task(context.message)
602+
task = new_task_from_user_message(context.message)
603603
task.status.state = TaskState.TASK_STATE_WORKING
604604
await event_queue.enqueue_event(task)
605605
started_event.set()
@@ -725,7 +725,7 @@ class DummyCancelAgent(AgentExecutor):
725725
async def execute(
726726
self, context: RequestContext, event_queue: EventQueue
727727
):
728-
task = new_task(context.message)
728+
task = new_task_from_user_message(context.message)
729729
task.status.state = TaskState.TASK_STATE_WORKING
730730
await event_queue.enqueue_event(task)
731731
started_event.set()
@@ -789,7 +789,7 @@ class ComplexAgent(AgentExecutor):
789789
async def execute(
790790
self, context: RequestContext, event_queue: EventQueue
791791
):
792-
task = new_task(context.message)
792+
task = new_task_from_user_message(context.message)
793793
task.status.state = TaskState.TASK_STATE_WORKING
794794
await event_queue.enqueue_event(task)
795795
started_event.set()
@@ -904,7 +904,7 @@ async def execute(
904904
)
905905
return
906906

907-
task = new_task(context.message)
907+
task = new_task_from_user_message(context.message)
908908
task.status.state = TaskState.TASK_STATE_WORKING
909909
await event_queue.enqueue_event(task)
910910
started_event.set()
@@ -1028,7 +1028,7 @@ class ImmediateAgent(AgentExecutor):
10281028
async def execute(
10291029
self, context: RequestContext, event_queue: EventQueue
10301030
):
1031-
task = new_task(context.message)
1031+
task = new_task_from_user_message(context.message)
10321032
task.status.state = TaskState.TASK_STATE_WORKING
10331033
await event_queue.enqueue_event(task)
10341034
await event_queue.enqueue_event(
@@ -1085,15 +1085,15 @@ async def execute(
10851085
):
10861086
message = context.message
10871087
if message and message.parts and message.parts[0].text == 'start':
1088-
task = new_task(message)
1088+
task = new_task_from_user_message(message)
10891089
task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED
10901090
await event_queue.enqueue_event(task)
10911091
elif (
10921092
message
10931093
and message.parts
10941094
and message.parts[0].text == 'here is input'
10951095
):
1096-
task = new_task(message)
1096+
task = new_task_from_user_message(message)
10971097
task.status.state = TaskState.TASK_STATE_COMPLETED
10981098
await event_queue.enqueue_event(task)
10991099
else:
@@ -1164,7 +1164,7 @@ class AuthAgent(AgentExecutor):
11641164
async def execute(
11651165
self, context: RequestContext, event_queue: EventQueue
11661166
):
1167-
task = new_task(context.message)
1167+
task = new_task_from_user_message(context.message)
11681168
task.status.state = TaskState.TASK_STATE_WORKING
11691169
await event_queue.enqueue_event(task)
11701170
await event_queue.enqueue_event(
@@ -1246,7 +1246,7 @@ async def execute(
12461246
):
12471247
message = context.message
12481248
if message and message.parts and message.parts[0].text == 'start':
1249-
task = new_task(message)
1249+
task = new_task_from_user_message(message)
12501250
task.status.state = TaskState.TASK_STATE_AUTH_REQUIRED
12511251
await event_queue.enqueue_event(task)
12521252
elif (
@@ -1326,7 +1326,7 @@ class EmitAgent(AgentExecutor):
13261326
async def execute(
13271327
self, context: RequestContext, event_queue: EventQueue
13281328
):
1329-
task = new_task(context.message)
1329+
task = new_task_from_user_message(context.message)
13301330
task.status.state = TaskState.TASK_STATE_WORKING
13311331
await event_queue.enqueue_event(task)
13321332

@@ -1544,7 +1544,7 @@ class ArtifactAgent(AgentExecutor):
15441544
async def execute(
15451545
self, context: RequestContext, event_queue: EventQueue
15461546
):
1547-
task = new_task(context.message)
1547+
task = new_task_from_user_message(context.message)
15481548
task.status.state = TaskState.TASK_STATE_WORKING
15491549
await event_queue.enqueue_event(task)
15501550
await event_queue.enqueue_event(
@@ -1712,7 +1712,7 @@ class TerminalAgent(AgentExecutor):
17121712
async def execute(
17131713
self, context: RequestContext, event_queue: EventQueue
17141714
):
1715-
task = new_task(context.message)
1715+
task = new_task_from_user_message(context.message)
17161716
task.status.state = TaskState.TASK_STATE_COMPLETED
17171717
await event_queue.enqueue_event(task)
17181718

@@ -1809,7 +1809,7 @@ async def execute(
18091809
):
18101810
message = context.message
18111811
if message and message.parts and message.parts[0].text == 'start':
1812-
task = new_task(message)
1812+
task = new_task_from_user_message(message)
18131813
task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED
18141814
await event_queue.enqueue_event(task)
18151815
elif message and message.parts and message.parts[0].text == 'input':
@@ -1944,7 +1944,7 @@ async def execute(
19441944
):
19451945
if initial_task_type == 'new_task':
19461946
# Create with new_task
1947-
task = new_task(context.message)
1947+
task = new_task_from_user_message(context.message)
19481948
task.status.state = TaskState.TASK_STATE_WORKING
19491949
await event_queue.enqueue_event(task)
19501950
else:
@@ -2092,7 +2092,9 @@ class TaskMessageAgent(AgentExecutor):
20922092
async def execute(
20932093
self, context: RequestContext, event_queue: EventQueue
20942094
):
2095-
await event_queue.enqueue_event(new_task(context.message))
2095+
await event_queue.enqueue_event(
2096+
new_task_from_user_message(context.message)
2097+
)
20962098
await event_queue.enqueue_event(
20972099
Message(message_id='m1', parts=[Part(text='m1')])
20982100
)

tests/server/agent_execution/test_active_task.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
Part,
2424
)
2525
from a2a.utils.errors import InvalidParamsError
26-
from a2a.utils.task import new_task
2726

2827

2928
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)