Skip to content

Commit c12d9d7

Browse files
Merge pull request #815 from RafalSkolasinski/get-or-create-timeout
feat: make get_or_create timeout configurable
2 parents 94bdeda + 408d30e commit c12d9d7

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

graphdatascience/session/aura_api.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import math
45
import time
56
import warnings
67
from collections import defaultdict
@@ -165,7 +166,7 @@ def wait_for_session_running(
165166
session_id: str,
166167
sleep_time: float = 0.2,
167168
max_sleep_time: float = 10,
168-
max_wait_time: float = 300,
169+
max_wait_time: float = math.inf,
169170
) -> WaitResult:
170171
waited_time = 0.0
171172
while waited_time < max_wait_time:
@@ -186,7 +187,12 @@ def wait_for_session_running(
186187
time.sleep(sleep_time)
187188
sleep_time = min(sleep_time * 2, max_sleep_time, max_wait_time - waited_time)
188189

189-
return WaitResult.from_error(f"Session `{session_id}` is not running after {waited_time} seconds")
190+
return WaitResult.from_error(
191+
f"Session `{session_id}` is not running after {waited_time} seconds.\n"
192+
"\tThe session may become available at a later time.\n"
193+
f'\tConsider running `sessions.delete(session_id="{session_id}")` '
194+
"to avoid resource leakage."
195+
)
190196

191197
def delete_session(self, session_id: str) -> bool:
192198
response = self._request_session.delete(

graphdatascience/session/dedicated_sessions.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import hashlib
4+
import math
45
import warnings
56
from datetime import datetime, timedelta, timezone
67
from typing import Optional
@@ -24,7 +25,10 @@ def __init__(self, aura_api: AuraApi) -> None:
2425
self._aura_api = aura_api
2526

2627
def estimate(
27-
self, node_count: int, relationship_count: int, algorithm_categories: Optional[list[AlgorithmCategory]] = None
28+
self,
29+
node_count: int,
30+
relationship_count: int,
31+
algorithm_categories: Optional[list[AlgorithmCategory]] = None,
2832
) -> SessionMemory:
2933
if algorithm_categories is None:
3034
algorithm_categories = []
@@ -56,6 +60,7 @@ def get_or_create(
5660
db_connection: DbmsConnectionInfo,
5761
ttl: Optional[timedelta] = None,
5862
cloud_location: Optional[CloudLocation] = None,
63+
timeout: Optional[int] = None,
5964
) -> AuraGraphDataScience:
6065
db_runner = Neo4jQueryRunner.create_for_db(
6166
endpoint=db_connection.uri,
@@ -83,7 +88,8 @@ def get_or_create(
8388

8489
connection_url = session_details.bolt_connection_url()
8590
if session_details.status != "Ready":
86-
wait_result = self._aura_api.wait_for_session_running(session_id)
91+
max_wait_time = float(timeout) if timeout is not None else math.inf
92+
wait_result = self._aura_api.wait_for_session_running(session_id, max_wait_time=max_wait_time)
8793
if err := wait_result.error:
8894
raise RuntimeError(f"Failed to get or create session `{session_name}`: {err}")
8995

@@ -93,7 +99,11 @@ def get_or_create(
9399
password=password,
94100
)
95101

96-
return self._construct_client(session_id=session_id, session_connection=session_connection, db_runner=db_runner)
102+
return self._construct_client(
103+
session_id=session_id,
104+
session_connection=session_connection,
105+
db_runner=db_runner,
106+
)
97107

98108
def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str] = None) -> bool:
99109
if not session_name and not session_id:
@@ -160,13 +170,20 @@ def _get_or_create_session(
160170
# If cloud location is provided we go for self managed DBs path
161171
if cloud_location:
162172
return self._aura_api.get_or_create_session(
163-
name=session_name, pwd=pwd, memory=memory, ttl=ttl, cloud_location=cloud_location
173+
name=session_name,
174+
pwd=pwd,
175+
memory=memory,
176+
ttl=ttl,
177+
cloud_location=cloud_location,
164178
)
165179
else:
166180
return self._aura_api.get_or_create_session(name=session_name, dbid=dbid, pwd=pwd, memory=memory, ttl=ttl)
167181

168182
def _construct_client(
169-
self, session_id: str, session_connection: DbmsConnectionInfo, db_runner: Neo4jQueryRunner
183+
self,
184+
session_id: str,
185+
session_connection: DbmsConnectionInfo,
186+
db_runner: Neo4jQueryRunner,
170187
) -> AuraGraphDataScience:
171188
return AuraGraphDataScience.create(
172189
gds_session_connection_info=session_connection,

graphdatascience/session/gds_sessions.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def __init__(self, api_credentials: AuraAPICredentials) -> None:
5353
self._impl: DedicatedSessions = DedicatedSessions(aura_api)
5454

5555
def estimate(
56-
self, node_count: int, relationship_count: int, algorithm_categories: Optional[list[AlgorithmCategory]] = None
56+
self,
57+
node_count: int,
58+
relationship_count: int,
59+
algorithm_categories: Optional[list[AlgorithmCategory]] = None,
5760
) -> SessionMemory:
5861
"""
5962
Estimates the memory required for a session with the given node and relationship counts.
@@ -86,6 +89,7 @@ def get_or_create(
8689
db_connection: DbmsConnectionInfo,
8790
ttl: Optional[timedelta] = None,
8891
cloud_location: Optional[CloudLocation] = None,
92+
timeout: Optional[int] = None,
8993
) -> AuraGraphDataScience:
9094
"""
9195
Retrieves an existing session with the given session name and database connection,
@@ -98,13 +102,16 @@ def get_or_create(
98102
session_name (str): The name of the session.
99103
memory (SessionMemory): The size of the session specified by memory.
100104
db_connection (DbmsConnectionInfo): The database connection information.
101-
ttl: Optional[timedelta]: The sessions time to live after inactivity in seconds.
105+
ttl: (Optional[timedelta]): The sessions time to live after inactivity in seconds.
102106
cloud_location (Optional[CloudLocation]): The cloud location. Required if the GDS session is for a self-managed database.
107+
timeout (Optional[int]): Optional timeout (in seconds) when waiting for session to become ready. If unset the method will wait forever. If set and session does not become ready an exception will be raised. It is user responsibility to ensure resource gets cleaned up in this situation.
103108
104109
Returns:
105110
AuraGraphDataScience: The session.
106111
"""
107-
return self._impl.get_or_create(session_name, memory, db_connection, ttl=ttl, cloud_location=cloud_location)
112+
return self._impl.get_or_create(
113+
session_name, memory, db_connection, ttl=ttl, cloud_location=cloud_location, timeout=timeout
114+
)
108115

109116
def delete(self, *, session_name: Optional[str] = None, session_id: Optional[str] = None) -> bool:
110117
"""

0 commit comments

Comments
 (0)