Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, List | |
import torch | |
from mmpretrain.registry import MODELS | |
from mmpretrain.structures import DataSample | |
from .base import BaseSelfSupervisor | |
class SwAV(BaseSelfSupervisor): | |
"""SwAV. | |
Implementation of `Unsupervised Learning of Visual Features by Contrasting | |
Cluster Assignments <https://arxiv.org/abs/2006.09882>`_. | |
The queue is built in ``mmpretrain/engine/hooks/swav_hook.py``. | |
""" | |
def loss(self, inputs: List[torch.Tensor], data_samples: List[DataSample], | |
**kwargs) -> Dict[str, torch.Tensor]: | |
"""Forward computation during training. | |
Args: | |
inputs (List[torch.Tensor]): The input images. | |
data_samples (List[DataSample]): All elements required | |
during the forward function. | |
Returns: | |
Dict[str, torch.Tensor]: A dictionary of loss components. | |
""" | |
assert isinstance(inputs, list) | |
# multi-res forward passes | |
idx_crops = torch.cumsum( | |
torch.unique_consecutive( | |
torch.tensor([input.shape[-1] for input in inputs]), | |
return_counts=True)[1], 0) | |
start_idx = 0 | |
output = [] | |
for end_idx in idx_crops: | |
_out = self.backbone(torch.cat(inputs[start_idx:end_idx])) | |
output.append(_out) | |
start_idx = end_idx | |
output = self.neck(output) | |
loss = self.head.loss(output) | |
losses = dict(loss=loss) | |
return losses | |