Skip to content

Conversation

@daidahao
Copy link
Contributor

@daidahao daidahao commented Sep 18, 2025

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Closes #2318.

Summary

Add scale_batch_size() to TorchForecastingModel:

  • A wrapped around Lightning Tuner's method of the same name, scale_batch_size() finds the largest batch size before out-of-memory error.
  • Options for Tuner method are supported, including mode, steps_per_trial, init_val, and max_trials.
  • Tuner requires a batch_size attribute within LightningDataModule or model and disallows previous train_loader and val_loader. Because of that, I implemented _CustomDataModule to return dataloaders as per batch_size.
  • The previous behaviours of dataloader_kwargs are being preserved with the new datamodules.
  • Update _setup_for_train(), _train(), fit_from_dataset(), lr_find() methods to use datamodules instead of direct data loaders.

Testing:

  • Add test_scale_batch_size for validating scale_batch_size() method.
  • Add test_scale_batch_size_no_updates for validating that batch size scaling does not update model weights.
  • Update test_dataloader_kwargs_setup to validate datamodule instead of train_dataloaders and val_dataloaders due to changes.
  • Update helper_check_val_set used in test_val_set to again validate datamodule.

Other Information

Should we remove arguments like val_*, sample_weight from scale_batch_size()? They do not affect the batch size scaling and removing them could simplify the method call.

- A wrapped around Lightning Tuner's method of the same name,
  `scale_batch_size()` finds a batch size before out-of-memory error.
- Options for Tuner method are supported, including `mode`,
  `steps_per_trial`, `init_val`, and `max_trials`.
- Tuner requires a `batch_size` attribute within `LightningDataModule`
  or model and disallows previous `train_loader` and `val_loader`.
- Because of that, I implemented `_CustomDataModule` and
  `_CustomDataModuleWithVal` to return dataloaders as per `batch_size`.
- The previous behaviours of `dataloader_kwargs` are being preserved
  with the new datamodules.
- Update `_setup_for_train()`, `_train()`, `fit_from_dataset()`,
  `lr_find()` methods to use datamodules instead of direct data loaders.
- Add `test_scale_batch_size` for validating `scale_batch_size()`
  method.
- Update `test_dataloader_kwargs_setup` to validate `datamodule`
  instead of `train_dataloaders` and `val_dataloaders` due to changes.
- Update `helper_check_val_set` used in `test_val_set` to again validate
  `datamodule`.
@codecov
Copy link

codecov bot commented Sep 18, 2025

Codecov Report

❌ Patch coverage is 96.96970% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 95.30%. Comparing base (2274e94) to head (7ede1c9).

Files with missing lines Patch % Lines
...arts/models/forecasting/torch_forecasting_model.py 96.96% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2905      +/-   ##
==========================================
- Coverage   95.37%   95.30%   -0.08%     
==========================================
  Files         146      146              
  Lines       15656    15681      +25     
==========================================
+ Hits        14932    14944      +12     
- Misses        724      737      +13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

- When `val_dataset` is `None`, `_CustomDataModule` would still need to
  implement `val_dataloader()` for Lightning to work.
- Since `val_dataloader()` can return ANY iterable but not `None` as per
  Lightning `EVAL_DATALOADERS`, we return an empty list here.
- Batch size scaling would not update the model weights, so there is no
  need to re-initialize the model after scaling.
As per previous commit, batch size scaling would not update model
weights. The model can be used for training directly.

- Update `test_scale_batch_size()` to NOT re-initialize the model after
  scaling.
- Add `test_scale_batch_size_no_updates()` to validate that the model
  weights do not change after scaling.
@dennisbader
Copy link
Collaborator

Thanks for this PR @daidahao . I remember when experimenting with the other PR and the feature in general, that I wasn't fully convinced of the functionality.
Even if it might give us the largest batch size possible, I'm not sure that training on that batch size is always what we want in the end (on most use cases it just ended up in having all training data in one batch).

Have you used it yourself and if yes, what's your perspective on it? I'm happy to discuss :)

@daidahao
Copy link
Contributor Author

daidahao commented Sep 25, 2025

Hi Dennis @dennisbader , I think it would depend on the uses cases, particularly the dataset size and the hardware. The main benefit of scaling the batch size is to maximise the GPU usage, when the dataset is too large to fit into one batch and could take very long time to train. In those cases, we often scale the batch size manually by power of 2 until the GPU is fully utilised (~100%) to speed up training, similar to what the feature could do.

You are right in that in many cases, scaling batch size is less helpful because many datasets are simply too small to see the benefits. Even in cases of larger datasets, the user might not benefit from this functionality if they are using only CPUs or less powerful GPUs. A small batch size would easily saturate the hardware.

The issue is that there is no tuner from Lightning that could scale the batch size by the CPU/GPU utilisation. This tuner provides a close approximation by reaching the out-of-memory error.

To your second question, we are now using Darts on a large dataset with 1M+ time points and often find ourselves manually tuning the batch size whenever a hyper-parameter has been changed. I could see the implementation of this feature benefiting our model training process greatly. But unlike mixed precision or skip-interpolation TFT, we could not easily overwrite the methods on an older version of Darts (from conda). I will try to update to a newer version and test this feature in the coming days.

Publicly, what would be sensible is perhaps to find a large public dataset for testing this feature and best reproducibility. I reckon that datasets like AirPassengersDataset is only a few hundred time points and too small for testing. What would be your suggestions on this?

We could reuse the script from #2898 for benchmarking.

import logging
import time

import numpy as np
import pandas as pd

from darts import TimeSeries
from darts.models import TFTModel

logging.basicConfig(level=logging.INFO)

# Load the dataset
series, future_covariates = ...

# Split the dataset into training and validation sets
train_series, val_series = series.split_after(0.8)

model_kwargs = {
    "batch_size": 512,
    "n_epochs": 1,
    "pl_trainer_kwargs": {
        "accelerator": "gpu",
    },
    "optimizer_kwargs": {"lr": 1e-2},
}
fit_kwargs = {
    "dataloader_kwargs": {
        "num_workers": 0,
    }
}

# Define the TFT model
model = TFTModel(
    input_chunk_length=100,
    output_chunk_length=100,
    skip_interpolation=True, # Change to `True` for speed-up
    **model_kwargs
)

# Train the TFT model
start_time = time.time()
model.fit(
    train_series,
    future_covariates=future_covariates,
    val_series=val_series,
    val_future_covariates=future_covariates,
    verbose=True,
    **fit_kwargs
)
print(f"Training time: {time.time() - start_time:.4f} seconds")

# Test the TFT model
# TODO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants