torch_geometric.nn.aggr.QuantileAggregation
- class QuantileAggregation(q: Union[float, List[float]], interpolation: str = 'linear', fill_value: float = 0.0)[source]
Bases:
AggregationAn aggregation operator that returns the feature-wise \(q\)-th quantile of a set \(\mathcal{X}\). That is, for every feature \(d\), it computes
\[\begin{split}{\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} x_{\pi_i,d} & i = q \cdot n, \\ f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ \end{cases}\end{split}\]where \(x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le x_{\pi_n,d}\) and \(f(a, b)\) is an interpolation function defined by
interpolation.- Parameters
q (float or list) – The quantile value(s) \(q\). Can be a scalar or a list of scalars in the range \([0, 1]\). If more than a quantile is passed, the results are concatenated.
interpolation (str) –
Interpolation method applied if the quantile point \(q\cdot n\) lies between two values \(a \le b\). Can be one of the following:
"lower": Returns the one with lowest value."higher": Returns the one with highest value."midpoint": Returns the average of the two values."nearest": Returns the one whose index is nearest to the quantile point."linear": Returns a linear combination of the two elements, defined as \(f(a, b) = a + (b - a)\cdot(q\cdot n - i)\).
(default:
"linear")fill_value (float, optional) – The default value in the case no entry is found for a given index (default:
0.0).
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]
- Parameters
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
indexorptrmust be defined. (default:None)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
indexorptrmust be defined. (default:None)dim_size (int, optional) – The size of the output tensor at dimension
dimafter aggregation. (default:None)dim (int, optional) – The dimension in which to aggregate. (default:
-2)