diff --git a/pypaimon/py4j/java_implementation.py b/pypaimon/py4j/java_implementation.py index 9f378b7..07721f0 100644 --- a/pypaimon/py4j/java_implementation.py +++ b/pypaimon/py4j/java_implementation.py @@ -177,18 +177,15 @@ def file_paths(self) -> List[str]: class TableRead(table_read.TableRead): def __init__(self, j_table_read, j_read_type, catalog_options): - self._j_table_read = j_table_read - self._j_read_type = j_read_type - self._catalog_options = catalog_options - self._j_bytes_reader = None self._arrow_schema = java_utils.to_arrow_schema(j_read_type) + self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader( + j_table_read, j_read_type, TableRead._get_max_workers(catalog_options)) def to_arrow(self, splits): record_batch_reader = self.to_arrow_batch_reader(splits) return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema) def to_arrow_batch_reader(self, splits): - self._init() j_splits = list(map(lambda s: s.to_j_split(), splits)) self._j_bytes_reader.setSplits(j_splits) batch_iterator = self._batch_generator() @@ -213,19 +210,13 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset": return ray.data.from_arrow(self.to_arrow(splits)) - def _init(self): - if self._j_bytes_reader is None: - # get thread num - max_workers = self._catalog_options.get(constants.MAX_WORKERS) - if max_workers is None: - # default is sequential - max_workers = 1 - else: - max_workers = int(max_workers) - if max_workers <= 0: - raise ValueError("max_workers must be greater than 0") - self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader( - self._j_table_read, self._j_read_type, max_workers) + @staticmethod + def _get_max_workers(catalog_options): + # default is sequential + max_workers = int(catalog_options.get(constants.MAX_WORKERS, 1)) + if max_workers <= 0: + raise ValueError("max_workers must be greater than 0") + return max_workers def _batch_generator(self) -> Iterator[pa.RecordBatch]: while True: