Skip to content

Commit 24af67e

Browse files
committed
Support "segment" cursor style
1 parent 74ee056 commit 24af67e

File tree

3 files changed

+63
-4
lines changed

3 files changed

+63
-4
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,6 +1853,29 @@ def test_select_query_spooled_segments(trino_connection):
18531853
assert isinstance(row[13], str), f"Expected string for shipinstruct, got {type(row[13])}"
18541854

18551855

1856+
@pytest.mark.skipif(
1857+
trino_version() <= 466,
1858+
reason="spooling protocol was introduced in version 466"
1859+
)
1860+
def test_segments_cursor(trino_connection):
1861+
if trino_connection._client_session.encoding is None:
1862+
with pytest.raises(ValueError, match=".*encoding.*"):
1863+
trino_connection.cursor("segment")
1864+
return
1865+
cur = trino_connection.cursor("segment")
1866+
cur.execute("""SELECT l.*
1867+
FROM tpch.tiny.lineitem l, TABLE(sequence(
1868+
start => 1,
1869+
stop => 5,
1870+
step => 1)) n""")
1871+
rows = cur.fetchall()
1872+
assert len(rows) > 0
1873+
for spooled_data, spooled_segment in rows:
1874+
assert spooled_data.encoding == trino_connection._client_session.encoding
1875+
assert isinstance(spooled_segment.uri, str), f"Expected string for uri, got {spooled_segment.uri}"
1876+
assert isinstance(spooled_segment.ack_uri, str), f"Expected string for ack_uri, got {spooled_segment.ack_uri}"
1877+
1878+
18561879
def get_cursor(legacy_prepared_statements, run_trino):
18571880
host, port = run_trino
18581881

trino/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,7 @@ def __init__(
779779
request: TrinoRequest,
780780
query: str,
781781
legacy_primitive_types: bool = False,
782+
fetch_mode: Literal["mapped", "segments"] = "mapped"
782783
) -> None:
783784
self._query_id: Optional[str] = None
784785
self._stats: Dict[Any, Any] = {}
@@ -795,6 +796,7 @@ def __init__(
795796
self._result: Optional[TrinoResult] = None
796797
self._legacy_primitive_types = legacy_primitive_types
797798
self._row_mapper: Optional[RowMapper] = None
799+
self._fetch_mode = fetch_mode
798800

799801
@property
800802
def query_id(self) -> Optional[str]:
@@ -899,6 +901,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
899901
# spooling protocol
900902
rows = cast(_SpooledProtocolResponseTO, rows)
901903
segments = self._to_segments(rows)
904+
if self._fetch_mode == "segments":
905+
return segments
902906
return list(SegmentIterator(segments))
903907
else:
904908
return self._row_mapper.map(rows)

trino/dbapi.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _create_request(self):
261261
self.request_timeout,
262262
)
263263

264-
def cursor(self, legacy_primitive_types: bool = None):
264+
def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None):
265265
"""Return a new :py:class:`Cursor` object using the connection."""
266266
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
267267
if self.transaction is None:
@@ -270,11 +270,21 @@ def cursor(self, legacy_primitive_types: bool = None):
270270
request = self.transaction.request
271271
else:
272272
request = self._create_request()
273-
return Cursor(
273+
274+
cursor_class = {
275+
# Add any custom Cursor classes here
276+
"segment": SegmentCursor,
277+
"row": Cursor
278+
}.get(cursor_style.lower(), Cursor)
279+
280+
return cursor_class(
274281
self,
275282
request,
276-
# if legacy params are not explicitly set in Cursor, take them from Connection
277-
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
283+
legacy_primitive_types=(
284+
legacy_primitive_types
285+
if legacy_primitive_types is not None
286+
else self.legacy_primitive_types
287+
)
278288
)
279289

280290
def _use_legacy_prepared_statements(self):
@@ -707,6 +717,28 @@ def close(self):
707717
# but also any other outstanding queries executed through this cursor.
708718

709719

720+
class SegmentCursor(Cursor):
721+
def __init__(
722+
self,
723+
connection,
724+
request,
725+
legacy_primitive_types: bool = False):
726+
super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types)
727+
if self.connection._client_session.encoding is None:
728+
raise ValueError("SegmentCursor can only be used if encoding is set on the connection")
729+
730+
def execute(self, operation, params=None):
731+
if params:
732+
# TODO: refactor code to allow for params to be supported
733+
raise ValueError("params not supported")
734+
735+
self._query = trino.client.TrinoQuery(self._request, query=operation,
736+
legacy_primitive_types=self._legacy_primitive_types,
737+
fetch_mode="segments")
738+
self._iterator = iter(self._query.execute())
739+
return self
740+
741+
710742
Date = datetime.date
711743
Time = datetime.time
712744
Timestamp = datetime.datetime

0 commit comments

Comments
 (0)