Source code for torch_geometric.loader.prefetch

from contextlib import nullcontext
from functools import partial
from typing import Any, Optional

import torch
from torch.utils.data import DataLoader


[docs]class PrefetchLoader: r"""A GPU prefetcher class for asynchronously transferring data of a :class:`torch.utils.data.DataLoader` from host memory to device memory. Args: loader (torch.utils.data.DataLoader): The data loader. device (torch.device, optional): The device to load the data to. (default: :obj:`None`) """ def __init__( self, loader: DataLoader, device: Optional[torch.device] = None, ): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.loader = loader self.device = torch.device(device) self.is_cuda = torch.cuda.is_available() and self.device.type == 'cuda' def non_blocking_transfer(self, batch: Any) -> Any: if not self.is_cuda: return batch if isinstance(batch, (list, tuple)): return [self.non_blocking_transfer(v) for v in batch] if isinstance(batch, dict): return {k: self.non_blocking_transfer(v) for k, v in batch.items()} batch = batch.pin_memory() return batch.to(self.device, non_blocking=True) def __iter__(self) -> Any: first = True if self.is_cuda: stream = torch.cuda.Stream() stream_context = partial(torch.cuda.stream, stream=stream) else: stream = None stream_context = nullcontext for next_batch in self.loader: with stream_context(): next_batch = self.non_blocking_transfer(next_batch) if not first: yield batch # noqa else: first = False if stream is not None: torch.cuda.current_stream().wait_stream(stream) batch = next_batch yield batch def __len__(self) -> int: return len(self.loader) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.loader})'