torch_geometric.nn.norm.DiffGroupNorm
- class DiffGroupNorm(in_channels: int, groups: int, lamda: float = 0.01, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)[source]
Bases:
ModuleThe differentiable group normalization layer from the “Towards Deeper Graph Neural Networks with Differentiable Group Normalization” paper, which normalizes node features group-wise via a learnable soft cluster assignment
\[\mathbf{S} = \text{softmax} (\mathbf{X} \mathbf{W})\]where \(\mathbf{W} \in \mathbb{R}^{F \times G}\) denotes a trainable weight matrix mapping each node into one of \(G\) clusters. Normalization is then performed group-wise via:
\[\mathbf{X}^{\prime} = \mathbf{X} + \lambda \sum_{i = 1}^G \text{BatchNorm}(\mathbf{S}[:, i] \odot \mathbf{X})\]- Parameters
in_channels (int) – Size of each input sample \(F\).
groups (int) – The number of groups \(G\).
lamda (float, optional) – The balancing factor \(\lambda\) between input embeddings and normalized embeddings. (default:
0.01)eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5)momentum (float, optional) – The value used for the running mean and running variance computation. (default:
0.1)affine (bool, optional) – If set to
True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True)track_running_stats (bool, optional) – If set to
True, this module tracks the running mean and variance, and when set toFalse, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True)
- forward(x: Tensor) Tensor[source]
- Parameters
x (torch.Tensor) – The source tensor.
- static group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-05) float[source]
Measures the ratio of inter-group distance over intra-group distance
\[R_{\text{Group}} = \frac{\frac{1}{(C-1)^2} \sum_{i!=j} \frac{1}{|\mathbf{X}_i||\mathbf{X}_j|} \sum_{\mathbf{x}_{iv} \in \mathbf{X}_i } \sum_{\mathbf{x}_{jv^{\prime}} \in \mathbf{X}_j} {\| \mathbf{x}_{iv} - \mathbf{x}_{jv^{\prime}} \|}_2 }{ \frac{1}{C} \sum_{i} \frac{1}{{|\mathbf{X}_i|}^2} \sum_{\mathbf{x}_{iv}, \mathbf{x}_{iv^{\prime}} \in \mathbf{X}_i } {\| \mathbf{x}_{iv} - \mathbf{x}_{iv^{\prime}} \|}_2 }\]where \(\mathbf{X}_i\) denotes the set of all nodes that belong to class \(i\), and \(C\) denotes the total number of classes in
y.