55from fastapi .concurrency import AsyncExitStack
66from fastapi .dependencies .utils import solve_dependencies
77from sqlalchemy import event
8- from sqlalchemy .engine import Connection , Engine
9- from sqlalchemy .engine .default import DefaultExecutionContext
8+ from sqlalchemy .engine import Connection , Engine , ExecutionContext
109from sqlalchemy .orm import Session
1110
1211from debug_toolbar .panels .sql import SQLPanel
@@ -20,40 +19,26 @@ def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
2019 self .engines : t .Set [Engine ] = set ()
2120
2221 def register (self , engine : Engine ) -> None :
23- event .listen (engine , "before_cursor_execute" , self .before_execute )
24- event .listen (engine , "after_cursor_execute" , self .after_execute )
22+ event .listen (engine , "before_cursor_execute" , self .before_execute , named = True )
23+ event .listen (engine , "after_cursor_execute" , self .after_execute , named = True )
2524
2625 def unregister (self , engine : Engine ) -> None :
2726 event .remove (engine , "before_cursor_execute" , self .before_execute )
2827 event .remove (engine , "after_cursor_execute" , self .after_execute )
2928
30- def before_execute (
31- self ,
32- conn : Connection ,
33- cursor : t .Any ,
34- statement : str ,
35- parameters : t .Union [t .Sequence , t .Dict ],
36- context : DefaultExecutionContext ,
37- executemany : bool ,
38- ) -> None :
39- conn .info .setdefault ("start_time" , []).append (perf_counter ())
29+ def before_execute (self , context : ExecutionContext , ** kwargs : t .Any ) -> None :
30+ context ._start_time = perf_counter () # type: ignore[attr-defined]
4031
41- def after_execute (
42- self ,
43- conn : Connection ,
44- cursor : t .Any ,
45- statement : str ,
46- parameters : t .Union [t .Sequence , t .Dict ],
47- context : DefaultExecutionContext ,
48- executemany : bool ,
49- ) -> None :
32+ def after_execute (self , context : ExecutionContext , ** kwargs : t .Any ) -> None :
5033 query = {
51- "duration" : (perf_counter () - conn .info ["start_time" ].pop (- 1 )) * 1000 ,
52- "sql" : statement ,
53- "params" : parameters ,
54- "is_select" : context .invoked_statement .is_select ,
34+ "duration" : (
35+ perf_counter () - context ._start_time # type: ignore[attr-defined]
36+ )
37+ * 1000 ,
38+ "sql" : context .statement ,
39+ "params" : context .parameters ,
5540 }
56- self .add_query (str (conn .engine .url ), query )
41+ self .add_query (str (context .engine .url ), query )
5742
5843 async def add_engines (self , request : Request ):
5944 route = request ["route" ]
@@ -70,7 +55,12 @@ async def add_engines(self, request: Request):
7055 )
7156 for value in solved_result [0 ].values ():
7257 if isinstance (value , Session ):
73- self .engines .add (value .get_bind ())
58+ bind = value .get_bind ()
59+
60+ if isinstance (bind , Connection ):
61+ self .engines .add (bind .engine )
62+ else :
63+ self .engines .add (bind )
7464
7565 async def process_request (self , request : Request ) -> Response :
7666 await self .add_engines (request )
0 commit comments