Skip to content

Commit 6f3e023

Browse files
authored
Merge pull request #729 from jsjasonseba/feature/pre-execution-queries
Add pre_execution_queries parameter to run setup queries before main query on Postgres and MySQL source
2 parents f633123 + 0925e85 commit 6f3e023

File tree

19 files changed

+470
-49
lines changed

19 files changed

+470
-49
lines changed

connectorx-cpp/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ pub unsafe extern "C" fn connectorx_scan(conn: *const c_char, query: *const c_ch
176176
let conn_str = unsafe { CStr::from_ptr(conn) }.to_str().unwrap();
177177
let query_str = unsafe { CStr::from_ptr(query) }.to_str().unwrap();
178178
let source_conn = SourceConn::try_from(conn_str).unwrap();
179-
let record_batches = get_arrow(&source_conn, None, &[CXQuery::from(query_str)])
179+
let record_batches = get_arrow(&source_conn, None, &[CXQuery::from(query_str)], None)
180180
.unwrap()
181181
.arrow()
182182
.unwrap();
@@ -281,7 +281,7 @@ pub unsafe extern "C" fn connectorx_scan_iter(
281281
}
282282

283283
let arrow_iter: Box<dyn RecordBatchIterator> =
284-
new_record_batch_iter(&source_conn, None, query_vec.as_slice(), batch_size);
284+
new_record_batch_iter(&source_conn, None, query_vec.as_slice(), batch_size, None);
285285

286286
Box::into_raw(Box::new(arrow_iter))
287287
}

connectorx-python/connectorx/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def read_sql_pandas(
131131
partition_on: str | None = None,
132132
partition_range: tuple[int, int] | None = None,
133133
partition_num: int | None = None,
134+
pre_execution_queries: list[str] | str | None = None,
134135
) -> pd.DataFrame:
135136
"""
136137
Run the SQL query, download the data from database into a dataframe.
@@ -160,6 +161,7 @@ def read_sql_pandas(
160161
partition_range=partition_range,
161162
partition_num=partition_num,
162163
index_col=index_col,
164+
pre_execution_queries=pre_execution_queries,
163165
)
164166

165167

@@ -174,6 +176,7 @@ def read_sql(
174176
partition_range: tuple[int, int] | None = None,
175177
partition_num: int | None = None,
176178
index_col: str | None = None,
179+
pre_execution_query: list[str] | str | None = None,
177180
) -> pd.DataFrame: ...
178181

179182

@@ -188,6 +191,7 @@ def read_sql(
188191
partition_range: tuple[int, int] | None = None,
189192
partition_num: int | None = None,
190193
index_col: str | None = None,
194+
pre_execution_query: list[str] | str | None = None,
191195
) -> pd.DataFrame: ...
192196

193197

@@ -202,6 +206,7 @@ def read_sql(
202206
partition_range: tuple[int, int] | None = None,
203207
partition_num: int | None = None,
204208
index_col: str | None = None,
209+
pre_execution_query: list[str] | str | None = None,
205210
) -> pa.Table: ...
206211

207212

@@ -216,6 +221,7 @@ def read_sql(
216221
partition_range: tuple[int, int] | None = None,
217222
partition_num: int | None = None,
218223
index_col: str | None = None,
224+
pre_execution_query: list[str] | str | None = None,
219225
) -> mpd.DataFrame: ...
220226

221227

@@ -230,6 +236,7 @@ def read_sql(
230236
partition_range: tuple[int, int] | None = None,
231237
partition_num: int | None = None,
232238
index_col: str | None = None,
239+
pre_execution_query: list[str] | str | None = None,
233240
) -> dd.DataFrame: ...
234241

235242

@@ -244,6 +251,7 @@ def read_sql(
244251
partition_range: tuple[int, int] | None = None,
245252
partition_num: int | None = None,
246253
index_col: str | None = None,
254+
pre_execution_query: list[str] | str | None = None,
247255
) -> pl.DataFrame: ...
248256

249257

@@ -260,6 +268,7 @@ def read_sql(
260268
partition_num: int | None = None,
261269
index_col: str | None = None,
262270
strategy: str | None = None,
271+
pre_execution_query: list[str] | str | None = None,
263272
) -> pd.DataFrame | mpd.DataFrame | dd.DataFrame | pl.DataFrame | pa.Table:
264273
"""
265274
Run the SQL query, download the data from database into a dataframe.
@@ -285,6 +294,9 @@ def read_sql(
285294
the index column to set; only applicable for return type "pandas", "modin", "dask".
286295
strategy
287296
strategy of rewriting the federated query for join pushdown
297+
pre_execution_query
298+
SQL query or list of SQL queries executed before main query; can be used to set runtime
299+
configurations using SET statements; only applicable for source "Postgres" and "MySQL".
288300

289301
Examples
290302
========
@@ -358,6 +370,13 @@ def read_sql(
358370
raise ValueError("Partition on multiple queries is not supported.")
359371
else:
360372
raise ValueError("query must be either str or a list of str")
373+
374+
if isinstance(pre_execution_query, list):
375+
pre_execution_queries = [remove_ending_semicolon(subquery) for subquery in pre_execution_query]
376+
elif isinstance(pre_execution_query, str):
377+
pre_execution_queries = [remove_ending_semicolon(pre_execution_query)]
378+
else:
379+
pre_execution_queries = None
361380

362381
conn, protocol = rewrite_conn(conn, protocol)
363382

@@ -370,6 +389,7 @@ def read_sql(
370389
queries=queries,
371390
protocol=protocol,
372391
partition_query=partition_query,
392+
pre_execution_queries=pre_execution_queries,
373393
)
374394
df = reconstruct_pandas(result)
375395

@@ -392,6 +412,7 @@ def read_sql(
392412
queries=queries,
393413
protocol=protocol,
394414
partition_query=partition_query,
415+
pre_execution_queries=pre_execution_queries,
395416
)
396417
df = reconstruct_arrow(result)
397418
if return_type in {"polars"}:

connectorx-python/connectorx/connectorx.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def read_sql(
2525
protocol: str | None,
2626
queries: list[str] | None,
2727
partition_query: dict[str, Any] | None,
28+
pre_execution_queries: list[str] | None,
2829
) -> _DataframeInfos: ...
2930
@overload
3031
def read_sql(
@@ -33,6 +34,7 @@ def read_sql(
3334
protocol: str | None,
3435
queries: list[str] | None,
3536
partition_query: dict[str, Any] | None,
37+
pre_execution_queries: list[str] | None,
3638
) -> _ArrowInfos: ...
3739
def partition_sql(conn: str, partition_query: dict[str, Any]) -> list[str]: ...
3840
def read_sql2(sql: str, db_map: dict[str, str]) -> _ArrowInfos: ...

connectorx-python/connectorx/tests/test_mysql.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,52 @@ def test_mysql_cte(mysql_url: str) -> None:
478478

479479
def test_connection_url(mysql_url: str) -> None:
480480
test_mysql_cte(ConnectionUrl(mysql_url))
481+
482+
def test_mysql_single_pre_execution_queries(mysql_url: str) -> None:
483+
pre_execution_query = "SET SESSION max_execution_time = 2151"
484+
query = "SELECT @@SESSION.max_execution_time AS max_execution_time"
485+
df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query)
486+
expected = pd.DataFrame(
487+
index=range(1),
488+
data={
489+
"max_execution_time": pd.Series([2151], dtype="float64")
490+
},
491+
)
492+
assert_frame_equal(df, expected, check_names=True)
493+
494+
495+
def test_mysql_multiple_pre_execution_queries(mysql_url: str) -> None:
496+
pre_execution_query = [
497+
"SET SESSION max_execution_time = 2151",
498+
"SET SESSION wait_timeout = 2252",
499+
]
500+
query = "SELECT @@SESSION.max_execution_time AS max_execution_time, @@SESSION.wait_timeout AS wait_timeout"
501+
df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query)
502+
expected = pd.DataFrame(
503+
index=range(1),
504+
data={
505+
"max_execution_time": pd.Series([2151], dtype="float64"),
506+
"wait_timeout": pd.Series([2252], dtype="float64")
507+
},
508+
)
509+
assert_frame_equal(df, expected, check_names=True)
510+
511+
def test_mysql_partitioned_pre_execution_queries(mysql_url: str) -> None:
512+
pre_execution_query = [
513+
"SET SESSION max_execution_time = 2151",
514+
"SET SESSION wait_timeout = 2252",
515+
]
516+
query = [
517+
'SELECT "max_execution_time" AS name, @@SESSION.max_execution_time AS setting',
518+
'SELECT "wait_timeout" AS name, @@SESSION.wait_timeout AS setting'
519+
]
520+
df = read_sql(mysql_url, query, pre_execution_query=pre_execution_query).sort_values(by=['name'])
521+
expected = pd.DataFrame(
522+
index=range(2),
523+
data={
524+
"name": pd.Series(["max_execution_time", "wait_timeout"], dtype="str"),
525+
"setting": pd.Series([2151, 2252], dtype="float64"),
526+
},
527+
).sort_values(by=['name'])
528+
529+
assert_frame_equal(df, expected, check_like=False)

connectorx-python/connectorx/tests/test_postgres.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,4 +1008,53 @@ def test_postgres_partition_with_orderby_limit_desc(postgres_url: str) -> None:
10081008
},
10091009
)
10101010
df.sort_values(by="test_int", inplace=True, ignore_index=True)
1011-
assert_frame_equal(df, expected, check_names=True)
1011+
assert_frame_equal(df, expected, check_names=True)
1012+
1013+
def test_postgres_single_pre_execution_queries(postgres_url: str) -> None:
1014+
pre_execution_query = "SET SESSION statement_timeout = 2151"
1015+
query = "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) FROM pg_settings WHERE name = 'statement_timeout'"
1016+
df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query)
1017+
expected = pd.DataFrame(
1018+
index=range(1),
1019+
data={
1020+
"name": pd.Series(["statement_timeout"], dtype="str"),
1021+
"setting": pd.Series([2151], dtype="Int64"),
1022+
},
1023+
)
1024+
assert_frame_equal(df, expected, check_names=True)
1025+
1026+
def test_postgres_multiple_pre_execution_queries(postgres_url: str) -> None:
1027+
pre_execution_query = [
1028+
"SET SESSION statement_timeout = 2151",
1029+
"SET SESSION idle_in_transaction_session_timeout = 2252",
1030+
]
1031+
query = "SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) FROM pg_settings WHERE name IN ('statement_timeout', 'idle_in_transaction_session_timeout') ORDER BY name"
1032+
df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query)
1033+
expected = pd.DataFrame(
1034+
index=range(2),
1035+
data={
1036+
"name": pd.Series(["idle_in_transaction_session_timeout", "statement_timeout"], dtype="str"),
1037+
"setting": pd.Series([2252, 2151], dtype="Int64"),
1038+
},
1039+
)
1040+
assert_frame_equal(df, expected, check_names=True)
1041+
1042+
def test_postgres_partitioned_pre_execution_queries(postgres_url: str) -> None:
1043+
pre_execution_query = [
1044+
"SET SESSION statement_timeout = 2151",
1045+
"SET SESSION idle_in_transaction_session_timeout = 2252",
1046+
]
1047+
query = [
1048+
"SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'statement_timeout'",
1049+
"SELECT CAST(name AS TEXT) AS name, CAST(setting AS INTEGER) AS setting FROM pg_settings WHERE name = 'idle_in_transaction_session_timeout'"
1050+
]
1051+
1052+
df = read_sql(postgres_url, query, pre_execution_query=pre_execution_query).sort_values(by=['name'])
1053+
expected = pd.DataFrame(
1054+
index=range(2),
1055+
data={
1056+
"name": pd.Series(["statement_timeout", "idle_in_transaction_session_timeout"], dtype="str"),
1057+
"setting": pd.Series([2151, 2252], dtype="Int64"),
1058+
},
1059+
).sort_values(by=['name'])
1060+
assert_frame_equal(df, expected, check_names=True)

connectorx-python/src/arrow.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ pub fn write_arrow<'py>(
1515
source_conn: &SourceConn,
1616
origin_query: Option<String>,
1717
queries: &[CXQuery<String>],
18+
pre_execution_queries: Option<&[String]>,
1819
) -> Bound<'py, PyAny> {
1920
let ptrs = py.allow_threads(
2021
|| -> Result<(Vec<String>, Vec<Vec<(uintptr_t, uintptr_t)>>), ConnectorXPythonError> {
21-
let destination = get_arrow(source_conn, origin_query, queries)?;
22+
let destination = get_arrow(source_conn, origin_query, queries, pre_execution_queries)?;
2223
let rbs = destination.arrow()?;
2324
Ok(to_ptrs(rbs))
2425
},

connectorx-python/src/cx_read_sql.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub fn read_sql<'py>(
3838
protocol: Option<&str>,
3939
queries: Option<Vec<String>>,
4040
partition_query: Option<PyPartitionQuery>,
41+
pre_execution_queries: Option<Vec<String>>,
4142
) -> PyResult<Bound<'py, PyAny>> {
4243
let source_conn = parse_source(conn, protocol).map_err(|e| ConnectorXPythonError::from(e))?;
4344
let (queries, origin_query) = match (queries, partition_query) {
@@ -62,12 +63,14 @@ pub fn read_sql<'py>(
6263
&source_conn,
6364
origin_query,
6465
&queries,
66+
pre_execution_queries.as_deref(),
6567
)?),
6668
"arrow" => Ok(crate::arrow::write_arrow(
6769
py,
6870
&source_conn,
6971
origin_query,
7072
&queries,
73+
pre_execution_queries.as_deref(),
7174
)?),
7275
_ => Err(PyValueError::new_err(format!(
7376
"return type should be 'pandas' or 'arrow', got '{}'",

connectorx-python/src/lib.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,25 @@ fn connectorx(_: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
3939
}
4040

4141
#[pyfunction]
42-
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None))]
42+
#[pyo3(signature = (conn, return_type, protocol=None, queries=None, partition_query=None, pre_execution_queries=None))]
4343
pub fn read_sql<'py>(
4444
py: Python<'py>,
4545
conn: &str,
4646
return_type: &str,
4747
protocol: Option<&str>,
4848
queries: Option<Vec<String>>,
4949
partition_query: Option<cx_read_sql::PyPartitionQuery>,
50+
pre_execution_queries: Option<Vec<String>>,
5051
) -> PyResult<Bound<'py, PyAny>> {
51-
cx_read_sql::read_sql(py, conn, return_type, protocol, queries, partition_query)
52+
cx_read_sql::read_sql(
53+
py,
54+
conn,
55+
return_type,
56+
protocol,
57+
queries,
58+
partition_query,
59+
pre_execution_queries,
60+
)
5261
}
5362

5463
#[pyfunction]

connectorx-python/src/pandas/dispatcher.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ where
4141
}
4242
}
4343

44+
pub fn set_pre_execution_queries(&mut self, pre_execution_queries: Option<&[String]>) {
45+
self.src.set_pre_execution_queries(pre_execution_queries);
46+
}
47+
4448
/// Start the data loading process.
4549
pub fn run(mut self, py: Python<'py>) -> Result<Bound<'py, PyAny>, TP::Error> {
4650
debug!("Run dispatcher");

0 commit comments

Comments
 (0)