diff --git a/src/hyperactive/experiment/integrations/torch_lightning_experiment.py b/src/hyperactive/experiment/integrations/torch_lightning_experiment.py index 0bdd1f80..f0407c24 100644 --- a/src/hyperactive/experiment/integrations/torch_lightning_experiment.py +++ b/src/hyperactive/experiment/integrations/torch_lightning_experiment.py @@ -21,13 +21,17 @@ class TorchExperiment(BaseExperiment): Parameters ---------- - datamodule : L.LightningDataModule - A PyTorch Lightning DataModule that handles data loading and preparation. + data_module : type + A PyTorch Lightning DataModule class (not an instance) that + handles data loading and preparation. It will be instantiated + with hyperparameters during optimization. lightning_module : type A PyTorch Lightning Module class (not an instance) that will be instantiated with hyperparameters during optimization. trainer_kwargs : dict, optional (default=None) A dictionary of keyword arguments to pass to the PyTorch Lightning Trainer. + dm_kwargs : dict, optional (default=None) + A dictionary of keyword arguments to pass to the Data Module upon instantiation. objective_metric : str, optional (default='val_loss') The metric used to evaluate the model's performance. This should correspond to a metric logged in the LightningModule during validation. @@ -93,14 +97,12 @@ class TorchExperiment(BaseExperiment): ... def val_dataloader(self): ... return DataLoader(self.val, batch_size=self.batch_size) >>> - >>> datamodule = RandomDataModule(batch_size=16) - >>> datamodule.setup() - >>> >>> # Create Experiment >>> experiment = TorchExperiment( - ... datamodule=datamodule, + ... data_module=RandomDataModule, ... lightning_module=SimpleLightningModule, ... trainer_kwargs={'max_epochs': 3}, + ... dm_kwargs={'batch_size': 16}, ... objective_metric="val_loss" ... ) >>> @@ -118,14 +120,16 @@ class TorchExperiment(BaseExperiment): def __init__( self, - datamodule, + data_module, lightning_module, trainer_kwargs=None, + dm_kwargs=None, objective_metric: str = "val_loss", ): - self.datamodule = datamodule + self.data_module = data_module self.lightning_module = lightning_module self.trainer_kwargs = trainer_kwargs or {} + self.dm_kwargs = dm_kwargs or {} self.objective_metric = objective_metric super().__init__() @@ -174,7 +178,8 @@ def _evaluate(self, params): try: model = self.lightning_module(**params) trainer = L.Trainer(**self._trainer_kwargs) - trainer.fit(model, self.datamodule) + data = self.data_module(**self.dm_kwargs) + trainer.fit(model, data) val_result = trainer.callback_metrics.get(self.objective_metric) metadata = {} @@ -265,10 +270,8 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.val, batch_size=self.batch_size) - datamodule = RandomDataModule(batch_size=16) - params = { - "datamodule": datamodule, + "data_module": RandomDataModule, "lightning_module": SimpleLightningModule, "trainer_kwargs": { "max_epochs": 1, @@ -276,6 +279,7 @@ def val_dataloader(self): "enable_model_summary": False, "logger": False, }, + "dm_kwargs": {"batch_size": 16}, "objective_metric": "val_loss", } @@ -339,10 +343,8 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.val, batch_size=self.batch_size) - datamodule2 = RegressionDataModule(batch_size=16, num_samples=150) - params2 = { - "datamodule": datamodule2, + "data_module": RegressionDataModule, "lightning_module": RegressionModule, "trainer_kwargs": { "max_epochs": 1, @@ -350,6 +352,7 @@ def val_dataloader(self): "enable_model_summary": False, "logger": False, }, + "dm_kwargs": {"batch_size": 8, "num_samples": 200}, "objective_metric": "val_loss", } @@ -370,4 +373,5 @@ def _get_score_params(cls): """ score_params1 = {"input_dim": 10, "hidden_dim": 20, "lr": 0.001} score_params2 = {"num_layers": 3, "hidden_size": 64, "dropout": 0.2} + return [score_params1, score_params2]