Skip to content

Commit 74ee056

Browse files
committed
Support spooled protocol
1 parent 59dc4db commit 74ee056

File tree

8 files changed

+374
-17
lines changed

8 files changed

+374
-17
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,31 @@ conn = connect(
469469
)
470470
```
471471

472+
## Spooled protocol
473+
474+
The client spooling protocol requires [a Trino server that supports the spooling protocol and
475+
has a valid spooling protocol configuration](https://trino.io/docs/current/client/client-protocol.html#spooling-protocol).
476+
477+
Enable the spooling protocol by specifying a supported encoding in the `encoding` parameter:
478+
479+
```python
480+
from trino.dbapi import connect
481+
482+
conn = connect(
483+
encoding="json+zstd"
484+
)
485+
```
486+
487+
or a list of supported encodings:
488+
489+
```python
490+
from trino.dbapi import connect
491+
492+
conn = connect(
493+
encoding=["json+zstd", "json"]
494+
)
495+
```
496+
472497
## Transactions
473498

474499
The client runs by default in *autocommit* mode. To enable transactions, set

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@
8383
],
8484
python_requires=">=3.9",
8585
install_requires=[
86+
"lz4",
8687
"python-dateutil",
8788
"pytz",
8889
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q
8990
"requests>=2.31.0",
9091
"tzlocal",
92+
"zstandard",
9193
],
9294
extras_require={
9395
"all": all_require,

tests/integration/test_dbapi_integration.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@
3030
from trino.transaction import IsolationLevel
3131

3232

33-
@pytest.fixture
34-
def trino_connection(run_trino):
33+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
34+
def trino_connection(request, run_trino):
3535
host, port = run_trino
36+
encoding = request.param
3637

3738
yield trino.dbapi.Connection(
38-
host=host, port=port, user="test", source="test", max_attempts=1
39+
host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding
3940
)
4041

4142

@@ -1823,8 +1824,8 @@ def test_prepared_statement_capability_autodetection(legacy_prepared_statements,
18231824

18241825

18251826
@pytest.mark.skipif(
1826-
trino_version() <= '464',
1827-
reason="spooled protocol was introduced in version 464"
1827+
trino_version() <= 466,
1828+
reason="spooling protocol was introduced in version 466"
18281829
)
18291830
def test_select_query_spooled_segments(trino_connection):
18301831
cur = trino_connection.cursor()
@@ -1834,8 +1835,22 @@ def test_select_query_spooled_segments(trino_connection):
18341835
stop => 5,
18351836
step => 1)) n""")
18361837
rows = cur.fetchall()
1837-
# TODO: improve test
1838-
assert len(rows) > 0
1838+
assert len(rows) == 300875
1839+
for row in rows:
1840+
assert isinstance(row[0], int), f"Expected integer for orderkey, got {type(row[0])}"
1841+
assert isinstance(row[1], int), f"Expected integer for partkey, got {type(row[1])}"
1842+
assert isinstance(row[2], int), f"Expected integer for suppkey, got {type(row[2])}"
1843+
assert isinstance(row[3], int), f"Expected int for linenumber, got {type(row[3])}"
1844+
assert isinstance(row[4], float), f"Expected float for quantity, got {type(row[4])}"
1845+
assert isinstance(row[5], float), f"Expected float for extendedprice, got {type(row[5])}"
1846+
assert isinstance(row[6], float), f"Expected float for discount, got {type(row[6])}"
1847+
assert isinstance(row[7], float), f"Expected string for tax, got {type(row[7])}"
1848+
assert isinstance(row[8], str), f"Expected string for returnflag, got {type(row[8])}"
1849+
assert isinstance(row[9], str), f"Expected string for linestatus, got {type(row[9])}"
1850+
assert isinstance(row[10], date), f"Expected date for shipdate, got {type(row[10])}"
1851+
assert isinstance(row[11], date), f"Expected date for commitdate, got {type(row[11])}"
1852+
assert isinstance(row[12], date), f"Expected date for receiptdate, got {type(row[12])}"
1853+
assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}"
18391854

18401855

18411856
def get_cursor(legacy_prepared_statements, run_trino):

tests/integration/test_types_integration.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
from tests.integration.conftest import trino_version
1313

1414

15-
@pytest.fixture
16-
def trino_connection(run_trino):
15+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
16+
def trino_connection(request, run_trino):
1717
host, port = run_trino
18+
encoding = request.param
1819

1920
yield trino.dbapi.Connection(
20-
host=host, port=port, user="test", source="test", max_attempts=1
21+
host=host,
22+
port=port,
23+
user="test",
24+
source="test",
25+
max_attempts=1,
26+
encoding=encoding
2127
)
2228

2329

tests/unit/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_request_headers(mock_get_and_post):
9999
accept_encoding_value = "identity,deflate,gzip"
100100
client_info_header = constants.HEADER_CLIENT_INFO
101101
client_info_value = "some_client_info"
102+
encoding = "json+zstd"
102103

103104
with pytest.deprecated_call():
104105
req = TrinoRequest(
@@ -111,6 +112,7 @@ def test_request_headers(mock_get_and_post):
111112
catalog=catalog,
112113
schema=schema,
113114
timezone=timezone,
115+
encoding=encoding,
114116
headers={
115117
accept_encoding_header: accept_encoding_value,
116118
client_info_header: client_info_value,
@@ -145,7 +147,8 @@ def assert_headers(headers):
145147
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
146148
)
147149
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
148-
assert len(headers.keys()) == 13
150+
assert headers[constants.HEADER_ENCODING] == encoding
151+
assert len(headers.keys()) == 14
149152

150153
req.post("URL")
151154
_, post_kwargs = post.call_args

0 commit comments

Comments
 (0)