Skip to content

Commit 9510041

Browse files
committed
Support spooled protocol
1 parent 40c3a4e commit 9510041

File tree

8 files changed

+321
-13
lines changed

8 files changed

+321
-13
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,19 @@ conn = connect(
469469
)
470470
```
471471

472+
## Spooled protocol
473+
474+
Enable the spooled protocol by specifying a supported encoding in the `encoding` parameter (equires a
475+
Trino server that supports the spooled protocol).
476+
477+
```python
478+
from trino.dbapi import connect
479+
480+
conn = connect(
481+
encoding="json+zstd"
482+
)
483+
```
484+
472485
## Transactions
473486

474487
The client runs by default in *autocommit* mode. To enable transactions, set

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@
8383
],
8484
python_requires=">=3.9",
8585
install_requires=[
86+
"lz4",
8687
"python-dateutil",
8788
"pytz",
8889
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q
8990
"requests>=2.31.0",
9091
"tzlocal",
92+
"zstandard",
9193
],
9294
extras_require={
9395
"all": all_require,

tests/integration/test_dbapi_integration.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@
2929
from trino.transaction import IsolationLevel
3030

3131

32-
@pytest.fixture
33-
def trino_connection(run_trino):
32+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
33+
def trino_connection(request, run_trino):
3434
host, port = run_trino
35+
encoding = request.param
3536

3637
yield trino.dbapi.Connection(
37-
host=host, port=port, user="test", source="test", max_attempts=1
38+
host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding
3839
)
3940

4041

@@ -1833,8 +1834,22 @@ def test_select_query_spooled_segments(trino_connection):
18331834
stop => 5,
18341835
step => 1)) n""")
18351836
rows = cur.fetchall()
1836-
# TODO: improve test
1837-
assert len(rows) > 0
1837+
assert len(rows) == 300875
1838+
for row in rows:
1839+
assert isinstance(row[0], int), f"Expected integer for orderkey, got {type(row[0])}"
1840+
assert isinstance(row[1], int), f"Expected integer for partkey, got {type(row[1])}"
1841+
assert isinstance(row[2], int), f"Expected integer for suppkey, got {type(row[2])}"
1842+
assert isinstance(row[3], int), f"Expected int for linenumber, got {type(row[3])}"
1843+
assert isinstance(row[4], float), f"Expected float for quantity, got {type(row[4])}"
1844+
assert isinstance(row[5], float), f"Expected float for extendedprice, got {type(row[5])}"
1845+
assert isinstance(row[6], float), f"Expected float for discount, got {type(row[6])}"
1846+
assert isinstance(row[7], float), f"Expected string for tax, got {type(row[7])}"
1847+
assert isinstance(row[8], str), f"Expected string for returnflag, got {type(row[8])}"
1848+
assert isinstance(row[9], str), f"Expected string for linestatus, got {type(row[9])}"
1849+
assert isinstance(row[10], date), f"Expected date for shipdate, got {type(row[10])}"
1850+
assert isinstance(row[11], date), f"Expected date for commitdate, got {type(row[11])}"
1851+
assert isinstance(row[12], date), f"Expected date for receiptdate, got {type(row[12])}"
1852+
assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}"
18381853

18391854

18401855
def get_cursor(legacy_prepared_statements, run_trino):

tests/integration/test_types_integration.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
from tests.integration.conftest import trino_version
1313

1414

15-
@pytest.fixture
16-
def trino_connection(run_trino):
15+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
16+
def trino_connection(request, run_trino):
1717
host, port = run_trino
18+
encoding = request.param
1819

1920
yield trino.dbapi.Connection(
20-
host=host, port=port, user="test", source="test", max_attempts=1
21+
host=host,
22+
port=port,
23+
user="test",
24+
source="test",
25+
max_attempts=1,
26+
encoding=encoding
2127
)
2228

2329

tests/unit/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_request_headers(mock_get_and_post):
9999
accept_encoding_value = "identity,deflate,gzip"
100100
client_info_header = constants.HEADER_CLIENT_INFO
101101
client_info_value = "some_client_info"
102+
encoding = "json+zstd"
102103

103104
with pytest.deprecated_call():
104105
req = TrinoRequest(
@@ -111,6 +112,7 @@ def test_request_headers(mock_get_and_post):
111112
catalog=catalog,
112113
schema=schema,
113114
timezone=timezone,
115+
encoding=encoding,
114116
headers={
115117
accept_encoding_header: accept_encoding_value,
116118
client_info_header: client_info_value,
@@ -145,7 +147,8 @@ def assert_headers(headers):
145147
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
146148
)
147149
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
148-
assert len(headers.keys()) == 13
150+
assert headers[constants.HEADER_ENCODING] == encoding
151+
assert len(headers.keys()) == 14
149152

150153
req.post("URL")
151154
_, post_kwargs = post.call_args

0 commit comments

Comments
 (0)