3636logger = logging .getLogger ("databases" )
3737
3838
39- _ACTIVE_CONNECTIONS : ContextVar [
40- typing .Optional ["weakref.WeakKeyDictionary['Database', 'Connection']" ]
41- ] = ContextVar ("databases:open_connections" , default = None )
4239_ACTIVE_TRANSACTIONS : ContextVar [
4340 typing .Optional ["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']" ]
44- ] = ContextVar ("databases:open_transactions " , default = None )
41+ ] = ContextVar ("databases:active_transactions " , default = None )
4542
4643
4744class Database :
@@ -54,6 +51,8 @@ class Database:
5451 "sqlite" : "databases.backends.sqlite:SQLiteBackend" ,
5552 }
5653
54+ _connection_map : "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
55+
5756 def __init__ (
5857 self ,
5958 url : typing .Union [str , "DatabaseURL" ],
@@ -64,6 +63,7 @@ def __init__(
6463 self .url = DatabaseURL (url )
6564 self .options = options
6665 self .is_connected = False
66+ self ._connection_map = weakref .WeakKeyDictionary ()
6767
6868 self ._force_rollback = force_rollback
6969
@@ -78,28 +78,28 @@ def __init__(
7878 self ._global_transaction : typing .Optional [Transaction ] = None
7979
8080 @property
81- def _connection (self ) -> typing .Optional ["Connection" ]:
82- connections = _ACTIVE_CONNECTIONS .get ()
83- if connections is None :
84- return None
81+ def _current_task (self ) -> asyncio .Task :
82+ task = asyncio .current_task ()
83+ if not task :
84+ raise RuntimeError ("No currently active asyncio.Task found" )
85+ return task
8586
86- return connections .get (self , None )
87+ @property
88+ def _connection (self ) -> typing .Optional ["Connection" ]:
89+ return self ._connection_map .get (self ._current_task )
8790
8891 @_connection .setter
8992 def _connection (
9093 self , connection : typing .Optional ["Connection" ]
9194 ) -> typing .Optional ["Connection" ]:
92- connections = _ACTIVE_CONNECTIONS .get ()
93- if connections is None :
94- connections = weakref .WeakKeyDictionary ()
95- _ACTIVE_CONNECTIONS .set (connections )
95+ task = self ._current_task
9696
9797 if connection is None :
98- connections . pop (self , None )
98+ self . _connection_map . pop (task , None )
9999 else :
100- connections [ self ] = connection
100+ self . _connection_map [ task ] = connection
101101
102- return connections . get ( self , None )
102+ return self . _connection
103103
104104 async def connect (self ) -> None :
105105 """
@@ -119,7 +119,7 @@ async def connect(self) -> None:
119119 assert self ._global_connection is None
120120 assert self ._global_transaction is None
121121
122- self ._global_connection = Connection (self ._backend )
122+ self ._global_connection = Connection (self , self ._backend )
123123 self ._global_transaction = self ._global_connection .transaction (
124124 force_rollback = True
125125 )
@@ -218,7 +218,7 @@ def connection(self) -> "Connection":
218218 return self ._global_connection
219219
220220 if not self ._connection :
221- self ._connection = Connection (self ._backend )
221+ self ._connection = Connection (self , self ._backend )
222222
223223 return self ._connection
224224
@@ -243,7 +243,8 @@ def _get_backend(self) -> str:
243243
244244
245245class Connection :
246- def __init__ (self , backend : DatabaseBackend ) -> None :
246+ def __init__ (self , database : Database , backend : DatabaseBackend ) -> None :
247+ self ._database = database
247248 self ._backend = backend
248249
249250 self ._connection_lock = asyncio .Lock ()
@@ -277,6 +278,7 @@ async def __aexit__(
277278 self ._connection_counter -= 1
278279 if self ._connection_counter == 0 :
279280 await self ._connection .release ()
281+ self ._database ._connection = None
280282
281283 async def fetch_all (
282284 self ,
@@ -393,13 +395,15 @@ def _transaction(
393395 transactions = _ACTIVE_TRANSACTIONS .get ()
394396 if transactions is None :
395397 transactions = weakref .WeakKeyDictionary ()
396- _ACTIVE_TRANSACTIONS .set (transactions )
398+ else :
399+ transactions = transactions .copy ()
397400
398401 if transaction is None :
399402 transactions .pop (self , None )
400403 else :
401404 transactions [self ] = transaction
402405
406+ _ACTIVE_TRANSACTIONS .set (transactions )
403407 return transactions .get (self , None )
404408
405409 async def __aenter__ (self ) -> "Transaction" :
0 commit comments