torch_geometric.data.lightning.LightningDataset
- class LightningDataset(train_dataset: Dataset, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, pred_dataset: Optional[Dataset] = None, **kwargs)[source]
Bases:
LightningDataModuleConverts a set of
Datasetobjects into apytorch_lightning.LightningDataModulevariant. It can then be automatically used as adatamodulefor multi-GPU graph-level training via PyTorch Lightning.LightningDatasetwill take care of providing mini-batches viaDataLoader.Note
Currently only the
pytorch_lightning.strategies.SingleDeviceStrategyandpytorch_lightning.strategies.DDPStrategytraining strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule)
- Parameters
train_dataset (Dataset) – The training dataset.
val_dataset (Dataset, optional) – The validation dataset. (default:
None)test_dataset (Dataset, optional) – The test dataset. (default:
None)pred_dataset (Dataset, optional) – The prediction dataset. (default:
None)**kwargs (optional) – Additional arguments of
torch_geometric.loader.DataLoader.