Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions hamilton/plugins/polars_lazyframe_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,166 @@ def name(cls) -> str:
return "feather"



@dataclasses.dataclass
class PolarsSinkParquetWriter(DataLoader):
"""
Class specifically to handle writing parquet files with Polars LazyFrame.
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_parquet.html
"""
path: Union[str, Path]
# kwargs:
compression: str = "zstd"
compression_level: Optional[int] = None
statistics: bool = False
row_group_size: Optional[int] = None
data_page_size: Optional[int] = None

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

def _get_writing_kwargs(self):
kwargs = {}
if self.compression is not None:
kwargs["compression"] = self.compression
if self.compression_level is not None:
kwargs["compression_level"] = self.compression_level
if self.statistics is not None:
kwargs["statistics"] = self.statistics
if self.row_group_size is not None:
kwargs["row_group_size"] = self.row_group_size
if self.data_page_size is not None:
kwargs["data_page_size"] = self.data_page_size
return kwargs

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.sink_parquet(self.path, **self._get_writing_kwargs())
metadata = utils.get_file_metadata(self.path)
return metadata

@classmethod
def name(cls) -> str:
return "parquet"


@dataclasses.dataclass
class PolarsSinkCSVWriter(DataLoader):
"""
Class specifically to handle writing CSV files with Polars LazyFrame.
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_csv.html
"""
path: Union[str, Path]
# kwargs:
include_bom: bool = False
include_header: bool = True
separator: str = ","
line_terminator: str = "\n"
quote_char: str = '"'
batch_size: int = 1024
datetime_format: Optional[str] = None
date_format: Optional[str] = None
time_format: Optional[str] = None
float_precision: Optional[int] = None
null_value: Optional[str] = None
quote_style: Optional[str] = None

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

def _get_writing_kwargs(self):
kwargs = {}
if self.include_bom is not None:
kwargs["include_bom"] = self.include_bom
if self.include_header is not None:
kwargs["include_header"] = self.include_header
if self.separator is not None:
kwargs["separator"] = self.separator
if self.line_terminator is not None:
kwargs["line_terminator"] = self.line_terminator
if self.quote_char is not None:
kwargs["quote_char"] = self.quote_char
if self.batch_size is not None:
kwargs["batch_size"] = self.batch_size
if self.datetime_format is not None:
kwargs["datetime_format"] = self.datetime_format
if self.date_format is not None:
kwargs["date_format"] = self.date_format
if self.time_format is not None:
kwargs["time_format"] = self.time_format
if self.float_precision is not None:
kwargs["float_precision"] = self.float_precision
if self.null_value is not None:
kwargs["null_value"] = self.null_value
if self.quote_style is not None:
kwargs["quote_style"] = self.quote_style
return kwargs

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.sink_csv(self.path, **self._get_writing_kwargs())
metadata = utils.get_file_metadata(self.path)
return metadata

@classmethod
def name(cls) -> str:
return "csv"


@dataclasses.dataclass
class PolarsSinkIPCWriter(DataLoader):
"""
Class specifically to handle writing IPC/Feather files with Polars LazyFrame.
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ipc.html
"""
path: Union[str, Path]
# kwargs:
compression: Optional[str] = "zstd"

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

def _get_writing_kwargs(self):
kwargs = {}
if self.compression is not None:
kwargs["compression"] = self.compression
return kwargs

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.sink_ipc(self.path, **self._get_writing_kwargs())
metadata = utils.get_file_metadata(self.path)
return metadata

@classmethod
def name(cls) -> str:
return "ipc"


@dataclasses.dataclass
class PolarsSinkNDJSONWriter(DataLoader):
"""
Class specifically to handle writing NDJSON files with Polars LazyFrame.
Should map to https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.sink_ndjson.html
Note: Load support for NDJSON is not yet implemented.
"""
path: Union[str, Path]

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
data.sink_ndjson(self.path)
metadata = utils.get_file_metadata(self.path)
return metadata

@classmethod
def name(cls) -> str:
return "ndjson"



def register_data_loaders():
"""Function to register the data loaders for this extension."""
for loader in [
Expand All @@ -308,3 +468,17 @@ def register_data_loaders():


register_data_loaders()


def register_data_savers():
"""Function to register the data savers for this extension."""
for saver in [
PolarsSinkParquetWriter,
PolarsSinkCSVWriter,
PolarsSinkIPCWriter,
PolarsSinkNDJSONWriter,
]:
registry.register_adapter(saver)


register_data_savers()