Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions pypaimon/py4j/java_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,15 @@ def file_paths(self) -> List[str]:
class TableRead(table_read.TableRead):

def __init__(self, j_table_read, j_read_type, catalog_options):
self._j_table_read = j_table_read
self._j_read_type = j_read_type
self._catalog_options = catalog_options
self._j_bytes_reader = None
self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
j_table_read, j_read_type, TableRead._get_max_workers(catalog_options))

def to_arrow(self, splits):
record_batch_reader = self.to_arrow_batch_reader(splits)
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)

def to_arrow_batch_reader(self, splits):
self._init()
j_splits = list(map(lambda s: s.to_j_split(), splits))
self._j_bytes_reader.setSplits(j_splits)
batch_iterator = self._batch_generator()
Expand All @@ -213,19 +210,13 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":

return ray.data.from_arrow(self.to_arrow(splits))

def _init(self):
if self._j_bytes_reader is None:
# get thread num
max_workers = self._catalog_options.get(constants.MAX_WORKERS)
if max_workers is None:
# default is sequential
max_workers = 1
else:
max_workers = int(max_workers)
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
self._j_table_read, self._j_read_type, max_workers)
@staticmethod
def _get_max_workers(catalog_options):
# default is sequential
max_workers = int(catalog_options.get(constants.MAX_WORKERS, 1))
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
return max_workers

def _batch_generator(self) -> Iterator[pa.RecordBatch]:
while True:
Expand Down