sunnychenxiwang's picture
Upload 1595 files
0b4516f verified
raw
history blame
1.42 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch.nn as nn
from torch import Tensor
from mmocr.models.textrecog.encoders import BaseEncoder
from mmocr.registry import MODELS
from mmocr.structures import TextSpottingDataSample
@MODELS.register_module()
class SPTSEncoder(BaseEncoder):
"""SPTS Encoder.
Args:
d_backbone (int): Backbone output dimension.
d_model (int): Model output dimension.
init_cfg (dict or list[dict], optional): Initialization configs.
Defaults to None.
"""
def __init__(self,
d_backbone: int = 2048,
d_model: int = 256,
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.input_proj = nn.Conv2d(d_backbone, d_model, kernel_size=1)
def forward(self,
feat: Tensor,
data_samples: List[TextSpottingDataSample] = None) -> Tensor:
"""Forward propagation of encoder.
Args:
feat (Tensor): Feature tensor of shape :math:`(N, D_m, H, W)`.
data_samples (list[TextSpottingDataSample]): Batch of
TextSpottingDataSample.
Defaults to None.
Returns:
Tensor: A tensor of shape :math:`(N, T, D_m)`.
"""
return self.input_proj(feat[0])