Skip to content

Commit 0ee1532

Browse files
authored
Merge pull request #887 from DarthMax/session_remote_projection_fix
Session remote projection fix
2 parents 964b156 + af5bc45 commit 0ee1532

File tree

8 files changed

+83
-5
lines changed

8 files changed

+83
-5
lines changed

changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
## Bug fixes
99

10+
* Fixed a bug where remote projections would fail when the database is clustered
11+
1012
## Improvements
1113

1214
* Allow creating sessions of size `512GB`.

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,13 @@ def close(self) -> None:
211211
self._fallback_query_runner.close()
212212
self._gds_arrow_client.close()
213213

214+
def clone(self, host: str, port: int) -> "QueryRunner":
215+
return ArrowQueryRunner(
216+
self._gds_arrow_client,
217+
self._fallback_query_runner.clone(host, port),
218+
self._server_version,
219+
)
220+
214221
def fallback_query_runner(self) -> QueryRunner:
215222
return self._fallback_query_runner
216223

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def create_for_db(
4646

4747
query_runner = Neo4jQueryRunner(
4848
driver,
49+
Neo4jQueryRunner.parse_protocol(endpoint),
50+
auth,
4951
auto_close=True,
5052
bookmarks=bookmarks,
5153
config=config,
@@ -54,8 +56,14 @@ def create_for_db(
5456
)
5557

5658
elif isinstance(endpoint, neo4j.Driver):
59+
protocol = "neo4j+s" if endpoint.encrypted else "bolt"
5760
query_runner = Neo4jQueryRunner(
58-
endpoint, auto_close=False, bookmarks=bookmarks, database=database, show_progress=show_progress
61+
endpoint,
62+
protocol,
63+
auto_close=False,
64+
bookmarks=bookmarks,
65+
database=database,
66+
show_progress=show_progress,
5967
)
6068
else:
6169
raise ValueError(f"Invalid endpoint type: {type(endpoint)}")
@@ -76,6 +84,8 @@ def create_for_session(
7684

7785
query_runner = Neo4jQueryRunner(
7886
driver,
87+
Neo4jQueryRunner.parse_protocol(endpoint),
88+
auth,
7989
auto_close=True,
8090
show_progress=show_progress,
8191
bookmarks=None,
@@ -94,9 +104,18 @@ def _configure_aura(config: dict[str, Any]) -> None:
94104
config["keep_alive"] = True
95105
config["max_connection_pool_size"] = 50
96106

107+
@staticmethod
108+
def parse_protocol(endpoint: str) -> str:
109+
protocol_match = re.match(r"^([^:]+)://", endpoint)
110+
if not protocol_match:
111+
raise ValueError(f"Invalid endpoint URI format: {endpoint}")
112+
return protocol_match.group(1)
113+
97114
def __init__(
98115
self,
99116
driver: neo4j.Driver,
117+
protocol: str,
118+
auth: Optional[tuple[str, str]] = None,
100119
config: dict[str, Any] = {},
101120
database: Optional[str] = neo4j.DEFAULT_DATABASE,
102121
auto_close: bool = False,
@@ -105,6 +124,8 @@ def __init__(
105124
instance_description: str = "Neo4j DBMS",
106125
):
107126
self._driver = driver
127+
self._protocol = protocol
128+
self._auth = auth
108129
self._config = config
109130
self._auto_close = auto_close
110131
self._database = database
@@ -279,6 +300,22 @@ def create_graph_constructor(
279300
def set_show_progress(self, show_progress: bool) -> None:
280301
self._show_progress = show_progress
281302

303+
def clone(self, host: str, port: int) -> QueryRunner:
304+
endpoint = "{}://{}:{}".format(self._protocol, host, port)
305+
driver = neo4j.GraphDatabase.driver(endpoint, auth=self._auth, **self.driver_config())
306+
307+
return Neo4jQueryRunner(
308+
driver,
309+
self._protocol,
310+
self._auth,
311+
self._config,
312+
self._database,
313+
self._auto_close,
314+
self._bookmarks,
315+
self._show_progress,
316+
self._instance_description,
317+
)
318+
282319
@staticmethod
283320
def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
284321
reg_gds_hit = re.search(

graphdatascience/query_runner/protocol/project_protocols.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def run_projection(
121121
query_runner: QueryRunner,
122122
endpoint: str,
123123
params: CallParameters,
124-
terminationFlag: TerminationFlag,
124+
termination_flag: TerminationFlag,
125125
yields: Optional[list[str]] = None,
126126
database: Optional[str] = None,
127127
logging: bool = False,
@@ -132,16 +132,28 @@ def is_not_done(result: DataFrame) -> bool:
132132

133133
logger = getLogger()
134134

135+
# We need to pin the driver to a specific cluster member
136+
response = query_runner.call_procedure(
137+
ProtocolVersion.V3.versioned_procedure_name(endpoint), params, yields, database, logging, False
138+
).squeeze()
139+
member_host = response["host"]
140+
member_port = response["port"] if ("port" in response.index) else 7687
141+
projection_query_runner = query_runner.clone(member_host, member_port)
142+
135143
@retry(
136144
reraise=True,
137145
before=before_log(f"Projection (graph: `{params['graph_name']}`)", logger, DEBUG),
138146
retry=retry_if_result(is_not_done),
139147
wait=wait_incrementing(start=0.2, increment=0.2, max=2),
140148
)
141149
def project_fn() -> DataFrame:
142-
terminationFlag.assert_running()
143-
return query_runner.call_procedure(
150+
termination_flag.assert_running()
151+
return projection_query_runner.call_procedure(
144152
ProtocolVersion.V3.versioned_procedure_name(endpoint), params, yields, database, logging, False
145153
)
146154

147-
return project_fn()
155+
projection_result = project_fn()
156+
157+
projection_query_runner.close()
158+
159+
return projection_result

graphdatascience/query_runner/query_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,9 @@ def last_bookmarks(self) -> Optional[Any]:
8080
def set_show_progress(self, show_progress: bool) -> None:
8181
pass
8282

83+
@abstractmethod
84+
def clone(self, host: str, port: int) -> "QueryRunner":
85+
pass
86+
8387
def set_server_version(self, _: ServerVersion) -> None:
8488
pass

graphdatascience/query_runner/session_query_runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,14 @@ def set_show_progress(self, show_progress: bool) -> None:
120120
self._show_progress = show_progress
121121
self._gds_query_runner.set_show_progress(show_progress)
122122

123+
def clone(self, host: str, port: int) -> QueryRunner:
124+
return SessionQueryRunner(
125+
self._gds_query_runner,
126+
self._db_query_runner.clone(host, port),
127+
self._gds_arrow_client,
128+
self._show_progress,
129+
)
130+
123131
def close(self) -> None:
124132
self._gds_arrow_client.close()
125133
self._gds_query_runner.close()

graphdatascience/query_runner/standalone_session_query_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Any, Optional
24

35
from pandas import DataFrame
@@ -74,3 +76,6 @@ def last_bookmarks(self) -> Optional[Any]:
7476

7577
def set_server_version(self, _: ServerVersion) -> None:
7678
super().set_server_version(_)
79+
80+
def clone(self, host: str, port: int) -> QueryRunner:
81+
return self

graphdatascience/tests/unit/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def create_graph_constructor(
126126
self, graph_name, concurrency, undirected_relationship_types, self._server_version
127127
)
128128

129+
def clone(self, host: str, port: int) -> QueryRunner:
130+
return self
131+
129132
def set__mock_result(self, result: DataFrame) -> None:
130133
self._result_map.clear()
131134
self._result_map[""] = result

0 commit comments

Comments
 (0)