Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

## [Unreleased](https://github.com/unit8co/darts/tree/master)

- Added utility functions for Huggingface Hub integration. Upload/download Darts TimeSeries and ForecastingModel instances. [#2201](https://github.com/unit8co/darts/pull/2201)
by [Ivelin Ivanov](https://github.com/ivelin).


[Full Changelog](https://github.com/unit8co/darts/compare/0.27.2...master)

### For users of the library:
Expand Down
93 changes: 93 additions & 0 deletions darts/utils/hfhub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pandas as pd
from dotenv import load_dotenv
import os
import tempfile
from typing import Optional
from darts import TimeSeries
from darts.models.forecasting.forecasting_model import ForecastingModel
from huggingface_hub import snapshot_download, upload_folder, create_repo


class HFHub:
"""
HuggingFace Hub integration using official HF API.
https://huggingface.co/docs/huggingface_hub/v0.20.3/en/guides/integrations
"""

def __init__(self, api_key: Optional[str] = None):
if api_key is None:
# load from .env file or OS vars if available
load_dotenv(override=True)
api_key = os.getenv("HF_TOKEN")
assert (
api_key is not None
), "Could not find HF_TOKEN in OS environment. Cannot interact with HF Hub."
self.HF_TOKEN = api_key

def upload_model(
self,
repo_id: str = None,
model: ForecastingModel = None,
private: Optional[bool] = True,
):
# Create repo if not existing yet and get the associated repo_id
create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)

with tempfile.TemporaryDirectory() as tmpdirname:
# print("created temporary directory for model", tmpdirname)
model.save(path=f"{tmpdirname}/{model.model_name}")
upload_folder(repo_id=repo_id, folder_path=tmpdirname, token=self.HF_TOKEN)

def download_model(
self,
repo_id: str = None,
model_name: str = None,
model_class: object = None,
) -> ForecastingModel:
with tempfile.TemporaryDirectory() as tmpdirname:
snapshot_download(
repo_id=repo_id, local_dir=tmpdirname, token=self.HF_TOKEN
)
model = model_class.load(path=f"{tmpdirname}/{model_name}")
return model

def upload_timeseries(
self,
repo_id: str = None,
series: TimeSeries = None,
series_name: str = None,
private: Optional[bool] = True,
):
# Create repo if not existing yet and get the associated repo_id
repo_info = create_repo(
repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True
)
# print(f"repo_info: ", repo_info)
df = series.pd_dataframe()
with tempfile.TemporaryDirectory() as tmpdirname:
df.to_parquet(path=f"{tmpdirname}/{series_name}.parquet")
upload_folder(
repo_id=repo_id,
repo_type="dataset",
folder_path=tmpdirname,
token=self.HF_TOKEN,
)

def download_timeseries(
self,
repo_id: str = None,
series_name: str = None,
) -> TimeSeries:
with tempfile.TemporaryDirectory() as tmpdirname:
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
local_dir=tmpdirname,
token=self.HF_TOKEN,
)
print(os.listdir(tmpdirname))
df = pd.read_parquet(
f"{tmpdirname}/{series_name}.parquet", engine="pyarrow"
)
ts = TimeSeries.from_dataframe(df)
return ts