Skip to content

Commit bfa300a

Browse files
committed
Support spooled protocol
1 parent 1e6555d commit bfa300a

File tree

7 files changed

+215
-11
lines changed

7 files changed

+215
-11
lines changed

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@
8383
],
8484
python_requires=">=3.9",
8585
install_requires=[
86+
"lz4",
8687
"python-dateutil",
8788
"pytz",
8889
# requests CVE https://github.com/advisories/GHSA-j8r2-6x86-q33q
8990
"requests>=2.31.0",
91+
"typing_extensions",
9092
"tzlocal",
93+
"zstandard",
9194
],
9295
extras_require={
9396
"all": all_require,

tests/integration/test_dbapi_integration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@
2929
from trino.transaction import IsolationLevel
3030

3131

32-
@pytest.fixture
33-
def trino_connection(run_trino):
32+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
33+
def trino_connection(request, run_trino):
3434
host, port = run_trino
35+
encoding = request.param
3536

3637
yield trino.dbapi.Connection(
37-
host=host, port=port, user="test", source="test", max_attempts=1
38+
host=host, port=port, user="test", source="test", max_attempts=1, encoding=encoding
3839
)
3940

4041

tests/integration/test_types_integration.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
from tests.integration.conftest import trino_version
1313

1414

15-
@pytest.fixture
16-
def trino_connection(run_trino):
15+
@pytest.fixture(params=[None, "json+zstd", "json+lz4", "json"])
16+
def trino_connection(request, run_trino):
1717
host, port = run_trino
18+
encoding = request.param
1819

1920
yield trino.dbapi.Connection(
20-
host=host, port=port, user="test", source="test", max_attempts=1
21+
host=host,
22+
port=port,
23+
user="test",
24+
source="test",
25+
max_attempts=1,
26+
encoding=encoding
2127
)
2228

2329

tests/unit/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def test_request_headers(mock_get_and_post):
9999
accept_encoding_value = "identity,deflate,gzip"
100100
client_info_header = constants.HEADER_CLIENT_INFO
101101
client_info_value = "some_client_info"
102+
encoding = "json+zstd"
102103

103104
with pytest.deprecated_call():
104105
req = TrinoRequest(
@@ -111,6 +112,7 @@ def test_request_headers(mock_get_and_post):
111112
catalog=catalog,
112113
schema=schema,
113114
timezone=timezone,
115+
encoding=encoding,
114116
headers={
115117
accept_encoding_header: accept_encoding_value,
116118
client_info_header: client_info_value,
@@ -145,7 +147,8 @@ def assert_headers(headers):
145147
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
146148
)
147149
assert headers["User-Agent"] == f"{constants.CLIENT_NAME}/{__version__}"
148-
assert len(headers.keys()) == 13
150+
assert headers[constants.HEADER_ENCODING] == encoding
151+
assert len(headers.keys()) == 14
149152

150153
req.post("URL")
151154
_, post_kwargs = post.call_args

trino/client.py

Lines changed: 192 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,28 @@
3434
"""
3535
from __future__ import annotations
3636

37+
import base64
3738
import copy
3839
import functools
40+
import json
3941
import os
4042
import random
4143
import re
4244
import threading
4345
import urllib.parse
4446
import warnings
47+
from abc import abstractmethod
4548
from dataclasses import dataclass
4649
from datetime import datetime
4750
from email.utils import parsedate_to_datetime
4851
from time import sleep
49-
from typing import Any, Dict, List, Optional, Tuple, Union
52+
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, cast
5053
from zoneinfo import ZoneInfo
5154

55+
import lz4.block
5256
import requests
57+
import zstandard
58+
from typing_extensions import TypedDict
5359
from tzlocal import get_localzone_name # type: ignore
5460

5561
import trino.logging
@@ -107,6 +113,7 @@ class ClientSession:
107113
:param roles: roles for the current session. Some connectors do not
108114
support role management. See connector documentation for more details.
109115
:param timezone: The timezone for query processing. Defaults to the system's local timezone.
116+
:param encoding: The encoding for the spooled protocol. Defaults to None.
110117
"""
111118

112119
def __init__(
@@ -123,6 +130,7 @@ def __init__(
123130
client_tags: List[str] = None,
124131
roles: Union[Dict[str, str], str] = None,
125132
timezone: str = None,
133+
encoding: str = None,
126134
):
127135
self._user = user
128136
self._authorization_user = authorization_user
@@ -140,6 +148,7 @@ def __init__(
140148
self._timezone = timezone or get_localzone_name()
141149
if timezone: # Check timezone validity
142150
ZoneInfo(timezone)
151+
self._encoding = encoding
143152

144153
@property
145154
def user(self):
@@ -235,6 +244,11 @@ def timezone(self):
235244
with self._object_lock:
236245
return self._timezone
237246

247+
@property
248+
def encoding(self):
249+
with self._object_lock:
250+
return self._encoding
251+
238252
def _format_roles(self, roles):
239253
if isinstance(roles, str):
240254
roles = {"system": roles}
@@ -299,7 +313,7 @@ class TrinoStatus:
299313
next_uri: Optional[str]
300314
update_type: Optional[str]
301315
update_count: Optional[int]
302-
rows: List[Any]
316+
rows: Union[List[Any], Dict[str, Any]]
303317
columns: List[Any]
304318

305319
def __repr__(self):
@@ -462,6 +476,7 @@ def http_headers(self) -> Dict[str, str]:
462476
headers[constants.HEADER_USER] = self._client_session.user
463477
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
464478
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
479+
headers[constants.HEADER_ENCODING] = self._client_session.encoding
465480
headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME'
466481
headers["user-agent"] = f"{constants.CLIENT_NAME}/{__version__}"
467482
if len(self._client_session.roles.values()):
@@ -522,6 +537,8 @@ def max_attempts(self, value) -> None:
522537
self._get = self._http_session.get
523538
self._post = self._http_session.post
524539
self._delete = self._http_session.delete
540+
self._prepare_request = self._http_session.prepare_request
541+
self._send = self._http_session.send
525542
return
526543

527544
with_retry = _retry_with(
@@ -537,6 +554,8 @@ def max_attempts(self, value) -> None:
537554
self._get = with_retry(self._http_session.get)
538555
self._post = with_retry(self._http_session.post)
539556
self._delete = with_retry(self._http_session.delete)
557+
self._prepare_request = self._http_session.prepare_request
558+
self._send = with_retry(self._http_session.send)
540559

541560
def get_url(self, path) -> str:
542561
return "{protocol}://{host}:{port}{path}".format(
@@ -835,7 +854,7 @@ def _update_state(self, status):
835854
if status.columns:
836855
self._columns = status.columns
837856

838-
def fetch(self) -> List[List[Any]]:
857+
def fetch(self) -> List[Union[List[Any]], Any]:
839858
"""Continue fetching data for the current query_id"""
840859
try:
841860
response = self._request.get(self._request.next_uri)
@@ -849,7 +868,30 @@ def fetch(self) -> List[List[Any]]:
849868
if not self._row_mapper:
850869
return []
851870

852-
return self._row_mapper.map(status.rows)
871+
rows = status.rows
872+
if isinstance(rows, dict):
873+
# spooled protocol
874+
rows = cast(_SpooledProtocolResponseTO, rows)
875+
segments = self._to_segments(rows)
876+
return list(SegmentIterator(segments))
877+
else:
878+
return self._row_mapper.map(rows)
879+
880+
def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData:
881+
encoding = rows["encoding"]
882+
segments = []
883+
for segment in rows["segments"]:
884+
segment_type = segment["type"]
885+
if segment_type == "inline":
886+
inline_segment = cast(_InlineSegmentTO, segment)
887+
segments.append(InlineSegment(inline_segment, self._row_mapper))
888+
elif segment_type == "spooled":
889+
spooled_segment = cast(_SpooledSegmentTO, segment)
890+
segments.append(SpooledSegment(rows, spooled_segment, self._row_mapper, self._request))
891+
else:
892+
raise ValueError(f"Unsupported segment type: {segment_type}")
893+
894+
return SpooledData(encoding, segments)
853895

854896
def cancel(self) -> None:
855897
"""Cancel the current query"""
@@ -925,3 +967,149 @@ def _parse_retry_after_header(retry_after):
925967
retry_date = parsedate_to_datetime(retry_after)
926968
now = datetime.utcnow()
927969
return (retry_date - now).total_seconds()
970+
971+
972+
# Trino Spooled protocol transfer objects
973+
class _SpooledProtocolResponseTO(TypedDict):
974+
encoding: Literal["json", "json+std", "json+lz4"]
975+
segments: List[_SegmentTO]
976+
977+
978+
class _SegmentMetadataTO(TypedDict):
979+
uncompressedSize: str
980+
981+
982+
class _SegmentTO(_SegmentMetadataTO):
983+
type: Literal["spooled", "inline"]
984+
metadata: _SegmentMetadataTO
985+
986+
987+
class _SpooledSegmentTO(_SegmentTO):
988+
uri: str
989+
ackUri: str
990+
headers: Dict[str, List[str]]
991+
992+
993+
class _InlineSegmentTO(_SegmentTO):
994+
data: str
995+
996+
997+
class Segment:
998+
def __init__(self, segment: _SegmentTO) -> None:
999+
self._segment = segment
1000+
self.metadata = segment["metadata"]
1001+
1002+
@property
1003+
@abstractmethod
1004+
def rows(self):
1005+
pass
1006+
1007+
1008+
class InlineSegment(Segment):
1009+
def __init__(self, segment: _InlineSegmentTO, row_mapper: RowMapper) -> None:
1010+
super().__init__(segment)
1011+
self._segment = cast(_InlineSegmentTO, segment)
1012+
self._row_mapper = row_mapper
1013+
1014+
@property
1015+
def rows(self) -> List[List[Any]]:
1016+
inline_segment = self._segment
1017+
data = inline_segment["data"]
1018+
decoded_string = base64.b64decode(data)
1019+
rows = self._row_mapper.map(json.loads(decoded_string))
1020+
return list(iter(rows))
1021+
1022+
1023+
class SpooledSegment(Segment):
1024+
def __init__(
1025+
self,
1026+
response: _SpooledProtocolResponseTO,
1027+
segment: _SpooledSegmentTO,
1028+
row_mapper: RowMapper,
1029+
request: TrinoRequest
1030+
) -> None:
1031+
super().__init__(segment)
1032+
self._segment = cast(_SpooledSegmentTO, segment)
1033+
self._row_mapper = row_mapper
1034+
self._request = request
1035+
self._encoding = response["encoding"]
1036+
self._headers = segment.get("headers")
1037+
1038+
@property
1039+
def rows(self) -> List[List[Any]]:
1040+
return self._row_mapper.map(json.loads(self._load_spooled_segment()))
1041+
1042+
def acknowledge(self):
1043+
http_response = self._send_spooling_request(self._segment["ackUri"], self._segment)
1044+
if not http_response.ok:
1045+
self._request.raise_response_error(http_response)
1046+
1047+
def _load_spooled_segment(self) -> str:
1048+
segment = self._segment
1049+
http_response = self._send_spooling_request(segment["uri"], segment)
1050+
if not http_response.ok:
1051+
self._request.raise_response_error(http_response)
1052+
content = http_response.content
1053+
1054+
encoding = self._encoding
1055+
if encoding == "json+zstd":
1056+
zstd_decompressor = zstandard.ZstdDecompressor()
1057+
return zstd_decompressor.decompress(content).decode('utf-8')
1058+
elif encoding == "json+lz4":
1059+
expected_size = segment["metadata"]["uncompressedSize"]
1060+
return lz4.block.decompress(content, uncompressed_size=int(expected_size)).decode('utf-8')
1061+
elif encoding == "json":
1062+
return content.decode('utf-8')
1063+
else:
1064+
raise ValueError(f"Unsupported encoding: {encoding}")
1065+
1066+
def _send_spooling_request(self, uri: str, segment: _SpooledSegmentTO) -> requests.Response:
1067+
req = requests.Request(
1068+
"GET",
1069+
uri,
1070+
headers=list(map(lambda item: [(item[0], value) for value in item[1]], segment.get("headers", {}).items()))
1071+
)
1072+
prepared = self._request._prepare_request(req)
1073+
return self._request._send(prepared)
1074+
1075+
1076+
class SpooledData:
1077+
def __init__(self, encoding: str, segments: List[Segment]) -> None:
1078+
self._encoding = encoding
1079+
self._segments = iter(segments)
1080+
1081+
def __iter__(self) -> Iterator[Tuple["SpooledData", "Segment"]]:
1082+
return self
1083+
1084+
def __next__(self) -> Tuple["SpooledData", "Segment"]:
1085+
return self, next(self._segments)
1086+
1087+
1088+
class SegmentIterator:
1089+
def __init__(self, spooled_data: SpooledData):
1090+
self._segments = iter(spooled_data._segments)
1091+
self._rows: Iterator[List[List[Any]]] = iter([])
1092+
self._finished = False
1093+
self._current_segment: Optional[Segment] = None
1094+
1095+
def __iter__(self) -> Iterator[List[Any]]:
1096+
return self
1097+
1098+
def __next__(self) -> List[Any]:
1099+
# If rows are exhausted, fetch the next segment
1100+
while True:
1101+
try:
1102+
return next(self._rows)
1103+
except StopIteration:
1104+
if self._current_segment and isinstance(self._current_segment, SpooledSegment):
1105+
self._current_segment.acknowledge()
1106+
if self._finished:
1107+
raise StopIteration
1108+
self._load_next_row_set()
1109+
1110+
def _load_next_row_set(self):
1111+
try:
1112+
self._current_segment = segment = next(self._segments)
1113+
self._rows = iter(segment.rows)
1114+
except StopIteration:
1115+
self._finished = True

trino/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
HEADER_CLIENT_TAGS = "X-Trino-Client-Tags"
3838
HEADER_EXTRA_CREDENTIAL = "X-Trino-Extra-Credential"
3939
HEADER_TIMEZONE = "X-Trino-Time-Zone"
40+
HEADER_ENCODING = "X-Trino-Query-Data-Encoding"
4041

4142
HEADER_SESSION = "X-Trino-Session"
4243
HEADER_SET_SESSION = "X-Trino-Set-Session"

trino/dbapi.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
legacy_prepared_statements=None,
154154
roles=None,
155155
timezone=None,
156+
encoding=None,
156157
):
157158
# Automatically assign http_schema, port based on hostname
158159
parsed_host = urlparse(host, allow_fragments=False)
@@ -176,6 +177,7 @@ def __init__(
176177
client_tags=client_tags,
177178
roles=roles,
178179
timezone=timezone,
180+
encoding=encoding,
179181
)
180182
# mypy cannot follow module import
181183
if http_session is None:

0 commit comments

Comments
 (0)