torch_geometric.nn.models.SchNet
- class SchNet(hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[Tensor] = None)[source]
Bases:
ModuleThe continuous-filter convolutional neural network SchNet from the “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions” paper that uses the interactions blocks of the form
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]here \(h_{\mathbf{\Theta}}\) denotes an MLP and \(\mathbf{e}_{j,i}\) denotes the interatomic distances between atoms.
Note
For an example of using a pretrained SchNet variant, see examples/qm9_pretrained_schnet.py.
- Parameters
hidden_channels (int, optional) – Hidden embedding size. (default:
128)num_filters (int, optional) – The number of filters to use. (default:
128)num_interactions (int, optional) – The number of interaction blocks. (default:
6)num_gaussians (int, optional) – The number of gaussians \(\mu\). (default:
50)interaction_graph (callable, optional) – The function used to compute the pairwise interaction graph and interatomic distances. If set to
None, will construct a graph based oncutoffandmax_num_neighborsproperties. If provided, this method takes inposandbatchtensors and should return(edge_index, edge_weight)tensors. (defaultNone)cutoff (float, optional) – Cutoff distance for interatomic interactions. (default:
10.0)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)readout (str, optional) – Whether to apply
"add"or"mean"global aggregation. (default:"add")dipole (bool, optional) – If set to
True, will use the magnitude of the dipole moment to make the final prediction, e.g., for target 0 oftorch_geometric.datasets.QM9. (default:False)mean (float, optional) – The mean of the property to predict. (default:
None)std (float, optional) – The standard deviation of the property to predict. (default:
None)atomref (torch.Tensor, optional) – The reference of single-atom properties. Expects a vector of shape
(max_atomic_number, ).
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]
- Parameters
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms].pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3].batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]. (default:None)