File size: 1,418 Bytes
0b4516f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union

import torch.nn as nn
from torch import Tensor

from mmocr.models.textrecog.encoders import BaseEncoder
from mmocr.registry import MODELS
from mmocr.structures import TextSpottingDataSample


@MODELS.register_module()
class SPTSEncoder(BaseEncoder):
    """SPTS Encoder.

    Args:
        d_backbone (int): Backbone output dimension.
        d_model (int): Model output dimension.
        init_cfg (dict or list[dict], optional): Initialization configs.
            Defaults to None.
    """

    def __init__(self,
                 d_backbone: int = 2048,
                 d_model: int = 256,
                 init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
        super().__init__(init_cfg=init_cfg)
        self.input_proj = nn.Conv2d(d_backbone, d_model, kernel_size=1)

    def forward(self,
                feat: Tensor,
                data_samples: List[TextSpottingDataSample] = None) -> Tensor:
        """Forward propagation of encoder.

        Args:
            feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
            data_samples (list[TextSpottingDataSample]): Batch of
                TextSpottingDataSample.
                Defaults to None.

        Returns:
            Tensor: A tensor of shape :math:`(N, T, D_m)`.
        """
        return self.input_proj(feat[0])