Spaces:
Sleeping
Sleeping
from collections.abc import Mapping, Sequence | |
from typing import List, Optional, Union | |
import torch.utils.data | |
from torch.utils.data.dataloader import default_collate | |
from torch_geometric.data import Batch, Dataset | |
from torch_geometric.data.data import BaseData | |
class Collater: | |
def __init__(self, follow_batch, exclude_keys): | |
self.follow_batch = follow_batch | |
self.exclude_keys = exclude_keys | |
def __call__(self, batch): | |
batch = [x for x in batch if x is not None] | |
elem = batch[0] | |
if isinstance(elem, BaseData): | |
return Batch.from_data_list(batch, self.follow_batch, | |
self.exclude_keys) | |
elif isinstance(elem, torch.Tensor): | |
return default_collate(batch) | |
elif isinstance(elem, float): | |
return torch.tensor(batch, dtype=torch.float) | |
elif isinstance(elem, int): | |
return torch.tensor(batch) | |
elif isinstance(elem, str): | |
return batch | |
elif isinstance(elem, Mapping): | |
return {key: self([data[key] for data in batch]) for key in elem} | |
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): | |
return type(elem)(*(self(s) for s in zip(*batch))) | |
elif isinstance(elem, Sequence) and not isinstance(elem, str): | |
return [self(s) for s in zip(*batch)] | |
raise TypeError(f'DataLoader found invalid type: {type(elem)}') | |
def collate(self, batch): # Deprecated... | |
return self(batch) | |
class DataLoader(torch.utils.data.DataLoader): | |
r"""A data loader which merges data objects from a | |
:class:`torch_geometric.data.Dataset` to a mini-batch. | |
Data objects can be either of type :class:`~torch_geometric.data.Data` or | |
:class:`~torch_geometric.data.HeteroData`. | |
Args: | |
dataset (Dataset): The dataset from which to load the data. | |
batch_size (int, optional): How many samples per batch to load. | |
(default: :obj:`1`) | |
shuffle (bool, optional): If set to :obj:`True`, the data will be | |
reshuffled at every epoch. (default: :obj:`False`) | |
follow_batch (List[str], optional): Creates assignment batch | |
vectors for each key in the list. (default: :obj:`None`) | |
exclude_keys (List[str], optional): Will exclude each key in the | |
list. (default: :obj:`None`) | |
**kwargs (optional): Additional arguments of | |
:class:`torch.utils.data.DataLoader`. | |
""" | |
def __init__( | |
self, | |
dataset: Union[Dataset, List[BaseData]], | |
batch_size: int = 1, | |
shuffle: bool = False, | |
follow_batch: Optional[List[str]] = None, | |
exclude_keys: Optional[List[str]] = None, | |
**kwargs, | |
): | |
if 'collate_fn' in kwargs: | |
del kwargs['collate_fn'] | |
# Save for PyTorch Lightning: | |
self.follow_batch = follow_batch | |
self.exclude_keys = exclude_keys | |
super().__init__( | |
dataset, | |
batch_size, | |
shuffle, | |
collate_fn=Collater(follow_batch, exclude_keys), | |
**kwargs, | |
) | |
def collate_fn(data_list): | |
data_list = [x for x in data_list if x is not None] | |
return data_list | |
class DataListLoader(torch.utils.data.DataLoader): | |
def __init__(self, dataset: Union[Dataset, List[BaseData]], | |
batch_size: int = 1, shuffle: bool = False, **kwargs): | |
if 'collate_fn' in kwargs: | |
del kwargs['collate_fn'] | |
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, | |
collate_fn=collate_fn, **kwargs) | |