@@ -46,6 +46,8 @@ def create_for_db(
4646
4747 query_runner = Neo4jQueryRunner (
4848 driver ,
49+ Neo4jQueryRunner .parse_protocol (endpoint ),
50+ auth ,
4951 auto_close = True ,
5052 bookmarks = bookmarks ,
5153 config = config ,
@@ -54,8 +56,14 @@ def create_for_db(
5456 )
5557
5658 elif isinstance (endpoint , neo4j .Driver ):
59+ protocol = "neo4j+s" if endpoint .encrypted else "bolt"
5760 query_runner = Neo4jQueryRunner (
58- endpoint , auto_close = False , bookmarks = bookmarks , database = database , show_progress = show_progress
61+ endpoint ,
62+ protocol ,
63+ auto_close = False ,
64+ bookmarks = bookmarks ,
65+ database = database ,
66+ show_progress = show_progress ,
5967 )
6068 else :
6169 raise ValueError (f"Invalid endpoint type: { type (endpoint )} " )
@@ -76,6 +84,8 @@ def create_for_session(
7684
7785 query_runner = Neo4jQueryRunner (
7886 driver ,
87+ Neo4jQueryRunner .parse_protocol (endpoint ),
88+ auth ,
7989 auto_close = True ,
8090 show_progress = show_progress ,
8191 bookmarks = None ,
@@ -94,9 +104,18 @@ def _configure_aura(config: dict[str, Any]) -> None:
94104 config ["keep_alive" ] = True
95105 config ["max_connection_pool_size" ] = 50
96106
107+ @staticmethod
108+ def parse_protocol (endpoint : str ) -> str :
109+ protocol_match = re .match (r"^([^:]+)://" , endpoint )
110+ if not protocol_match :
111+ raise ValueError (f"Invalid endpoint URI format: { endpoint } " )
112+ return protocol_match .group (1 )
113+
97114 def __init__ (
98115 self ,
99116 driver : neo4j .Driver ,
117+ protocol : str ,
118+ auth : Optional [tuple [str , str ]] = None ,
100119 config : dict [str , Any ] = {},
101120 database : Optional [str ] = neo4j .DEFAULT_DATABASE ,
102121 auto_close : bool = False ,
@@ -105,6 +124,8 @@ def __init__(
105124 instance_description : str = "Neo4j DBMS" ,
106125 ):
107126 self ._driver = driver
127+ self ._protocol = protocol
128+ self ._auth = auth
108129 self ._config = config
109130 self ._auto_close = auto_close
110131 self ._database = database
@@ -279,6 +300,22 @@ def create_graph_constructor(
279300 def set_show_progress (self , show_progress : bool ) -> None :
280301 self ._show_progress = show_progress
281302
303+ def clone (self , host : str , port : int ) -> QueryRunner :
304+ endpoint = "{}://{}:{}" .format (self ._protocol , host , port )
305+ driver = neo4j .GraphDatabase .driver (endpoint , auth = self ._auth , ** self .driver_config ())
306+
307+ return Neo4jQueryRunner (
308+ driver ,
309+ self ._protocol ,
310+ self ._auth ,
311+ self ._config ,
312+ self ._database ,
313+ self ._auto_close ,
314+ self ._bookmarks ,
315+ self ._show_progress ,
316+ self ._instance_description ,
317+ )
318+
282319 @staticmethod
283320 def handle_driver_exception (session : neo4j .Session , e : Exception ) -> None :
284321 reg_gds_hit = re .search (
0 commit comments