torch_geometric.nn.models.to_captum_input
- class to_captum_input(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], mask_type: Union[str, MaskLevelType], *args)[source]
Bases:
Given
x,edge_indexandmask_type, converts it to a format to use in Captum attribution methods. Returnsinputsandadditional_forward_argsrequired for Captum’sattributefunctions. Seeto_captum_model()for example usage.- Parameters
x (torch.Tensor or Dict[NodeType, torch.Tensor]) – The node features. For heterogeneous graphs this is a dictionary holding node featues for each node type.
edge_index (torch.Tensor or Dict[EdgeType, torch.Tensor]) – The edge indices. For heterogeneous graphs this is a dictionary holding the
edge indexfor each edge type.mask_type (str) – Denotes the type of mask to be created with a Captum explainer. Valid inputs are
"edge","node", and"node_and_edge".*args – Additional forward arguments of the model being explained which will be added to
additional_forward_args.