torch_geometric.transforms.RandomNodeSplit
- class RandomNodeSplit(split: str = 'train_rest', num_splits: int = 1, num_train_per_class: int = 20, num_val: Union[int, float] = 500, num_test: Union[int, float] = 1000, key: Optional[str] = 'y')[source]
Bases:
BaseTransformPerforms a node-level random split by adding
train_mask,val_maskandtest_maskattributes to theDataorHeteroDataobject (functional name:random_node_split).- Parameters
split (str, optional) – The type of dataset split (
"train_rest","test_rest","random"). If set to"train_rest", all nodes except those in the validation and test sets will be used for training (as in the “FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling” paper). If set to"test_rest", all nodes except those in the training and validation sets will be used for test (as in the “Pitfalls of Graph Neural Network Evaluation” paper). If set to"random", train, validation, and test sets will be randomly generated, according tonum_train_per_class,num_valandnum_test(as in the “Semi-supervised Classification with Graph Convolutional Networks” paper). (default:"train_rest")num_splits (int, optional) – The number of splits to add. If bigger than
1, the shape of masks will be[num_nodes, num_splits], and[num_nodes]otherwise. (default:1)num_train_per_class (int, optional) – The number of training samples per class in case of
"test_rest"and"random"split. (default:20)num_val (int or float, optional) – The number of validation samples. If float, it represents the ratio of samples to include in the validation set. (default:
500)num_test (int or float, optional) – The number of test samples in case of
"train_rest"and"random"split. If float, it represents the ratio of samples to include in the test set. (default:1000)key (str, optional) – The name of the attribute holding ground-truth labels. By default, will only add node-level splits for node-level storages in which
keyis present. (default:"y").