torch_geometric.nn.models.DimeNetPlusPlus
- class DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish')[source]
Bases:
DimeNetThe DimeNet++ from the “Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules” paper.
DimeNetPlusPlusis an upgrade to theDimeNetmodel with 8x faster and 10% more accurate thanDimeNet.- Parameters
hidden_channels (int) – Hidden embedding size.
out_channels (int) – Size of each output sample.
num_blocks (int) – Number of building blocks.
int_emb_size (int) – Size of embedding in the interaction block.
basis_emb_size (int) – Size of basis embedding in the interaction block.
out_emb_channels (int) – Size of embedding in the output block.
num_spherical (int) – Number of spherical harmonics.
num_radial (int) – Number of radial basis functions.
cutoff – (float, optional): Cutoff distance for interatomic interactions. (default:
5.0)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)envelope_exponent (int, optional) – Shape of the smooth cutoff. (default:
5)num_before_skip – (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default:
1)num_after_skip – (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default:
2)num_output_layers – (int, optional): Number of linear layers for the output blocks. (default:
3)act – (str or Callable, optional): The activation funtion. (default:
"swish")
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor
- Parameters
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms].pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3].batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]. (default:None)
- reset_parameters()
Resets all learnable parameters of the module.