Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pypaimon/api/table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
class TableRead(ABC):
"""To read data from data splits."""

@abstractmethod
def to_arrow(self, splits: List[Split]) -> pa.Table:
"""Read data from splits and converted to pyarrow.Table format."""

@abstractmethod
def to_arrow_batch_reader(self, splits: List[Split]) -> pa.RecordBatchReader:
"""Read data from splits and converted to pyarrow.RecordBatchReader format."""

@abstractmethod
def to_arrow(self, splits: List[Split]) -> pa.Table:
"""Read data from splits and converted to pyarrow.Table format."""

@abstractmethod
def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
"""Read data from splits and converted to pandas.DataFrame format."""
Expand Down
37 changes: 14 additions & 23 deletions pypaimon/py4j/java_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,20 @@ def file_paths(self) -> List[str]:
class TableRead(table_read.TableRead):

def __init__(self, j_table_read, j_read_type, catalog_options):
self._j_table_read = j_table_read
self._j_read_type = j_read_type
self._catalog_options = catalog_options
self._j_bytes_reader = None
self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
j_table_read, j_read_type, TableRead._get_max_workers(catalog_options))

def to_arrow(self, splits):
record_batch_reader = self.to_arrow_batch_reader(splits)
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)

def to_arrow_batch_reader(self, splits):
self._init()
def to_arrow_batch_reader(self, splits) -> pa.RecordBatchReader:
j_splits = list(map(lambda s: s.to_j_split(), splits))
self._j_bytes_reader.setSplits(j_splits)
batch_iterator = self._batch_generator()
return pa.RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)

def to_arrow(self, splits) -> pa.Table:
record_batch_reader = self.to_arrow_batch_reader(splits)
return pa.Table.from_batches(record_batch_reader, schema=self._arrow_schema)

def to_pandas(self, splits: List[Split]) -> pd.DataFrame:
return self.to_arrow(splits).to_pandas()

Expand All @@ -213,19 +210,13 @@ def to_ray(self, splits: List[Split]) -> "ray.data.dataset.Dataset":

return ray.data.from_arrow(self.to_arrow(splits))

def _init(self):
if self._j_bytes_reader is None:
# get thread num
max_workers = self._catalog_options.get(constants.MAX_WORKERS)
if max_workers is None:
# default is sequential
max_workers = 1
else:
max_workers = int(max_workers)
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
self._j_table_read, self._j_read_type, max_workers)
@staticmethod
def _get_max_workers(catalog_options):
# default is sequential
max_workers = int(catalog_options.get(constants.MAX_WORKERS, 1))
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
return max_workers

def _batch_generator(self) -> Iterator[pa.RecordBatch]:
while True:
Expand Down