torch_geometric.data.lightning.LightningNodeData
- class LightningNodeData(data: Union[Data, HeteroData], input_train_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_train_time: Optional[Tensor] = None, input_val_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_val_time: Optional[Tensor] = None, input_test_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_test_time: Optional[Tensor] = None, input_pred_nodes: Union[Tensor, None, str, Tuple[str, Optional[Tensor]]] = None, input_pred_time: Optional[Tensor] = None, loader: str = 'neighbor', node_sampler: Optional[BaseSampler] = None, eval_loader_kwargs: Optional[Dict[str, Any]] = None, **kwargs)[source]
Bases:
LightningDataConverts a
DataorHeteroDataobject into apytorch_lightning.LightningDataModulevariant. It can then be automatically used as adatamodulefor multi-GPU node-level training via PyTorch Lightning.LightningDatasetwill take care of providing mini-batches viaNeighborLoader.Note
Currently only the
pytorch_lightning.strategies.SingleDeviceStrategyandpytorch_lightning.strategies.DDPStrategytraining strategies of PyTorch Lightning are supported in order to correctly share data across all devices/processes:import pytorch_lightning as pl trainer = pl.Trainer(strategy="ddp_spawn", accelerator="gpu", devices=4) trainer.fit(model, datamodule)
- Parameters
data (Data or HeteroData) – The
DataorHeteroDatagraph object.input_train_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of training nodes. If not given, will try to automatically infer them from the
dataobject by searching fortrain_mask,train_idx, ortrain_indexattributes. (default:None)input_train_time (torch.Tensor, optional) – The timestamp of training nodes. (default:
None)input_val_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of validation nodes. If not given, will try to automatically infer them from the
dataobject by searching forval_mask,valid_mask,val_idx,valid_idx,val_index, orvalid_indexattributes. (default:None)input_val_time (torch.Tensor, optional) – The timestamp of validation edges. (default:
None)input_test_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of test nodes. If not given, will try to automatically infer them from the
dataobject by searching fortest_mask,test_idx, ortest_indexattributes. (default:None)input_test_time (torch.Tensor, optional) – The timestamp of test nodes. (default:
None)input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)) – The indices of prediction nodes. If not given, will try to automatically infer them from the
dataobject by searching forpred_mask,pred_idx, orpred_indexattributes. (default:None)input_pred_time (torch.Tensor, optional) – The timestamp of prediction nodes. (default:
None)loader (str) – The scalability technique to use (
"full","neighbor"). (default:"neighbor")node_sampler (BaseSampler, optional) – A custom sampler object to generate mini-batches. If set, will ignore the
loaderoption. (default:None)eval_loader_kwargs (Dict[str, Any], optional) – Custom keyword arguments that override the
torch_geometric.loader.NeighborLoaderconfiguration during evaluation. (default:None)**kwargs (optional) – Additional arguments of
torch_geometric.loader.NeighborLoader.