@@ -61,37 +61,36 @@ class Table(table.Table):
6161 def __init__ (self , j_table , catalog_options : dict ):
6262 self ._j_table = j_table
6363 self ._catalog_options = catalog_options
64- # init arrow schema
65- schema_bytes = get_gateway ().jvm .SchemaUtil .getArrowSchema (j_table .rowType ())
66- schema_reader = pa .RecordBatchStreamReader (pa .BufferReader (schema_bytes ))
67- self ._arrow_schema = schema_reader .schema
68- schema_reader .close ()
6964
7065 def new_read_builder (self ) -> 'ReadBuilder' :
7166 j_read_builder = get_gateway ().jvm .InvocationUtil .getReadBuilder (self ._j_table )
72- return ReadBuilder (
73- j_read_builder , self ._j_table .rowType (), self ._catalog_options , self ._arrow_schema )
67+ return ReadBuilder (j_read_builder , self ._j_table .rowType (), self ._catalog_options )
7468
7569 def new_batch_write_builder (self ) -> 'BatchWriteBuilder' :
7670 java_utils .check_batch_write (self ._j_table )
7771 j_batch_write_builder = get_gateway ().jvm .InvocationUtil .getBatchWriteBuilder (self ._j_table )
78- return BatchWriteBuilder (j_batch_write_builder , self . _j_table . rowType (), self . _arrow_schema )
72+ return BatchWriteBuilder (j_batch_write_builder )
7973
8074
8175class ReadBuilder (read_builder .ReadBuilder ):
8276
83- def __init__ (self , j_read_builder , j_row_type , catalog_options : dict , arrow_schema : pa . Schema ):
77+ def __init__ (self , j_read_builder , j_row_type , catalog_options : dict ):
8478 self ._j_read_builder = j_read_builder
8579 self ._j_row_type = j_row_type
8680 self ._catalog_options = catalog_options
87- self ._arrow_schema = arrow_schema
8881
8982 def with_filter (self , predicate : 'Predicate' ):
9083 self ._j_read_builder .withFilter (predicate .to_j_predicate ())
9184 return self
9285
93- def with_projection (self , projection : List [List [int ]]) -> 'ReadBuilder' :
94- self ._j_read_builder .withProjection (projection )
86+ def with_projection (self , projection : List [str ]) -> 'ReadBuilder' :
87+ field_names = list (map (lambda field : field .name (), self ._j_row_type .getFields ()))
88+ int_projection = list (map (lambda p : field_names .index (p ), projection ))
89+ gateway = get_gateway ()
90+ int_projection_arr = gateway .new_array (gateway .jvm .int , len (projection ))
91+ for i in range (len (projection )):
92+ int_projection_arr [i ] = int_projection [i ]
93+ self ._j_read_builder .withProjection (int_projection_arr )
9594 return self
9695
9796 def with_limit (self , limit : int ) -> 'ReadBuilder' :
@@ -104,7 +103,7 @@ def new_scan(self) -> 'TableScan':
104103
105104 def new_read (self ) -> 'TableRead' :
106105 j_table_read = self ._j_read_builder .newRead ().executeFilter ()
107- return TableRead (j_table_read , self ._j_row_type , self . _catalog_options , self ._arrow_schema )
106+ return TableRead (j_table_read , self ._j_read_builder . readType () , self ._catalog_options )
108107
109108 def new_predicate_builder (self ) -> 'PredicateBuilder' :
110109 return PredicateBuilder (self ._j_row_type )
@@ -141,12 +140,12 @@ def to_j_split(self):
141140
142141class TableRead (table_read .TableRead ):
143142
144- def __init__ (self , j_table_read , j_row_type , catalog_options , arrow_schema ):
143+ def __init__ (self , j_table_read , j_read_type , catalog_options ):
145144 self ._j_table_read = j_table_read
146- self ._j_row_type = j_row_type
145+ self ._j_read_type = j_read_type
147146 self ._catalog_options = catalog_options
148147 self ._j_bytes_reader = None
149- self ._arrow_schema = arrow_schema
148+ self ._arrow_schema = java_utils . to_arrow_schema ( j_read_type )
150149
151150 def to_arrow (self , splits ):
152151 record_batch_reader = self .to_arrow_batch_reader (splits )
@@ -174,7 +173,7 @@ def _init(self):
174173 if max_workers <= 0 :
175174 raise ValueError ("max_workers must be greater than 0" )
176175 self ._j_bytes_reader = get_gateway ().jvm .InvocationUtil .createParallelBytesReader (
177- self ._j_table_read , self ._j_row_type , max_workers )
176+ self ._j_table_read , self ._j_read_type , max_workers )
178177
179178 def _batch_generator (self ) -> Iterator [pa .RecordBatch ]:
180179 while True :
@@ -188,10 +187,8 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]:
188187
189188class BatchWriteBuilder (write_builder .BatchWriteBuilder ):
190189
191- def __init__ (self , j_batch_write_builder , j_row_type , arrow_schema : pa . Schema ):
190+ def __init__ (self , j_batch_write_builder ):
192191 self ._j_batch_write_builder = j_batch_write_builder
193- self ._j_row_type = j_row_type
194- self ._arrow_schema = arrow_schema
195192
196193 def overwrite (self , static_partition : Optional [dict ] = None ) -> 'BatchWriteBuilder' :
197194 if static_partition is None :
@@ -201,7 +198,7 @@ def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuild
201198
202199 def new_write (self ) -> 'BatchTableWrite' :
203200 j_batch_table_write = self ._j_batch_write_builder .newWrite ()
204- return BatchTableWrite (j_batch_table_write , self ._j_row_type , self . _arrow_schema )
201+ return BatchTableWrite (j_batch_table_write , self ._j_batch_write_builder . rowType () )
205202
206203 def new_commit (self ) -> 'BatchTableCommit' :
207204 j_batch_table_commit = self ._j_batch_write_builder .newCommit ()
@@ -210,11 +207,11 @@ def new_commit(self) -> 'BatchTableCommit':
210207
211208class BatchTableWrite (table_write .BatchTableWrite ):
212209
213- def __init__ (self , j_batch_table_write , j_row_type , arrow_schema : pa . Schema ):
210+ def __init__ (self , j_batch_table_write , j_row_type ):
214211 self ._j_batch_table_write = j_batch_table_write
215212 self ._j_bytes_writer = get_gateway ().jvm .InvocationUtil .createBytesWriter (
216213 j_batch_table_write , j_row_type )
217- self ._arrow_schema = arrow_schema
214+ self ._arrow_schema = java_utils . to_arrow_schema ( j_row_type )
218215
219216 def write_arrow (self , table ):
220217 for record_batch in table .to_reader ():
0 commit comments