3434"""
3535from __future__ import annotations
3636
37+ import base64
3738import copy
3839import functools
40+ import json
3941import os
4042import random
4143import re
4244import threading
4345import urllib .parse
4446import warnings
47+ from abc import abstractmethod
4548from dataclasses import dataclass
4649from datetime import datetime
4750from email .utils import parsedate_to_datetime
4851from time import sleep
49- from typing import Any , Dict , List , Optional , Tuple , Union
52+ from typing import Any , Dict , Iterator , List , Literal , Optional , Tuple , Union , cast
5053from zoneinfo import ZoneInfo
5154
55+ import lz4 .block
5256import requests
57+ import zstandard
58+ from typing_extensions import TypedDict
5359from tzlocal import get_localzone_name # type: ignore
5460
5561import trino .logging
@@ -107,6 +113,7 @@ class ClientSession:
107113 :param roles: roles for the current session. Some connectors do not
108114 support role management. See connector documentation for more details.
109115 :param timezone: The timezone for query processing. Defaults to the system's local timezone.
116+ :param encoding: The encoding for the spooled protocol. Defaults to None.
110117 """
111118
112119 def __init__ (
@@ -123,6 +130,7 @@ def __init__(
123130 client_tags : List [str ] = None ,
124131 roles : Union [Dict [str , str ], str ] = None ,
125132 timezone : str = None ,
133+ encoding : str = None ,
126134 ):
127135 self ._user = user
128136 self ._authorization_user = authorization_user
@@ -140,6 +148,7 @@ def __init__(
140148 self ._timezone = timezone or get_localzone_name ()
141149 if timezone : # Check timezone validity
142150 ZoneInfo (timezone )
151+ self ._encoding = encoding
143152
144153 @property
145154 def user (self ):
@@ -235,6 +244,11 @@ def timezone(self):
235244 with self ._object_lock :
236245 return self ._timezone
237246
247+ @property
248+ def encoding (self ):
249+ with self ._object_lock :
250+ return self ._encoding
251+
238252 def _format_roles (self , roles ):
239253 if isinstance (roles , str ):
240254 roles = {"system" : roles }
@@ -299,7 +313,7 @@ class TrinoStatus:
299313 next_uri : Optional [str ]
300314 update_type : Optional [str ]
301315 update_count : Optional [int ]
302- rows : List [Any ]
316+ rows : Union [ List [Any ], Dict [ str , Any ] ]
303317 columns : List [Any ]
304318
305319 def __repr__ (self ):
@@ -462,6 +476,7 @@ def http_headers(self) -> Dict[str, str]:
462476 headers [constants .HEADER_USER ] = self ._client_session .user
463477 headers [constants .HEADER_AUTHORIZATION_USER ] = self ._client_session .authorization_user
464478 headers [constants .HEADER_TIMEZONE ] = self ._client_session .timezone
479+ headers [constants .HEADER_ENCODING ] = self ._client_session .encoding
465480 headers [constants .HEADER_CLIENT_CAPABILITIES ] = 'PARAMETRIC_DATETIME'
466481 headers ["user-agent" ] = f"{ constants .CLIENT_NAME } /{ __version__ } "
467482 if len (self ._client_session .roles .values ()):
@@ -522,6 +537,8 @@ def max_attempts(self, value) -> None:
522537 self ._get = self ._http_session .get
523538 self ._post = self ._http_session .post
524539 self ._delete = self ._http_session .delete
540+ self ._prepare_request = self ._http_session .prepare_request
541+ self ._send = self ._http_session .send
525542 return
526543
527544 with_retry = _retry_with (
@@ -537,6 +554,8 @@ def max_attempts(self, value) -> None:
537554 self ._get = with_retry (self ._http_session .get )
538555 self ._post = with_retry (self ._http_session .post )
539556 self ._delete = with_retry (self ._http_session .delete )
557+ self ._prepare_request = self ._http_session .prepare_request
558+ self ._send = with_retry (self ._http_session .send )
540559
541560 def get_url (self , path ) -> str :
542561 return "{protocol}://{host}:{port}{path}" .format (
@@ -835,7 +854,7 @@ def _update_state(self, status):
835854 if status .columns :
836855 self ._columns = status .columns
837856
838- def fetch (self ) -> List [List [Any ]]:
857+ def fetch (self ) -> List [Union [ List [Any ]], Any ]:
839858 """Continue fetching data for the current query_id"""
840859 try :
841860 response = self ._request .get (self ._request .next_uri )
@@ -849,7 +868,30 @@ def fetch(self) -> List[List[Any]]:
849868 if not self ._row_mapper :
850869 return []
851870
852- return self ._row_mapper .map (status .rows )
871+ rows = status .rows
872+ if isinstance (rows , dict ):
873+ # spooled protocol
874+ rows = cast (_SpooledProtocolResponseTO , rows )
875+ segments = self ._to_segments (rows )
876+ return list (SegmentIterator (segments ))
877+ else :
878+ return self ._row_mapper .map (rows )
879+
880+ def _to_segments (self , rows : _SpooledProtocolResponseTO ) -> SpooledData :
881+ encoding = rows ["encoding" ]
882+ segments = []
883+ for segment in rows ["segments" ]:
884+ segment_type = segment ["type" ]
885+ if segment_type == "inline" :
886+ inline_segment = cast (_InlineSegmentTO , segment )
887+ segments .append (InlineSegment (inline_segment , self ._row_mapper ))
888+ elif segment_type == "spooled" :
889+ spooled_segment = cast (_SpooledSegmentTO , segment )
890+ segments .append (SpooledSegment (rows , spooled_segment , self ._row_mapper , self ._request ))
891+ else :
892+ raise ValueError (f"Unsupported segment type: { segment_type } " )
893+
894+ return SpooledData (encoding , segments )
853895
854896 def cancel (self ) -> None :
855897 """Cancel the current query"""
@@ -925,3 +967,149 @@ def _parse_retry_after_header(retry_after):
925967 retry_date = parsedate_to_datetime (retry_after )
926968 now = datetime .utcnow ()
927969 return (retry_date - now ).total_seconds ()
970+
971+
972+ # Trino Spooled protocol transfer objects
973+ class _SpooledProtocolResponseTO (TypedDict ):
974+ encoding : Literal ["json" , "json+std" , "json+lz4" ]
975+ segments : List [_SegmentTO ]
976+
977+
978+ class _SegmentMetadataTO (TypedDict ):
979+ uncompressedSize : str
980+
981+
982+ class _SegmentTO (_SegmentMetadataTO ):
983+ type : Literal ["spooled" , "inline" ]
984+ metadata : _SegmentMetadataTO
985+
986+
987+ class _SpooledSegmentTO (_SegmentTO ):
988+ uri : str
989+ ackUri : str
990+ headers : Dict [str , List [str ]]
991+
992+
993+ class _InlineSegmentTO (_SegmentTO ):
994+ data : str
995+
996+
997+ class Segment :
998+ def __init__ (self , segment : _SegmentTO ) -> None :
999+ self ._segment = segment
1000+ self .metadata = segment ["metadata" ]
1001+
1002+ @property
1003+ @abstractmethod
1004+ def rows (self ):
1005+ pass
1006+
1007+
1008+ class InlineSegment (Segment ):
1009+ def __init__ (self , segment : _InlineSegmentTO , row_mapper : RowMapper ) -> None :
1010+ super ().__init__ (segment )
1011+ self ._segment = cast (_InlineSegmentTO , segment )
1012+ self ._row_mapper = row_mapper
1013+
1014+ @property
1015+ def rows (self ) -> List [List [Any ]]:
1016+ inline_segment = self ._segment
1017+ data = inline_segment ["data" ]
1018+ decoded_string = base64 .b64decode (data )
1019+ rows = self ._row_mapper .map (json .loads (decoded_string ))
1020+ return list (iter (rows ))
1021+
1022+
1023+ class SpooledSegment (Segment ):
1024+ def __init__ (
1025+ self ,
1026+ response : _SpooledProtocolResponseTO ,
1027+ segment : _SpooledSegmentTO ,
1028+ row_mapper : RowMapper ,
1029+ request : TrinoRequest
1030+ ) -> None :
1031+ super ().__init__ (segment )
1032+ self ._segment = cast (_SpooledSegmentTO , segment )
1033+ self ._row_mapper = row_mapper
1034+ self ._request = request
1035+ self ._encoding = response ["encoding" ]
1036+ self ._headers = segment .get ("headers" )
1037+
1038+ @property
1039+ def rows (self ) -> List [List [Any ]]:
1040+ return self ._row_mapper .map (json .loads (self ._load_spooled_segment ()))
1041+
1042+ def acknowledge (self ):
1043+ http_response = self ._send_spooling_request (self ._segment ["ackUri" ], self ._segment )
1044+ if not http_response .ok :
1045+ self ._request .raise_response_error (http_response )
1046+
1047+ def _load_spooled_segment (self ) -> str :
1048+ segment = self ._segment
1049+ http_response = self ._send_spooling_request (segment ["uri" ], segment )
1050+ if not http_response .ok :
1051+ self ._request .raise_response_error (http_response )
1052+ content = http_response .content
1053+
1054+ encoding = self ._encoding
1055+ if encoding == "json+zstd" :
1056+ zstd_decompressor = zstandard .ZstdDecompressor ()
1057+ return zstd_decompressor .decompress (content ).decode ('utf-8' )
1058+ elif encoding == "json+lz4" :
1059+ expected_size = segment ["metadata" ]["uncompressedSize" ]
1060+ return lz4 .block .decompress (content , uncompressed_size = int (expected_size )).decode ('utf-8' )
1061+ elif encoding == "json" :
1062+ return content .decode ('utf-8' )
1063+ else :
1064+ raise ValueError (f"Unsupported encoding: { encoding } " )
1065+
1066+ def _send_spooling_request (self , uri : str , segment : _SpooledSegmentTO ) -> requests .Response :
1067+ req = requests .Request (
1068+ "GET" ,
1069+ uri ,
1070+ headers = list (map (lambda item : [(item [0 ], value ) for value in item [1 ]], segment .get ("headers" , {}).items ()))
1071+ )
1072+ prepared = self ._request ._prepare_request (req )
1073+ return self ._request ._send (prepared )
1074+
1075+
1076+ class SpooledData :
1077+ def __init__ (self , encoding : str , segments : List [Segment ]) -> None :
1078+ self ._encoding = encoding
1079+ self ._segments = iter (segments )
1080+
1081+ def __iter__ (self ) -> Iterator [Tuple ["SpooledData" , "Segment" ]]:
1082+ return self
1083+
1084+ def __next__ (self ) -> Tuple ["SpooledData" , "Segment" ]:
1085+ return self , next (self ._segments )
1086+
1087+
1088+ class SegmentIterator :
1089+ def __init__ (self , spooled_data : SpooledData ):
1090+ self ._segments = iter (spooled_data ._segments )
1091+ self ._rows : Iterator [List [List [Any ]]] = iter ([])
1092+ self ._finished = False
1093+ self ._current_segment : Optional [Segment ] = None
1094+
1095+ def __iter__ (self ) -> Iterator [List [Any ]]:
1096+ return self
1097+
1098+ def __next__ (self ) -> List [Any ]:
1099+ # If rows are exhausted, fetch the next segment
1100+ while True :
1101+ try :
1102+ return next (self ._rows )
1103+ except StopIteration :
1104+ if self ._current_segment and isinstance (self ._current_segment , SpooledSegment ):
1105+ self ._current_segment .acknowledge ()
1106+ if self ._finished :
1107+ raise StopIteration
1108+ self ._load_next_row_set ()
1109+
1110+ def _load_next_row_set (self ):
1111+ try :
1112+ self ._current_segment = segment = next (self ._segments )
1113+ self ._rows = iter (segment .rows )
1114+ except StopIteration :
1115+ self ._finished = True
0 commit comments