torch_geometric.nn.conv.FusedGATConv
- class FusedGATConv(*args, **kwargs)[source]
Bases:
GATConvThe fused graph attention operator from the “Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective” paper.
FusedGATConvis an optimized version ofGATConvbased on thedgNNpackage that fuses message passing computation for accelerated execution and lower memory footprint.Note
This implementation is based on the
dgNNpackage. See here for instructions on how to install.- forward(x: Tensor, csr: Tuple[Tensor, Tensor], csc: Tuple[Tensor, Tensor], perm: Tensor) Tensor[source]
Runs the forward pass of the module.
- Parameters
x (torch.Tensor) – The node features.
csr ((torch.Tensor, torch.Tensor)) – A tuple containing the CSR representation of a graph, given as a tuple of
(rowptr, col).csc ((torch.Tensor, torch.Tensor)) – A tuple containing the CSC representation of a graph, given as a tuple of
(row, colptr).perm (torch.Tensor) – Permutation tensor to map the CSR representation to the CSC representation.
Note
Use the
to_graph_format()method to obtain the(csr, csc, perm)graph format from an existingedge_indexrepresentation.
- reset_parameters()
Resets all learnable parameters of the module.
- static to_graph_format(edge_index: Tensor, size: Optional[Tuple[int, int]] = None) Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor][source]
Converts an
edge_indexrepresentation of a graph to the desired input format ofFusedGATConv.- Parameters
edge_index (torch.Tensor) – The edge indices.