torch_geometric.explain.metric.groundtruth_metrics
- groundtruth_metrics(pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5) Union[float, Tuple[float, ...]][source]
Compares and evaluates an explanation mask with the ground-truth explanation mask.
- Parameters
pred_mask (torch.Tensor) – The prediction mask to evaluate.
target_mask (torch.Tensor) – The ground-truth target mask.
metrics (str or List[str], optional) – (
"accuracy","recall","precision","f1_score","auroc"). (default:["accuracy", "recall", "precision", "f1_score", "auroc"])threshold (float, optional) – The threshold value to perform hard thresholding of
maskandgroundtruth. (default:0.5)