torch_geometric.nn.models.captum_output_to_dicts
- class captum_output_to_dicts(captum_attrs: Tuple[Tensor, ...], mask_type: Union[str, MaskLevelType], metadata: Tuple[List[str], List[Tuple[str, str, str]]])[source]
Bases:
Convert the output of Captum attribution methods which is a tuple of attributions to two dictionaries with node and edge attribution tensors. This function is used while explaining
HeteroDataobjects. Seeto_captum_model()for example usage.- Parameters
captum_attrs (tuple[torch.Tensor]) – The output of attribution methods.
mask_type (str) –
Denotes the type of mask to be created with a Captum explainer. Valid inputs are
"edge","node", and"node_and_edge":"edge":captum_attrscontains only edge attributions. The returned tuple has no node attributions, and an edge attribution dictionary edge types as keys and edge mask tensors of shape[num_edges]as values."node":captum_attrscontains only node attributions. The returned tuple has a node attribution dictionary with node types as keys and node mask tensors of shape[num_nodes, num_features]as values, and no edge attributions."node_and_edge":captum_attrscontains node andedge attributions.
metadata (Metadata) – The metadata of the heterogeneous graph.