Skip to content

Commit 4d3de53

Browse files
committed
Support "segment" cursor style
1 parent 6739d40 commit 4d3de53

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,6 +1838,26 @@ def test_select_query_spooled_segments(trino_connection):
18381838
assert len(rows) > 0
18391839

18401840

1841+
@pytest.mark.skipif(
1842+
trino_version() <= '464',
1843+
reason="spooled protocol was introduced in version 464"
1844+
)
1845+
def test_segments_cursor(trino_connection):
1846+
if trino_connection._client_session.encoding is None:
1847+
with pytest.raises(ValueError, match=".*encoding.*"):
1848+
trino_connection.cursor("segment")
1849+
return
1850+
cur = trino_connection.cursor("segment")
1851+
cur.execute("""SELECT l.*
1852+
FROM tpch.tiny.lineitem l, TABLE(sequence(
1853+
start => 1,
1854+
stop => 5,
1855+
step => 1)) n""")
1856+
rows = cur.fetchall()
1857+
# TODO: improve test
1858+
assert len(rows) > 0
1859+
1860+
18411861
def get_cursor(legacy_prepared_statements, run_trino):
18421862
host, port = run_trino
18431863

trino/client.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,16 @@
6363
from trino._version import __version__
6464
from trino.mapper import RowMapper, RowMapperFactory
6565

66-
__all__ = ["ClientSession", "TrinoQuery", "TrinoRequest", "PROXIES"]
66+
__all__ = [
67+
"ClientSession",
68+
"TrinoQuery",
69+
"TrinoRequest",
70+
"PROXIES",
71+
"SpooledData",
72+
"SpooledSegment",
73+
"InlineSegment",
74+
"Segment"
75+
]
6776

6877
logger = trino.logging.get_logger(__name__)
6978

@@ -753,6 +762,7 @@ def __init__(
753762
request: TrinoRequest,
754763
query: str,
755764
legacy_primitive_types: bool = False,
765+
fetch_mode: Literal["mapped", "segments"] = "mapped"
756766
) -> None:
757767
self._query_id: Optional[str] = None
758768
self._stats: Dict[Any, Any] = {}
@@ -769,6 +779,7 @@ def __init__(
769779
self._result: Optional[TrinoResult] = None
770780
self._legacy_primitive_types = legacy_primitive_types
771781
self._row_mapper: Optional[RowMapper] = None
782+
self._fetch_mode = fetch_mode
772783

773784
@property
774785
def query_id(self) -> Optional[str]:
@@ -869,10 +880,12 @@ def fetch(self) -> List[Union[List[Any]], Any]:
869880
return []
870881

871882
rows = status.rows
872-
if isinstance(rows, dict):
883+
if isinstance(status.rows, dict):
873884
# spooled protocol
874885
rows = cast(_SpooledProtocolResponseTO, rows)
875886
segments = self._to_segments(rows)
887+
if self._fetch_mode == "segments":
888+
return segments
876889
return list(SegmentIterator(segments))
877890
else:
878891
return self._row_mapper.map(rows)

trino/dbapi.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _create_request(self):
251251
self.request_timeout,
252252
)
253253

254-
def cursor(self, legacy_primitive_types: bool = None):
254+
def cursor(self, cursor_style: str = "row", legacy_primitive_types: bool = None):
255255
"""Return a new :py:class:`Cursor` object using the connection."""
256256
if self.isolation_level != IsolationLevel.AUTOCOMMIT:
257257
if self.transaction is None:
@@ -260,11 +260,21 @@ def cursor(self, legacy_primitive_types: bool = None):
260260
request = self.transaction.request
261261
else:
262262
request = self._create_request()
263-
return Cursor(
263+
264+
cursor_class = {
265+
# Add any custom Cursor classes here
266+
"segment": SegmentCursor,
267+
"row": Cursor
268+
}.get(cursor_style.lower(), Cursor)
269+
270+
return cursor_class(
264271
self,
265272
request,
266-
# if legacy params are not explicitly set in Cursor, take them from Connection
267-
legacy_primitive_types if legacy_primitive_types is not None else self.legacy_primitive_types
273+
legacy_primitive_types=(
274+
legacy_primitive_types
275+
if legacy_primitive_types is not None
276+
else self.legacy_primitive_types
277+
)
268278
)
269279

270280
def _use_legacy_prepared_statements(self):
@@ -697,6 +707,28 @@ def close(self):
697707
# but also any other outstanding queries executed through this cursor.
698708

699709

710+
class SegmentCursor(Cursor):
711+
def __init__(
712+
self,
713+
connection,
714+
request,
715+
legacy_primitive_types: bool = False):
716+
super().__init__(connection, request, legacy_primitive_types=legacy_primitive_types)
717+
if self.connection._client_session.encoding is None:
718+
raise ValueError("SegmentCursor can only be used if encoding is set on the connection")
719+
720+
def execute(self, operation, params=None):
721+
if params:
722+
# TODO: refactor code to allow for params to be supported
723+
raise ValueError("params not supported")
724+
725+
self._query = trino.client.TrinoQuery(self._request, query=operation,
726+
legacy_primitive_types=self._legacy_primitive_types,
727+
fetch_mode="segments")
728+
self._iterator = iter(self._query.execute())
729+
return self
730+
731+
700732
Date = datetime.date
701733
Time = datetime.time
702734
Timestamp = datetime.datetime

0 commit comments

Comments
 (0)