- 
                Notifications
    You must be signed in to change notification settings 
- Fork 973
          Feature/ Find largest batch size for TorchForecastingModel
          #2905
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
- A wrapped around Lightning Tuner's method of the same name, `scale_batch_size()` finds a batch size before out-of-memory error. - Options for Tuner method are supported, including `mode`, `steps_per_trial`, `init_val`, and `max_trials`. - Tuner requires a `batch_size` attribute within `LightningDataModule` or model and disallows previous `train_loader` and `val_loader`. - Because of that, I implemented `_CustomDataModule` and `_CustomDataModuleWithVal` to return dataloaders as per `batch_size`. - The previous behaviours of `dataloader_kwargs` are being preserved with the new datamodules. - Update `_setup_for_train()`, `_train()`, `fit_from_dataset()`, `lr_find()` methods to use datamodules instead of direct data loaders.
- Add `test_scale_batch_size` for validating `scale_batch_size()` method. - Update `test_dataloader_kwargs_setup` to validate `datamodule` instead of `train_dataloaders` and `val_dataloaders` due to changes. - Update `helper_check_val_set` used in `test_val_set` to again validate `datamodule`.
| Codecov Report❌ Patch coverage is  
 Additional details and impacted files@@            Coverage Diff             @@
##           master    #2905      +/-   ##
==========================================
- Coverage   95.37%   95.30%   -0.08%     
==========================================
  Files         146      146              
  Lines       15656    15681      +25     
==========================================
+ Hits        14932    14944      +12     
- Misses        724      737      +13     ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
 | 
- When `val_dataset` is `None`, `_CustomDataModule` would still need to implement `val_dataloader()` for Lightning to work. - Since `val_dataloader()` can return ANY iterable but not `None` as per Lightning `EVAL_DATALOADERS`, we return an empty list here. - Batch size scaling would not update the model weights, so there is no need to re-initialize the model after scaling.
As per previous commit, batch size scaling would not update model weights. The model can be used for training directly. - Update `test_scale_batch_size()` to NOT re-initialize the model after scaling. - Add `test_scale_batch_size_no_updates()` to validate that the model weights do not change after scaling.
| Thanks for this PR @daidahao . I remember when experimenting with the other PR and the feature in general, that I wasn't fully convinced of the functionality. Have you used it yourself and if yes, what's your perspective on it? I'm happy to discuss :) | 
| Hi Dennis @dennisbader , I think it would depend on the uses cases, particularly the dataset size and the hardware. The main benefit of scaling the batch size is to maximise the GPU usage, when the dataset is too large to fit into one batch and could take very long time to train. In those cases, we often scale the batch size manually by power of 2 until the GPU is fully utilised (~100%) to speed up training, similar to what the feature could do. You are right in that in many cases, scaling batch size is less helpful because many datasets are simply too small to see the benefits. Even in cases of larger datasets, the user might not benefit from this functionality if they are using only CPUs or less powerful GPUs. A small batch size would easily saturate the hardware. The issue is that there is no tuner from Lightning that could scale the batch size by the CPU/GPU utilisation. This tuner provides a close approximation by reaching the out-of-memory error. To your second question, we are now using Darts on a large dataset with 1M+ time points and often find ourselves manually tuning the batch size whenever a hyper-parameter has been changed. I could see the implementation of this feature benefiting our model training process greatly. But unlike mixed precision or skip-interpolation TFT, we could not easily overwrite the methods on an older version of Darts (from conda). I will try to update to a newer version and test this feature in the coming days. Publicly, what would be sensible is perhaps to find a large public dataset for testing this feature and best reproducibility. I reckon that datasets like  We could reuse the script from #2898 for benchmarking. import logging
import time
import numpy as np
import pandas as pd
from darts import TimeSeries
from darts.models import TFTModel
logging.basicConfig(level=logging.INFO)
# Load the dataset
series, future_covariates = ...
# Split the dataset into training and validation sets
train_series, val_series = series.split_after(0.8)
model_kwargs = {
    "batch_size": 512,
    "n_epochs": 1,
    "pl_trainer_kwargs": {
        "accelerator": "gpu",
    },
    "optimizer_kwargs": {"lr": 1e-2},
}
fit_kwargs = {
    "dataloader_kwargs": {
        "num_workers": 0,
    }
}
# Define the TFT model
model = TFTModel(
    input_chunk_length=100,
    output_chunk_length=100,
    skip_interpolation=True, # Change to `True` for speed-up
    **model_kwargs
)
# Train the TFT model
start_time = time.time()
model.fit(
    train_series,
    future_covariates=future_covariates,
    val_series=val_series,
    val_future_covariates=future_covariates,
    verbose=True,
    **fit_kwargs
)
print(f"Training time: {time.time() - start_time:.4f} seconds")
# Test the TFT model
# TODO | 
Co-authored-by: Zhihao Dai <[email protected]>
Checklist before merging this PR:
Closes #2318.
Summary
Add
scale_batch_size()toTorchForecastingModel:scale_batch_size()finds the largest batch size before out-of-memory error.mode,steps_per_trial,init_val, andmax_trials.batch_sizeattribute withinLightningDataModuleor model and disallows previoustrain_loaderandval_loader. Because of that, I implemented_CustomDataModuleto return dataloaders as perbatch_size.dataloader_kwargsare being preserved with the new datamodules._setup_for_train(),_train(),fit_from_dataset(),lr_find()methods to use datamodules instead of direct data loaders.Testing:
test_scale_batch_sizefor validatingscale_batch_size()method.test_scale_batch_size_no_updatesfor validating that batch size scaling does not update model weights.test_dataloader_kwargs_setupto validatedatamoduleinstead oftrain_dataloadersandval_dataloadersdue to changes.helper_check_val_setused intest_val_setto again validatedatamodule.Other Information
Should we remove arguments like
val_*,sample_weightfromscale_batch_size()? They do not affect the batch size scaling and removing them could simplify the method call.