Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# Collate extensions | |
# -------------------------------------------------------- | |
import torch | |
import collections | |
from torch.utils.data._utils.collate import default_collate_fn_map, default_collate_err_msg_format | |
from typing import Callable, Dict, Optional, Tuple, Type, Union, List | |
def cat_collate_tensor_fn(batch, *, collate_fn_map): | |
return torch.cat(batch, dim=0) | |
def cat_collate_list_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): | |
return [item for bb in batch for item in bb] # concatenate all lists | |
cat_collate_fn_map = default_collate_fn_map.copy() | |
cat_collate_fn_map[torch.Tensor] = cat_collate_tensor_fn | |
cat_collate_fn_map[List] = cat_collate_list_fn | |
cat_collate_fn_map[type(None)] = lambda _, **kw: None # When some Nones, simply return a single None | |
def cat_collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): | |
r"""Custom collate function that concatenates stuff instead of stacking them, and handles NoneTypes """ | |
elem = batch[0] | |
elem_type = type(elem) | |
if collate_fn_map is not None: | |
if elem_type in collate_fn_map: | |
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) | |
for collate_type in collate_fn_map: | |
if isinstance(elem, collate_type): | |
return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map) | |
if isinstance(elem, collections.abc.Mapping): | |
try: | |
return elem_type({key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) | |
except TypeError: | |
# The mapping type may not support `__init__(iterable)`. | |
return {key: cat_collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem} | |
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | |
return elem_type(*(cat_collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch))) | |
elif isinstance(elem, collections.abc.Sequence): | |
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. | |
if isinstance(elem, tuple): | |
# Backwards compatibility. | |
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] | |
else: | |
try: | |
return elem_type([cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]) | |
except TypeError: | |
# The sequence type may not support `__init__(iterable)` (e.g., `range`). | |
return [cat_collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] | |
raise TypeError(default_collate_err_msg_format.format(elem_type)) | |