1- import inspect
21import logging
32
3+ from collections .abc import Callable
44from datetime import datetime , timezone
5- from typing import TYPE_CHECKING
65
76
87try :
2221 "'pip install a2a-sdk[sqlite]', "
2322 "or 'pip install a2a-sdk[sql]'"
2423 ) from e
25-
26- if TYPE_CHECKING :
27- from collections .abc import Callable
28-
2924from google .protobuf .json_format import MessageToDict , ParseDict
3025
31- from a2a .compat .v0_3 import conversions
26+ from a2a .compat .v0_3 .conversions import (
27+ compat_task_model_to_core ,
28+ )
3229from a2a .server .context import ServerCallContext
3330from a2a .server .models import Base , TaskModel , create_task_model
3431from a2a .server .owner_resolver import OwnerResolver , resolve_user_scope
@@ -55,16 +52,18 @@ class DatabaseTaskStore(TaskStore):
5552 _initialized : bool
5653 task_model : type [TaskModel ]
5754 owner_resolver : OwnerResolver
55+ core_to_model_conversion : Callable [[Task , str ], TaskModel ] | None = None
56+ model_to_core_conversion : Callable [[TaskModel ], Task ] | None = None
5857
59- core_to_model_conversion : 'Callable[[Task, str], TaskModel] | None' = None
60- model_to_core_conversion : 'Callable[[TaskModel], Task] | None' = None
61-
62- def __init__ (
58+ def __init__ ( # noqa: PLR0913
6359 self ,
6460 engine : AsyncEngine ,
6561 create_table : bool = True ,
6662 table_name : str = 'tasks' ,
6763 owner_resolver : OwnerResolver = resolve_user_scope ,
64+ core_to_model_conversion : Callable [[Task , str ], TaskModel ]
65+ | None = None ,
66+ model_to_core_conversion : Callable [[TaskModel ], Task ] | None = None ,
6867 ) -> None :
6968 """Initializes the DatabaseTaskStore.
7069
@@ -73,6 +72,8 @@ def __init__(
7372 create_table: If true, create tasks table on initialization.
7473 table_name: Name of the database table. Defaults to 'tasks'.
7574 owner_resolver: Function to resolve the owner from the context.
75+ core_to_model_conversion: Optional function to convert a Task to a TaskModel.
76+ model_to_core_conversion: Optional function to convert a TaskModel to a Task.
7677 """
7778 logger .debug (
7879 'Initializing DatabaseTaskStore with existing engine, table: %s' ,
@@ -85,6 +86,8 @@ def __init__(
8586 self .create_table = create_table
8687 self ._initialized = False
8788 self .owner_resolver = owner_resolver
89+ self .core_to_model_conversion = core_to_model_conversion
90+ self .model_to_core_conversion = model_to_core_conversion
8891
8992 self .task_model = (
9093 TaskModel
@@ -117,12 +120,8 @@ async def _ensure_initialized(self) -> None:
117120
118121 def _to_orm (self , task : Task , owner : str ) -> TaskModel :
119122 """Maps a Proto Task to a SQLAlchemy TaskModel instance."""
120- if conversion := self .core_to_model_conversion :
121- # If it's a bound method of this instance, call the underlying function
122- # to avoid passing 'self' twice.
123- if inspect .ismethod (conversion ):
124- return conversion .__func__ (task , owner )
125- return conversion (task , owner )
123+ if self .core_to_model_conversion :
124+ return self .core_to_model_conversion (task , owner )
126125
127126 return self .task_model (
128127 id = task .id ,
@@ -145,12 +144,8 @@ def _to_orm(self, task: Task, owner: str) -> TaskModel:
145144
146145 def _from_orm (self , task_model : TaskModel ) -> Task :
147146 """Maps a SQLAlchemy TaskModel to a Proto Task instance."""
148- if conversion := self .model_to_core_conversion :
149- # If it's a bound method of this instance, call the underlying function
150- # to avoid passing 'self' twice.
151- if inspect .ismethod (conversion ):
152- return conversion .__func__ (task_model )
153- return conversion (task_model )
147+ if self .model_to_core_conversion :
148+ return self .model_to_core_conversion (task_model )
154149
155150 if task_model .protocol_version == '1.0' :
156151 task = Task (
@@ -172,7 +167,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
172167 return task
173168
174169 # Legacy conversion
175- return conversions . compat_task_model_to_core (task_model )
170+ return compat_task_model_to_core (task_model )
176171
177172 async def save (
178173 self , task : Task , context : ServerCallContext | None = None
0 commit comments