import torch.nn as nn from util.util import to_device from torch.nn import init import os import torch from .networks import * from params import * class BidirectionalLSTM(nn.Module): def __init__(self, nIn, nHidden, nOut): super(BidirectionalLSTM, self).__init__() self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) self.embedding = nn.Linear(nHidden * 2, nOut) def forward(self, input): recurrent, _ = self.rnn(input) T, b, h = recurrent.size() t_rec = recurrent.view(T * b, h) output = self.embedding(t_rec) # [T * b, nOut] output = output.view(T, b, -1) return output class CRNN(nn.Module): def __init__(self, leakyRelu=False): super(CRNN, self).__init__() self.name = 'OCR' #assert opt.imgH % 16 == 0, 'imgH has to be a multiple of 16' ks = [3, 3, 3, 3, 3, 3, 2] ps = [1, 1, 1, 1, 1, 1, 0] ss = [1, 1, 1, 1, 1, 1, 1] nm = [64, 128, 256, 256, 512, 512, 512] cnn = nn.Sequential() nh = 256 dealwith_lossnone=False # whether to replace all nan/inf in gradients to zero def convRelu(i, batchNormalization=False): nIn = 1 if i == 0 else nm[i - 1] nOut = nm[i] cnn.add_module('conv{0}'.format(i), nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) if batchNormalization: cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) if leakyRelu: cnn.add_module('relu{0}'.format(i), nn.LeakyReLU(0.2, inplace=True)) else: cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) convRelu(0) cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 convRelu(1) cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 convRelu(2, True) convRelu(3) cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 convRelu(4, True) if resolution==63: cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 convRelu(5) cnn.add_module('pooling{0}'.format(4), nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 convRelu(6, True) # 512x1x16 self.cnn = cnn self.use_rnn = False if self.use_rnn: self.rnn = nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, )) else: self.linear = nn.Linear(512, VOCAB_SIZE) # replace all nan/inf in gradients to zero if dealwith_lossnone: self.register_backward_hook(self.backward_hook) self.device = torch.device('cuda:{}'.format(0)) self.init = 'N02' # Initialize weights self = init_weights(self, self.init) def forward(self, input): # conv features conv = self.cnn(input) b, c, h, w = conv.size() if h!=1: print('a') assert h == 1, "the height of conv must be 1" conv = conv.squeeze(2) conv = conv.permute(2, 0, 1) # [w, b, c] if self.use_rnn: # rnn features output = self.rnn(conv) else: output = self.linear(conv) return output def backward_hook(self, module, grad_input, grad_output): for g in grad_input: g[g != g] = 0 # replace all nan/inf in gradients to zero class OCRLabelConverter(object): """Convert between str and label. NOTE: Insert `blank` to the alphabet for CTC. Args: alphabet (str): set of the possible characters. ignore_case (bool, default=True): whether or not to ignore all of the case. """ def __init__(self, alphabet, ignore_case=False): self._ignore_case = ignore_case if self._ignore_case: alphabet = alphabet.lower() self.alphabet = alphabet + '-' # for `-1` index self.dict = {} for i, char in enumerate(alphabet): # NOTE: 0 is reserved for 'blank' required by wrap_ctc self.dict[char] = i + 1 def encode(self, text): """Support batch or single str. Args: text (str or list of str): texts to convert. Returns: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. """ ''' if isinstance(text, str): text = [ self.dict[char.lower() if self._ignore_case else char] for char in text ] length = [len(text)] elif isinstance(text, collections.Iterable): length = [len(s) for s in text] text = ''.join(text) text, _ = self.encode(text) return (torch.IntTensor(text), torch.IntTensor(length)) ''' length = [] result = [] for item in text: item = item.decode('utf-8', 'strict') length.append(len(item)) for char in item: index = self.dict[char] result.append(index) text = result return (torch.IntTensor(text), torch.IntTensor(length)) def decode(self, t, length, raw=False): """Decode encoded texts back into strs. Args: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. Raises: AssertionError: when the texts and its length does not match. Returns: text (str or list of str): texts to convert. """ if length.numel() == 1: length = length[0] assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) if raw: return ''.join([self.alphabet[i - 1] for i in t]) else: char_list = [] for i in range(length): if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): char_list.append(self.alphabet[t[i] - 1]) return ''.join(char_list) else: # batch mode assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( t.numel(), length.sum()) texts = [] index = 0 for i in range(length.numel()): l = length[i] texts.append( self.decode( t[index:index + l], torch.IntTensor([l]), raw=raw)) index += l return texts class strLabelConverter(object): """Convert between str and label. NOTE: Insert `blank` to the alphabet for CTC. Args: alphabet (str): set of the possible characters. ignore_case (bool, default=True): whether or not to ignore all of the case. """ def __init__(self, alphabet, ignore_case=False): self._ignore_case = ignore_case if self._ignore_case: alphabet = alphabet.lower() self.alphabet = alphabet + '-' # for `-1` index self.dict = {} for i, char in enumerate(alphabet): # NOTE: 0 is reserved for 'blank' required by wrap_ctc self.dict[char] = i + 1 def encode(self, text): """Support batch or single str. Args: text (str or list of str): texts to convert. Returns: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. """ ''' if isinstance(text, str): text = [ self.dict[char.lower() if self._ignore_case else char] for char in text ] length = [len(text)] elif isinstance(text, collections.Iterable): length = [len(s) for s in text] text = ''.join(text) text, _ = self.encode(text) return (torch.IntTensor(text), torch.IntTensor(length)) ''' length = [] result = [] results = [] for item in text: item = item.decode('utf-8', 'strict') length.append(len(item)) for char in item: index = self.dict[char] result.append(index) results.append(result) result = [] return (torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length)) def decode(self, t, length, raw=False): """Decode encoded texts back into strs. Args: torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. torch.IntTensor [n]: length of each text. Raises: AssertionError: when the texts and its length does not match. Returns: text (str or list of str): texts to convert. """ if length.numel() == 1: length = length[0] assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) if raw: return ''.join([self.alphabet[i - 1] for i in t]) else: char_list = [] for i in range(length): if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): char_list.append(self.alphabet[t[i] - 1]) return ''.join(char_list) else: # batch mode assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( t.numel(), length.sum()) texts = [] index = 0 for i in range(length.numel()): l = length[i] texts.append( self.decode( t[index:index + l], torch.IntTensor([l]), raw=raw)) index += l return texts