|
47 | 47 | TaskStatus, |
48 | 48 | TaskStatusUpdateEvent, |
49 | 49 | ) |
| 50 | +from a2a.helpers.proto_helpers import new_task_from_user_message |
50 | 51 | from a2a.utils import TransportProtocol |
51 | | -from a2a.utils.task import new_task |
52 | 52 | from a2a.utils.errors import ( |
53 | 53 | InvalidParamsError, |
54 | 54 | TaskNotCancelableError, |
@@ -249,7 +249,7 @@ class DummyAgentExecutor(AgentExecutor): |
249 | 249 | async def execute( |
250 | 250 | self, context: RequestContext, event_queue: EventQueue |
251 | 251 | ): |
252 | | - task = new_task(context.message) |
| 252 | + task = new_task_from_user_message(context.message) |
253 | 253 | task.status.state = TaskState.TASK_STATE_WORKING |
254 | 254 | await event_queue.enqueue_event(task) |
255 | 255 | await event_queue.enqueue_event( |
@@ -294,7 +294,7 @@ class DummyAgentExecutor(AgentExecutor): |
294 | 294 | async def execute( |
295 | 295 | self, context: RequestContext, event_queue: EventQueue |
296 | 296 | ): |
297 | | - task = new_task(context.message) |
| 297 | + task = new_task_from_user_message(context.message) |
298 | 298 | task.status.state = TaskState.TASK_STATE_WORKING |
299 | 299 | await event_queue.enqueue_event(task) |
300 | 300 | await event_queue.enqueue_event( |
@@ -349,7 +349,7 @@ class DummyAgentExecutor(AgentExecutor): |
349 | 349 | async def execute( |
350 | 350 | self, context: RequestContext, event_queue: EventQueue |
351 | 351 | ): |
352 | | - task = new_task(context.message) |
| 352 | + task = new_task_from_user_message(context.message) |
353 | 353 | task.status.state = TaskState.TASK_STATE_COMPLETED |
354 | 354 | await event_queue.enqueue_event(task) |
355 | 355 |
|
@@ -478,7 +478,7 @@ class ErrorAfterAgent(AgentExecutor): |
478 | 478 | async def execute( |
479 | 479 | self, context: RequestContext, event_queue: EventQueue |
480 | 480 | ): |
481 | | - task = new_task(context.message) |
| 481 | + task = new_task_from_user_message(context.message) |
482 | 482 | task.status.state = TaskState.TASK_STATE_WORKING |
483 | 483 | await event_queue.enqueue_event(task) |
484 | 484 | started_event.set() |
@@ -543,7 +543,7 @@ class ErrorCancelAgent(AgentExecutor): |
543 | 543 | async def execute( |
544 | 544 | self, context: RequestContext, event_queue: EventQueue |
545 | 545 | ): |
546 | | - task = new_task(context.message) |
| 546 | + task = new_task_from_user_message(context.message) |
547 | 547 | task.status.state = TaskState.TASK_STATE_WORKING |
548 | 548 | await event_queue.enqueue_event(task) |
549 | 549 | started_event.set() |
@@ -599,7 +599,7 @@ class ErrorAfterAgent(AgentExecutor): |
599 | 599 | async def execute( |
600 | 600 | self, context: RequestContext, event_queue: EventQueue |
601 | 601 | ): |
602 | | - task = new_task(context.message) |
| 602 | + task = new_task_from_user_message(context.message) |
603 | 603 | task.status.state = TaskState.TASK_STATE_WORKING |
604 | 604 | await event_queue.enqueue_event(task) |
605 | 605 | started_event.set() |
@@ -725,7 +725,7 @@ class DummyCancelAgent(AgentExecutor): |
725 | 725 | async def execute( |
726 | 726 | self, context: RequestContext, event_queue: EventQueue |
727 | 727 | ): |
728 | | - task = new_task(context.message) |
| 728 | + task = new_task_from_user_message(context.message) |
729 | 729 | task.status.state = TaskState.TASK_STATE_WORKING |
730 | 730 | await event_queue.enqueue_event(task) |
731 | 731 | started_event.set() |
@@ -789,7 +789,7 @@ class ComplexAgent(AgentExecutor): |
789 | 789 | async def execute( |
790 | 790 | self, context: RequestContext, event_queue: EventQueue |
791 | 791 | ): |
792 | | - task = new_task(context.message) |
| 792 | + task = new_task_from_user_message(context.message) |
793 | 793 | task.status.state = TaskState.TASK_STATE_WORKING |
794 | 794 | await event_queue.enqueue_event(task) |
795 | 795 | started_event.set() |
@@ -904,7 +904,7 @@ async def execute( |
904 | 904 | ) |
905 | 905 | return |
906 | 906 |
|
907 | | - task = new_task(context.message) |
| 907 | + task = new_task_from_user_message(context.message) |
908 | 908 | task.status.state = TaskState.TASK_STATE_WORKING |
909 | 909 | await event_queue.enqueue_event(task) |
910 | 910 | started_event.set() |
@@ -1028,7 +1028,7 @@ class ImmediateAgent(AgentExecutor): |
1028 | 1028 | async def execute( |
1029 | 1029 | self, context: RequestContext, event_queue: EventQueue |
1030 | 1030 | ): |
1031 | | - task = new_task(context.message) |
| 1031 | + task = new_task_from_user_message(context.message) |
1032 | 1032 | task.status.state = TaskState.TASK_STATE_WORKING |
1033 | 1033 | await event_queue.enqueue_event(task) |
1034 | 1034 | await event_queue.enqueue_event( |
@@ -1085,15 +1085,15 @@ async def execute( |
1085 | 1085 | ): |
1086 | 1086 | message = context.message |
1087 | 1087 | if message and message.parts and message.parts[0].text == 'start': |
1088 | | - task = new_task(message) |
| 1088 | + task = new_task_from_user_message(message) |
1089 | 1089 | task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED |
1090 | 1090 | await event_queue.enqueue_event(task) |
1091 | 1091 | elif ( |
1092 | 1092 | message |
1093 | 1093 | and message.parts |
1094 | 1094 | and message.parts[0].text == 'here is input' |
1095 | 1095 | ): |
1096 | | - task = new_task(message) |
| 1096 | + task = new_task_from_user_message(message) |
1097 | 1097 | task.status.state = TaskState.TASK_STATE_COMPLETED |
1098 | 1098 | await event_queue.enqueue_event(task) |
1099 | 1099 | else: |
@@ -1164,7 +1164,7 @@ class AuthAgent(AgentExecutor): |
1164 | 1164 | async def execute( |
1165 | 1165 | self, context: RequestContext, event_queue: EventQueue |
1166 | 1166 | ): |
1167 | | - task = new_task(context.message) |
| 1167 | + task = new_task_from_user_message(context.message) |
1168 | 1168 | task.status.state = TaskState.TASK_STATE_WORKING |
1169 | 1169 | await event_queue.enqueue_event(task) |
1170 | 1170 | await event_queue.enqueue_event( |
@@ -1246,7 +1246,7 @@ async def execute( |
1246 | 1246 | ): |
1247 | 1247 | message = context.message |
1248 | 1248 | if message and message.parts and message.parts[0].text == 'start': |
1249 | | - task = new_task(message) |
| 1249 | + task = new_task_from_user_message(message) |
1250 | 1250 | task.status.state = TaskState.TASK_STATE_AUTH_REQUIRED |
1251 | 1251 | await event_queue.enqueue_event(task) |
1252 | 1252 | elif ( |
@@ -1326,7 +1326,7 @@ class EmitAgent(AgentExecutor): |
1326 | 1326 | async def execute( |
1327 | 1327 | self, context: RequestContext, event_queue: EventQueue |
1328 | 1328 | ): |
1329 | | - task = new_task(context.message) |
| 1329 | + task = new_task_from_user_message(context.message) |
1330 | 1330 | task.status.state = TaskState.TASK_STATE_WORKING |
1331 | 1331 | await event_queue.enqueue_event(task) |
1332 | 1332 |
|
@@ -1544,7 +1544,7 @@ class ArtifactAgent(AgentExecutor): |
1544 | 1544 | async def execute( |
1545 | 1545 | self, context: RequestContext, event_queue: EventQueue |
1546 | 1546 | ): |
1547 | | - task = new_task(context.message) |
| 1547 | + task = new_task_from_user_message(context.message) |
1548 | 1548 | task.status.state = TaskState.TASK_STATE_WORKING |
1549 | 1549 | await event_queue.enqueue_event(task) |
1550 | 1550 | await event_queue.enqueue_event( |
@@ -1712,7 +1712,7 @@ class TerminalAgent(AgentExecutor): |
1712 | 1712 | async def execute( |
1713 | 1713 | self, context: RequestContext, event_queue: EventQueue |
1714 | 1714 | ): |
1715 | | - task = new_task(context.message) |
| 1715 | + task = new_task_from_user_message(context.message) |
1716 | 1716 | task.status.state = TaskState.TASK_STATE_COMPLETED |
1717 | 1717 | await event_queue.enqueue_event(task) |
1718 | 1718 |
|
@@ -1809,7 +1809,7 @@ async def execute( |
1809 | 1809 | ): |
1810 | 1810 | message = context.message |
1811 | 1811 | if message and message.parts and message.parts[0].text == 'start': |
1812 | | - task = new_task(message) |
| 1812 | + task = new_task_from_user_message(message) |
1813 | 1813 | task.status.state = TaskState.TASK_STATE_INPUT_REQUIRED |
1814 | 1814 | await event_queue.enqueue_event(task) |
1815 | 1815 | elif message and message.parts and message.parts[0].text == 'input': |
@@ -1944,7 +1944,7 @@ async def execute( |
1944 | 1944 | ): |
1945 | 1945 | if initial_task_type == 'new_task': |
1946 | 1946 | # Create with new_task |
1947 | | - task = new_task(context.message) |
| 1947 | + task = new_task_from_user_message(context.message) |
1948 | 1948 | task.status.state = TaskState.TASK_STATE_WORKING |
1949 | 1949 | await event_queue.enqueue_event(task) |
1950 | 1950 | else: |
@@ -2092,7 +2092,9 @@ class TaskMessageAgent(AgentExecutor): |
2092 | 2092 | async def execute( |
2093 | 2093 | self, context: RequestContext, event_queue: EventQueue |
2094 | 2094 | ): |
2095 | | - await event_queue.enqueue_event(new_task(context.message)) |
| 2095 | + await event_queue.enqueue_event( |
| 2096 | + new_task_from_user_message(context.message) |
| 2097 | + ) |
2096 | 2098 | await event_queue.enqueue_event( |
2097 | 2099 | Message(message_id='m1', parts=[Part(text='m1')]) |
2098 | 2100 | ) |
|
0 commit comments