torch_geometric.nn.norm.HeteroBatchNorm
- class HeteroBatchNorm(in_channels: int, num_types: int, eps: float = 1e-05, momentum: Optional[float] = 0.1, affine: bool = True, track_running_stats: bool = True)[source]
Bases:
ModuleApplies batch normalization over a batch of heterogeneous features as described in the “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift” paper. Compared to
BatchNorm,HeteroBatchNormapplies normalization individually for each node or edge type.- Parameters
in_channels (int) – Size of each input sample.
num_types (int) – The number of types.
eps (float, optional) – A value added to the denominator for numerical stability. (default:
1e-5)momentum (float, optional) – The value used for the running mean and running variance computation. (default:
0.1)affine (bool, optional) – If set to
True, this module has learnable affine parameters \(\gamma\) and \(\beta\). (default:True)track_running_stats (bool, optional) – If set to
True, this module tracks the running mean and variance, and when set toFalse, this module does not track such statistics and always uses batch statistics in both training and eval modes. (default:True)
- forward(x: Tensor, type_vec: Tensor) Tensor[source]
- Parameters
x (torch.Tensor) – The input features.
type_vec (torch.Tensor) – A vector that maps each entry to a type.