torch_geometric.transforms.RemoveTrainingClasses
- class RemoveTrainingClasses(classes: List[int])[source]
Bases:
BaseTransformRemoves classes from the node-level training set as given by
data.train_mask, e.g., in order to get a zero-shot label scenario (functional name:remove_training_classes).- Parameters
classes (List[int]) – The classes to remove from the training set.