Spaces:
Running
Running
added textbpn++ detection module
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- IndicPhotoOCR/detection/textbpn/__init__.py +0 -0
- IndicPhotoOCR/detection/textbpn/cfglib/config.py +90 -0
- IndicPhotoOCR/detection/textbpn/cfglib/option.py +123 -0
- IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth +3 -0
- IndicPhotoOCR/detection/textbpn/network/Reg_loss.py +196 -0
- IndicPhotoOCR/detection/textbpn/network/Seg_loss.py +107 -0
- IndicPhotoOCR/detection/textbpn/network/__init__.py +1 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/__init__.py +1 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile +6 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile.sh +6 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/__init__.py +13 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/__init__.py +0 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_conv.py +181 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_pool.py +69 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/__init__.py +0 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_conv.py +157 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_pool.py +172 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/setup.py +19 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda.cpp +695 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda_kernel.cu +866 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda.cpp +87 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda_kernel.cu +364 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/resnet.py +336 -0
- IndicPhotoOCR/detection/textbpn/network/backbone/vgg.py +60 -0
- IndicPhotoOCR/detection/textbpn/network/layers/Adaptive_Deformation.py +88 -0
- IndicPhotoOCR/detection/textbpn/network/layers/CircConv.py +91 -0
- IndicPhotoOCR/detection/textbpn/network/layers/GCN.py +77 -0
- IndicPhotoOCR/detection/textbpn/network/layers/GraphConv.py +45 -0
- IndicPhotoOCR/detection/textbpn/network/layers/RNN.py +35 -0
- IndicPhotoOCR/detection/textbpn/network/layers/Transformer.py +140 -0
- IndicPhotoOCR/detection/textbpn/network/layers/Transformer_old.py +171 -0
- IndicPhotoOCR/detection/textbpn/network/layers/__init__.py +0 -0
- IndicPhotoOCR/detection/textbpn/network/layers/gcn_utils.py +150 -0
- IndicPhotoOCR/detection/textbpn/network/layers/model_block.py +149 -0
- IndicPhotoOCR/detection/textbpn/network/layers/position_encoding.py +89 -0
- IndicPhotoOCR/detection/textbpn/network/layers/resnet.py +73 -0
- IndicPhotoOCR/detection/textbpn/network/layers/resnet_dcn.py +59 -0
- IndicPhotoOCR/detection/textbpn/network/layers/vgg.py +62 -0
- IndicPhotoOCR/detection/textbpn/network/loss.py +187 -0
- IndicPhotoOCR/detection/textbpn/network/loss_org.py +136 -0
- IndicPhotoOCR/detection/textbpn/network/textnet.py +216 -0
- IndicPhotoOCR/detection/textbpn/output.png +3 -0
- IndicPhotoOCR/detection/textbpn/textbpnpp_detector.py +197 -0
- IndicPhotoOCR/detection/textbpn/util/__init__.py +2 -0
- IndicPhotoOCR/detection/textbpn/util/augmentation.py +794 -0
- IndicPhotoOCR/detection/textbpn/util/canvas.py +55 -0
- IndicPhotoOCR/detection/textbpn/util/detection.py +48 -0
- IndicPhotoOCR/detection/textbpn/util/eval.py +228 -0
- IndicPhotoOCR/detection/textbpn/util/graph.py +309 -0
- IndicPhotoOCR/detection/textbpn/util/io.py +233 -0
IndicPhotoOCR/detection/textbpn/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/detection/textbpn/cfglib/config.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
|
5 |
+
config = EasyDict()
|
6 |
+
|
7 |
+
|
8 |
+
# Normalize image
|
9 |
+
config.means = (0.485, 0.456, 0.406)
|
10 |
+
config.stds = (0.229, 0.224, 0.225)
|
11 |
+
|
12 |
+
config.gpu = "1"
|
13 |
+
|
14 |
+
# Experiment name #
|
15 |
+
config.exp_name = "Synthtext"
|
16 |
+
|
17 |
+
# dataloader jobs number
|
18 |
+
config.num_workers = 24
|
19 |
+
|
20 |
+
# batch_size
|
21 |
+
config.batch_size = 12
|
22 |
+
|
23 |
+
# training epoch number
|
24 |
+
config.max_epoch = 200
|
25 |
+
|
26 |
+
config.start_epoch = 0
|
27 |
+
|
28 |
+
# learning rate
|
29 |
+
config.lr = 1e-4
|
30 |
+
|
31 |
+
# using GPU
|
32 |
+
config.cuda = False
|
33 |
+
|
34 |
+
config.output_dir = 'output'
|
35 |
+
|
36 |
+
config.input_size = 640
|
37 |
+
|
38 |
+
# max polygon per image
|
39 |
+
# synText, total-text:64; CTW1500: 64; icdar: 64; MLT: 32; TD500: 64.
|
40 |
+
config.max_annotation = 64
|
41 |
+
|
42 |
+
# adj num for graph
|
43 |
+
config.adj_num = 4
|
44 |
+
|
45 |
+
# control points number
|
46 |
+
config.num_points = 20
|
47 |
+
|
48 |
+
# use hard examples (annotated as '#')
|
49 |
+
config.use_hard = True
|
50 |
+
|
51 |
+
# Load data into memory at one time
|
52 |
+
config.load_memory = False
|
53 |
+
|
54 |
+
# prediction on 1/scale feature map
|
55 |
+
config.scale = 1
|
56 |
+
|
57 |
+
# # clip gradient of loss
|
58 |
+
config.grad_clip = 25
|
59 |
+
|
60 |
+
# demo tcl threshold
|
61 |
+
config.dis_threshold = 0.4
|
62 |
+
|
63 |
+
config.cls_threshold = 0.8
|
64 |
+
|
65 |
+
# Contour approximation factor
|
66 |
+
config.approx_factor = 0.004
|
67 |
+
|
68 |
+
|
69 |
+
def update_config(config, extra_config):
|
70 |
+
for k, v in vars(extra_config).items():
|
71 |
+
config[k] = v
|
72 |
+
# print(config.gpu)
|
73 |
+
# config.device = torch.device('cuda') if config.cuda else torch.device('cpu')
|
74 |
+
config.device = torch.device('cpu')
|
75 |
+
|
76 |
+
|
77 |
+
def print_config(config):
|
78 |
+
print('==========Options============')
|
79 |
+
for k, v in config.items():
|
80 |
+
print('{}: {}'.format(k, v))
|
81 |
+
print('=============End=============')
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
################### MY Settings ##################
|
86 |
+
config.resume=True
|
87 |
+
|
88 |
+
config.device="cpu"
|
89 |
+
|
90 |
+
# config.test_size = [224, 224]
|
IndicPhotoOCR/detection/textbpn/cfglib/option.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import torch.backends.cudnn as cudnn
|
5 |
+
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
|
9 |
+
def str2bool(v):
|
10 |
+
return v.lower() in ("yes", "true", "t", "1")
|
11 |
+
|
12 |
+
|
13 |
+
def arg2str(args):
|
14 |
+
args_dict = vars(args)
|
15 |
+
option_str = datetime.now().strftime('%b%d_%H-%M-%S') + '\n'
|
16 |
+
|
17 |
+
for k, v in sorted(args_dict.items()):
|
18 |
+
option_str += ('{}: {}\n'.format(str(k), str(v)))
|
19 |
+
|
20 |
+
return option_str
|
21 |
+
|
22 |
+
|
23 |
+
class BaseOptions(object):
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
|
27 |
+
self.parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
# basic opts
|
30 |
+
self.parser.add_argument('--exp_name', default="TD500", type=str,
|
31 |
+
choices=['Synthtext', 'Totaltext', 'Ctw1500','Icdar2015',
|
32 |
+
"MLT2017", 'TD500', "MLT2019", "ArT", "ALL"], help='Experiment name')
|
33 |
+
self.parser.add_argument("--gpu", default="1", help="set gpu id", type=str)
|
34 |
+
self.parser.add_argument('--resume', default=None, type=str, help='Path to target resume checkpoint')
|
35 |
+
self.parser.add_argument('--num_workers', default=24, type=int, help='Number of workers used in dataloading')
|
36 |
+
self.parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
|
37 |
+
self.parser.add_argument('--mgpu', action='store_true', help='Use multi-gpu to train model')
|
38 |
+
self.parser.add_argument('--save_dir', default='./model/', help='Path to save checkpoint models')
|
39 |
+
self.parser.add_argument('--vis_dir', default='./vis/', help='Path to save visualization images')
|
40 |
+
self.parser.add_argument('--log_dir', default='./logs/', help='Path to tensorboard log')
|
41 |
+
self.parser.add_argument('--loss', default='CrossEntropyLoss', type=str, help='Training Loss')
|
42 |
+
# self.parser.add_argument('--input_channel', default=1, type=int, help='number of input channels' )
|
43 |
+
self.parser.add_argument('--pretrain', default=False, type=str2bool, help='Pretrained AutoEncoder model')
|
44 |
+
self.parser.add_argument('--verbose', '-v', default=True, type=str2bool, help='Whether to output debug info')
|
45 |
+
self.parser.add_argument('--viz', action='store_true', help='Whether to output debug info')
|
46 |
+
# self.parser.add_argument('--viz', default=True, type=str2bool, help='Whether to output debug info')
|
47 |
+
|
48 |
+
# train opts
|
49 |
+
self.parser.add_argument('--max_epoch', default=250, type=int, help='Max epochs')
|
50 |
+
self.parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
|
51 |
+
self.parser.add_argument('--lr_adjust', default='fix',
|
52 |
+
choices=['fix', 'poly'], type=str, help='Learning Rate Adjust Strategy')
|
53 |
+
self.parser.add_argument('--stepvalues', default=[], nargs='+', type=int, help='# of iter to change lr')
|
54 |
+
self.parser.add_argument('--weight_decay', '--wd', default=0., type=float, help='Weight decay for SGD')
|
55 |
+
self.parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD lr')
|
56 |
+
self.parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
|
57 |
+
self.parser.add_argument('--batch_size', default=6, type=int, help='Batch size for training')
|
58 |
+
self.parser.add_argument('--optim', default='Adam', type=str, choices=['SGD', 'Adam'], help='Optimizer')
|
59 |
+
self.parser.add_argument('--save_freq', default=5, type=int, help='save weights every # epoch')
|
60 |
+
self.parser.add_argument('--display_freq', default=10, type=int, help='display training metrics every # iter')
|
61 |
+
self.parser.add_argument('--viz_freq', default=50, type=int, help='visualize training process every # iter')
|
62 |
+
self.parser.add_argument('--log_freq', default=10000, type=int, help='log to tensorboard every # iterations')
|
63 |
+
self.parser.add_argument('--val_freq', default=1000, type=int, help='do validation every # iterations')
|
64 |
+
|
65 |
+
# backbone
|
66 |
+
self.parser.add_argument('--scale', default=1, type=int, help='prediction on 1/scale feature map')
|
67 |
+
self.parser.add_argument('--net', default='resnet50', type=str,
|
68 |
+
choices=['vgg', 'resnet50', 'resnet18',
|
69 |
+
"deformable_resnet18", "deformable_resnet50"],
|
70 |
+
help='Network architecture')
|
71 |
+
# data args
|
72 |
+
self.parser.add_argument('--load_memory', default=False, type=str2bool, help='Load data into memory')
|
73 |
+
self.parser.add_argument('--rescale', type=float, default=255.0, help='rescale factor')
|
74 |
+
self.parser.add_argument('--input_size', default=640, type=int, help='model input size')
|
75 |
+
self.parser.add_argument('--test_size', default=[640, 960], type=int, nargs='+', help='test size')
|
76 |
+
|
77 |
+
# eval args00
|
78 |
+
self.parser.add_argument('--checkepoch', default=1070, type=int, help='Load checkpoint number')
|
79 |
+
self.parser.add_argument('--start_epoch', default=0, type=int, help='start epoch number')
|
80 |
+
self.parser.add_argument('--cls_threshold', default=0.875, type=float, help='threshold of pse')
|
81 |
+
self.parser.add_argument('--dis_threshold', default=0.35, type=float, help='filter the socre < score_i')
|
82 |
+
|
83 |
+
# demo args
|
84 |
+
self.parser.add_argument('--img_root', default=None, type=str, help='Path to deploy images')
|
85 |
+
|
86 |
+
def parse(self, fixed=None):
|
87 |
+
|
88 |
+
if fixed is not None:
|
89 |
+
args = self.parser.parse_args(fixed)
|
90 |
+
else:
|
91 |
+
args = self.parser.parse_args()
|
92 |
+
|
93 |
+
return args
|
94 |
+
|
95 |
+
def initialize(self, fixed=None):
|
96 |
+
|
97 |
+
# Parse options
|
98 |
+
self.args = self.parse(fixed)
|
99 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = self.args.gpu
|
100 |
+
|
101 |
+
# Setting default torch Tensor type
|
102 |
+
if self.args.cuda and torch.cuda.is_available():
|
103 |
+
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
104 |
+
cudnn.benchmark = True
|
105 |
+
else:
|
106 |
+
torch.set_default_tensor_type('torch.FloatTensor')
|
107 |
+
|
108 |
+
# Create weights saving directory
|
109 |
+
if not os.path.exists(self.args.save_dir):
|
110 |
+
os.mkdir(self.args.save_dir)
|
111 |
+
|
112 |
+
# Create weights saving directory of target model
|
113 |
+
model_save_path = os.path.join(self.args.save_dir, self.args.exp_name)
|
114 |
+
|
115 |
+
if not os.path.exists(model_save_path):
|
116 |
+
os.mkdir(model_save_path)
|
117 |
+
|
118 |
+
return self.args
|
119 |
+
|
120 |
+
def update(self, args, extra_options):
|
121 |
+
|
122 |
+
for k, v in extra_options.items():
|
123 |
+
setattr(args, k, v)
|
IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b735b9c93c8758972d3b8cfd3ef8e1c09afa8cd9106f4cb11406300b141b1d78
|
3 |
+
size 145703602
|
IndicPhotoOCR/detection/textbpn/network/Reg_loss.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 10/1/21
|
3 |
+
# @Author : GXYM
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class PolyMatchingLoss(nn.Module):
|
11 |
+
def __init__(self, pnum, device, loss_type="L1"):
|
12 |
+
super(PolyMatchingLoss, self).__init__()
|
13 |
+
|
14 |
+
self.pnum = pnum
|
15 |
+
self.device = device
|
16 |
+
self.loss_type = loss_type
|
17 |
+
self.smooth_L1 = F.smooth_l1_loss
|
18 |
+
self.L2_loss = torch.nn.MSELoss(reduce=False, size_average=False)
|
19 |
+
|
20 |
+
batch_size = 1
|
21 |
+
pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32)
|
22 |
+
for b in range(batch_size):
|
23 |
+
for i in range(pnum):
|
24 |
+
pidx = (np.arange(pnum) + i) % pnum
|
25 |
+
pidxall[b, i] = pidx
|
26 |
+
|
27 |
+
pidxall = torch.from_numpy(np.reshape(pidxall, newshape=(batch_size, -1))).to(device)
|
28 |
+
self.feature_id = pidxall.unsqueeze_(2).long().expand(pidxall.size(0), pidxall.size(1), 2).detach()
|
29 |
+
print(self.feature_id.shape)
|
30 |
+
|
31 |
+
def match_loss(self, pred, gt):
|
32 |
+
batch_size = pred.shape[0]
|
33 |
+
feature_id = self.feature_id.expand(batch_size, self.feature_id.size(1), 2)
|
34 |
+
|
35 |
+
gt_expand = torch.gather(gt, 1, feature_id).view(batch_size, self.pnum, self.pnum, 2)
|
36 |
+
pred_expand = pred.unsqueeze(1)
|
37 |
+
|
38 |
+
if self.loss_type == "L2":
|
39 |
+
dis = self.L2_loss(pred_expand, gt_expand)
|
40 |
+
dis = dis.sum(3).sqrt().mean(2)
|
41 |
+
elif self.loss_type == "L1":
|
42 |
+
dis = self.smooth_L1(pred_expand, gt_expand, reduction='none')
|
43 |
+
dis = dis.sum(3).mean(2)
|
44 |
+
|
45 |
+
min_dis, min_id = torch.min(dis, dim=1, keepdim=True)
|
46 |
+
|
47 |
+
return min_dis
|
48 |
+
|
49 |
+
def forward(self, pred_list, gt):
|
50 |
+
loss = torch.tensor(0.)
|
51 |
+
for pred in pred_list:
|
52 |
+
loss += torch.mean(self.match_loss(pred, gt))
|
53 |
+
|
54 |
+
return loss / torch.tensor(len(pred_list))
|
55 |
+
|
56 |
+
# los = []
|
57 |
+
# for pred in pred_list:
|
58 |
+
# los.append(self.match_loss(pred, gt))
|
59 |
+
#
|
60 |
+
# los_b = torch.tensor(0.)
|
61 |
+
# loss_c = torch.tensor(0.)
|
62 |
+
# for i, _ in enumerate(los):
|
63 |
+
# los_b += torch.mean(los[i])
|
64 |
+
# loss_c += (torch.mean(torch.clamp(los[i] - los[i - 1], min=0.0)) if i > 0 else torch.tensor(0.))
|
65 |
+
# loss = los_b / torch.tensor(len(los)) + 0.5*loss_c / torch.tensor(len(los)-1)
|
66 |
+
#
|
67 |
+
# return loss
|
68 |
+
|
69 |
+
|
70 |
+
class AttentionLoss(nn.Module):
|
71 |
+
def __init__(self, beta=4, gamma=0.5):
|
72 |
+
super(AttentionLoss, self).__init__()
|
73 |
+
|
74 |
+
self.beta = beta
|
75 |
+
self.gamma = gamma
|
76 |
+
|
77 |
+
def forward(self, pred, gt):
|
78 |
+
num_pos = torch.sum(gt)
|
79 |
+
num_neg = torch.sum(1 - gt)
|
80 |
+
alpha = num_neg / (num_pos + num_neg)
|
81 |
+
edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma))
|
82 |
+
bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma))
|
83 |
+
|
84 |
+
loss = 0
|
85 |
+
loss = loss - alpha * edge_beta * torch.log(pred) * gt
|
86 |
+
loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt)
|
87 |
+
return torch.mean(loss)
|
88 |
+
|
89 |
+
|
90 |
+
class GeoCrossEntropyLoss(nn.Module):
|
91 |
+
def __init__(self):
|
92 |
+
super(GeoCrossEntropyLoss, self).__init__()
|
93 |
+
|
94 |
+
def forward(self, output, target, poly):
|
95 |
+
output = torch.nn.functional.softmax(output, dim=1)
|
96 |
+
output = torch.log(torch.clamp(output, min=1e-4))
|
97 |
+
poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2)
|
98 |
+
target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, poly.size(3))
|
99 |
+
target_poly = torch.gather(poly, 2, target)
|
100 |
+
sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True)
|
101 |
+
kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3))
|
102 |
+
loss = -(output * kernel.transpose(2, 1)).sum(1).mean()
|
103 |
+
return loss
|
104 |
+
|
105 |
+
|
106 |
+
class AELoss(nn.Module):
|
107 |
+
def __init__(self):
|
108 |
+
super(AELoss, self).__init__()
|
109 |
+
|
110 |
+
def forward(self, ae, ind, ind_mask):
|
111 |
+
"""
|
112 |
+
ae: [b, 1, h, w]
|
113 |
+
ind: [b, max_objs, max_parts]
|
114 |
+
ind_mask: [b, max_objs, max_parts]
|
115 |
+
obj_mask: [b, max_objs]
|
116 |
+
"""
|
117 |
+
# first index
|
118 |
+
b, _, h, w = ae.shape
|
119 |
+
b, max_objs, max_parts = ind.shape
|
120 |
+
obj_mask = torch.sum(ind_mask, dim=2) != 0
|
121 |
+
|
122 |
+
ae = ae.view(b, h * w, 1)
|
123 |
+
seed_ind = ind.view(b, max_objs * max_parts, 1)
|
124 |
+
tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts)
|
125 |
+
|
126 |
+
# compute the mean
|
127 |
+
tag_mean = tag * ind_mask
|
128 |
+
tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4)
|
129 |
+
|
130 |
+
# pull ae of the same object to their mean
|
131 |
+
pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask
|
132 |
+
obj_num = obj_mask.sum(dim=1).float()
|
133 |
+
pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum()
|
134 |
+
pull /= b
|
135 |
+
|
136 |
+
# push away the mean of different objects
|
137 |
+
push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2))
|
138 |
+
push_dist = 1 - push_dist
|
139 |
+
push_dist = nn.functional.relu(push_dist, inplace=True)
|
140 |
+
obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2
|
141 |
+
push_dist = push_dist * obj_mask.float()
|
142 |
+
push = ((push_dist.sum(dim=(1, 2)) - obj_num) / (obj_num * (obj_num - 1) + 1e-4)).sum()
|
143 |
+
push /= b
|
144 |
+
return pull, push
|
145 |
+
|
146 |
+
|
147 |
+
def smooth_l1_loss(inputs, target, sigma=9.0):
|
148 |
+
try:
|
149 |
+
diff = torch.abs(inputs - target)
|
150 |
+
less_one = (diff < 1.0 / sigma).float()
|
151 |
+
loss = less_one * 0.5 * diff ** 2 * sigma \
|
152 |
+
+ torch.abs(torch.tensor(1.0) - less_one) * (diff - 0.5 / sigma)
|
153 |
+
loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0)
|
154 |
+
except Exception as e:
|
155 |
+
print('RPN_REGR_Loss Exception:', e)
|
156 |
+
loss = torch.tensor(0.0)
|
157 |
+
|
158 |
+
return loss
|
159 |
+
|
160 |
+
|
161 |
+
def _neg_loss(pred, gt):
|
162 |
+
''' Modified focal loss. Exactly the same as CornerNet.
|
163 |
+
Runs faster and costs a little bit more memory
|
164 |
+
Arguments:
|
165 |
+
pred (batch x c x h x w)
|
166 |
+
gt_regr (batch x c x h x w)
|
167 |
+
'''
|
168 |
+
pos_inds = gt.eq(1).float()
|
169 |
+
neg_inds = gt.lt(1).float()
|
170 |
+
|
171 |
+
neg_weights = torch.pow(1 - gt, 4)
|
172 |
+
|
173 |
+
loss = 0
|
174 |
+
|
175 |
+
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
|
176 |
+
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
|
177 |
+
|
178 |
+
num_pos = pos_inds.float().sum()
|
179 |
+
pos_loss = pos_loss.sum()
|
180 |
+
neg_loss = neg_loss.sum()
|
181 |
+
|
182 |
+
if num_pos == 0:
|
183 |
+
loss = loss - neg_loss
|
184 |
+
else:
|
185 |
+
loss = loss - (pos_loss + neg_loss) / num_pos
|
186 |
+
return loss
|
187 |
+
|
188 |
+
|
189 |
+
class FocalLoss(nn.Module):
|
190 |
+
'''nn.Module warpper for focal loss'''
|
191 |
+
def __init__(self):
|
192 |
+
super(FocalLoss, self).__init__()
|
193 |
+
self.neg_loss = _neg_loss
|
194 |
+
|
195 |
+
def forward(self, out, target):
|
196 |
+
return self.neg_loss(out, target)
|
IndicPhotoOCR/detection/textbpn/network/Seg_loss.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 10/1/21
|
3 |
+
# @Author : GXYM
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class SegmentLoss(nn.Module):
|
10 |
+
def __init__(self, Lambda, ratio=3, reduction='mean'):
|
11 |
+
"""Implement PSE Loss.
|
12 |
+
"""
|
13 |
+
super(SegmentLoss, self).__init__()
|
14 |
+
assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
|
15 |
+
self.Lambda = Lambda
|
16 |
+
self.ratio = ratio
|
17 |
+
self.reduction = reduction
|
18 |
+
|
19 |
+
def forward(self, outputs, labels, training_masks, th=0.5):
|
20 |
+
texts = outputs[:, -1, :, :]
|
21 |
+
kernels = outputs[:, :-1, :, :]
|
22 |
+
gt_texts = labels[:, -1, :, :]
|
23 |
+
gt_kernels = labels[:, :-1, :, :]
|
24 |
+
|
25 |
+
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
|
26 |
+
selected_masks = selected_masks.to(outputs.device)
|
27 |
+
|
28 |
+
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
|
29 |
+
|
30 |
+
loss_kernels = []
|
31 |
+
# mask0 = torch.sigmoid(texts).data.cpu().numpy()
|
32 |
+
mask0 = texts.data.cpu().numpy()
|
33 |
+
mask1 = training_masks.data.cpu().numpy()
|
34 |
+
selected_masks = ((mask0 > th) & (mask1 > th)).astype('float32')
|
35 |
+
selected_masks = torch.from_numpy(selected_masks).float()
|
36 |
+
selected_masks = selected_masks.to(outputs.device)
|
37 |
+
kernels_num = gt_kernels.size()[1]
|
38 |
+
for i in range(kernels_num):
|
39 |
+
kernel_i = kernels[:, i, :, :]
|
40 |
+
gt_kernel_i = gt_kernels[:, i, :, :]
|
41 |
+
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
|
42 |
+
loss_kernels.append(loss_kernel_i)
|
43 |
+
loss_kernels = torch.stack(loss_kernels).mean(0)
|
44 |
+
if self.reduction == 'mean':
|
45 |
+
loss_text = loss_text.mean()
|
46 |
+
loss_kernels = loss_kernels.mean()
|
47 |
+
elif self.reduction == 'sum':
|
48 |
+
loss_text = loss_text.sum()
|
49 |
+
loss_kernels = loss_kernels.sum()
|
50 |
+
|
51 |
+
loss = self.Lambda *loss_text + (1-self.Lambda)*loss_kernels
|
52 |
+
return loss_text, loss_kernels, loss
|
53 |
+
|
54 |
+
def dice_loss(self, input, target, mask):
|
55 |
+
# input = torch.sigmoid(input)
|
56 |
+
|
57 |
+
input = input.contiguous().view(input.size()[0], -1)
|
58 |
+
target = target.contiguous().view(target.size()[0], -1)
|
59 |
+
mask = mask.contiguous().view(mask.size()[0], -1)
|
60 |
+
|
61 |
+
input = input * mask
|
62 |
+
target = (target.float()) * mask
|
63 |
+
|
64 |
+
a = torch.sum(input * target, 1)
|
65 |
+
b = torch.sum(input * input, 1) + 0.001
|
66 |
+
c = torch.sum(target * target, 1) + 0.001
|
67 |
+
d = (2 * a) / (b + c)
|
68 |
+
return 1 - d
|
69 |
+
|
70 |
+
def ohem_single(self, score, gt_text, training_mask, th=0.5):
|
71 |
+
pos_num = (int)(np.sum(gt_text > th)) - (int)(np.sum((gt_text > th) & (training_mask <= th)))
|
72 |
+
|
73 |
+
if pos_num == 0:
|
74 |
+
# selected_mask = gt_text.copy() * 0 # may be not good
|
75 |
+
selected_mask = training_mask
|
76 |
+
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
77 |
+
return selected_mask
|
78 |
+
|
79 |
+
neg_num = (int)(np.sum(gt_text <= th))
|
80 |
+
neg_num = (int)(min(pos_num * 3, neg_num))
|
81 |
+
|
82 |
+
if neg_num == 0:
|
83 |
+
selected_mask = training_mask
|
84 |
+
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
85 |
+
return selected_mask
|
86 |
+
|
87 |
+
neg_score = score[gt_text <= th]
|
88 |
+
# 将负样本得分从高到低排序
|
89 |
+
neg_score_sorted = np.sort(-neg_score)
|
90 |
+
threshold = -neg_score_sorted[neg_num - 1]
|
91 |
+
# 选出 得分高的 负样本 和正样本 的 mask
|
92 |
+
selected_mask = ((score >= threshold) | (gt_text > th)) & (training_mask > th)
|
93 |
+
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
94 |
+
return selected_mask
|
95 |
+
|
96 |
+
def ohem_batch(self, scores, gt_texts, training_masks):
|
97 |
+
scores = scores.data.cpu().numpy()
|
98 |
+
gt_texts = gt_texts.data.cpu().numpy()
|
99 |
+
training_masks = training_masks.data.cpu().numpy()
|
100 |
+
selected_masks = []
|
101 |
+
for i in range(scores.shape[0]):
|
102 |
+
selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
|
103 |
+
|
104 |
+
selected_masks = np.concatenate(selected_masks, 0)
|
105 |
+
selected_masks = torch.from_numpy(selected_masks).float()
|
106 |
+
|
107 |
+
return selected_masks
|
IndicPhotoOCR/detection/textbpn/network/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import *
|
IndicPhotoOCR/detection/textbpn/network/backbone/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .resnet import resnet18, resnet34, resnet50, resnet101, deformable_resnet50, deformable_resnet18
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
rm *.so
|
3 |
+
python setup.py build_ext --inplace
|
4 |
+
rm -rf ./build
|
5 |
+
|
6 |
+
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
rm *.so
|
3 |
+
python setup.py build_ext --inplace
|
4 |
+
rm -rf ./build
|
5 |
+
|
6 |
+
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .functions.deform_conv import deform_conv, modulated_deform_conv
|
2 |
+
from .functions.deform_pool import deform_roi_pooling
|
3 |
+
from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
|
4 |
+
DeformConvPack, ModulatedDeformConvPack)
|
5 |
+
from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
|
6 |
+
ModulatedDeformRoIPoolingPack)
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
|
10 |
+
'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
|
11 |
+
'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
|
12 |
+
'deform_roi_pooling'
|
13 |
+
]
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_conv.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Function
|
3 |
+
from torch.nn.modules.utils import _pair
|
4 |
+
|
5 |
+
from .. import deform_conv_cuda
|
6 |
+
|
7 |
+
|
8 |
+
class DeformConvFunction(Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx,
|
12 |
+
input,
|
13 |
+
offset,
|
14 |
+
weight,
|
15 |
+
stride=1,
|
16 |
+
padding=0,
|
17 |
+
dilation=1,
|
18 |
+
groups=1,
|
19 |
+
deformable_groups=1,
|
20 |
+
im2col_step=64):
|
21 |
+
if input is not None and input.dim() != 4:
|
22 |
+
raise ValueError(
|
23 |
+
"Expected 4D tensor as input, got {}D tensor instead.".format(
|
24 |
+
input.dim()))
|
25 |
+
ctx.stride = _pair(stride)
|
26 |
+
ctx.padding = _pair(padding)
|
27 |
+
ctx.dilation = _pair(dilation)
|
28 |
+
ctx.groups = groups
|
29 |
+
ctx.deformable_groups = deformable_groups
|
30 |
+
ctx.im2col_step = im2col_step
|
31 |
+
|
32 |
+
ctx.save_for_backward(input, offset, weight)
|
33 |
+
|
34 |
+
output = input.new_empty(
|
35 |
+
DeformConvFunction._output_size(input, weight, ctx.padding,
|
36 |
+
ctx.dilation, ctx.stride))
|
37 |
+
|
38 |
+
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
|
39 |
+
|
40 |
+
if not input.is_cuda:
|
41 |
+
raise NotImplementedError
|
42 |
+
else:
|
43 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
44 |
+
assert (input.shape[0] %
|
45 |
+
cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
46 |
+
deform_conv_cuda.deform_conv_forward_cuda(
|
47 |
+
input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
|
48 |
+
weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
|
49 |
+
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
|
50 |
+
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
|
51 |
+
cur_im2col_step)
|
52 |
+
return output
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, grad_output):
|
56 |
+
input, offset, weight = ctx.saved_tensors
|
57 |
+
|
58 |
+
grad_input = grad_offset = grad_weight = None
|
59 |
+
|
60 |
+
if not grad_output.is_cuda:
|
61 |
+
raise NotImplementedError
|
62 |
+
else:
|
63 |
+
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
|
64 |
+
assert (input.shape[0] %
|
65 |
+
cur_im2col_step) == 0, 'im2col step must divide batchsize'
|
66 |
+
|
67 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
68 |
+
grad_input = torch.zeros_like(input)
|
69 |
+
grad_offset = torch.zeros_like(offset)
|
70 |
+
deform_conv_cuda.deform_conv_backward_input_cuda(
|
71 |
+
input, offset, grad_output, grad_input,
|
72 |
+
grad_offset, weight, ctx.bufs_[0], weight.size(3),
|
73 |
+
weight.size(2), ctx.stride[1], ctx.stride[0],
|
74 |
+
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
|
75 |
+
ctx.dilation[0], ctx.groups, ctx.deformable_groups,
|
76 |
+
cur_im2col_step)
|
77 |
+
|
78 |
+
if ctx.needs_input_grad[2]:
|
79 |
+
grad_weight = torch.zeros_like(weight)
|
80 |
+
deform_conv_cuda.deform_conv_backward_parameters_cuda(
|
81 |
+
input, offset, grad_output,
|
82 |
+
grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
|
83 |
+
weight.size(2), ctx.stride[1], ctx.stride[0],
|
84 |
+
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
|
85 |
+
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
|
86 |
+
cur_im2col_step)
|
87 |
+
|
88 |
+
return (grad_input, grad_offset, grad_weight, None, None, None, None,
|
89 |
+
None)
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def _output_size(input, weight, padding, dilation, stride):
|
93 |
+
channels = weight.size(0)
|
94 |
+
output_size = (input.size(0), channels)
|
95 |
+
for d in range(input.dim() - 2):
|
96 |
+
in_size = input.size(d + 2)
|
97 |
+
pad = padding[d]
|
98 |
+
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
|
99 |
+
stride_ = stride[d]
|
100 |
+
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
|
101 |
+
if not all(map(lambda s: s > 0, output_size)):
|
102 |
+
raise ValueError(
|
103 |
+
"convolution input is too small (output would be {})".format(
|
104 |
+
'x'.join(map(str, output_size))))
|
105 |
+
return output_size
|
106 |
+
|
107 |
+
|
108 |
+
class ModulatedDeformConvFunction(Function):
|
109 |
+
|
110 |
+
@staticmethod
|
111 |
+
def forward(ctx,
|
112 |
+
input,
|
113 |
+
offset,
|
114 |
+
mask,
|
115 |
+
weight,
|
116 |
+
bias=None,
|
117 |
+
stride=1,
|
118 |
+
padding=0,
|
119 |
+
dilation=1,
|
120 |
+
groups=1,
|
121 |
+
deformable_groups=1):
|
122 |
+
ctx.stride = stride
|
123 |
+
ctx.padding = padding
|
124 |
+
ctx.dilation = dilation
|
125 |
+
ctx.groups = groups
|
126 |
+
ctx.deformable_groups = deformable_groups
|
127 |
+
ctx.with_bias = bias is not None
|
128 |
+
if not ctx.with_bias:
|
129 |
+
bias = input.new_empty(1) # fake tensor
|
130 |
+
if not input.is_cuda:
|
131 |
+
raise NotImplementedError
|
132 |
+
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
|
133 |
+
or input.requires_grad:
|
134 |
+
ctx.save_for_backward(input, offset, mask, weight, bias)
|
135 |
+
output = input.new_empty(
|
136 |
+
ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
|
137 |
+
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
|
138 |
+
deform_conv_cuda.modulated_deform_conv_cuda_forward(
|
139 |
+
input, weight, bias, ctx._bufs[0], offset, mask, output,
|
140 |
+
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
|
141 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
142 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
143 |
+
return output
|
144 |
+
|
145 |
+
@staticmethod
|
146 |
+
def backward(ctx, grad_output):
|
147 |
+
if not grad_output.is_cuda:
|
148 |
+
raise NotImplementedError
|
149 |
+
input, offset, mask, weight, bias = ctx.saved_tensors
|
150 |
+
grad_input = torch.zeros_like(input)
|
151 |
+
grad_offset = torch.zeros_like(offset)
|
152 |
+
grad_mask = torch.zeros_like(mask)
|
153 |
+
grad_weight = torch.zeros_like(weight)
|
154 |
+
grad_bias = torch.zeros_like(bias)
|
155 |
+
deform_conv_cuda.modulated_deform_conv_cuda_backward(
|
156 |
+
input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
|
157 |
+
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
|
158 |
+
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
|
159 |
+
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
|
160 |
+
ctx.groups, ctx.deformable_groups, ctx.with_bias)
|
161 |
+
if not ctx.with_bias:
|
162 |
+
grad_bias = None
|
163 |
+
|
164 |
+
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
|
165 |
+
None, None, None, None, None)
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def _infer_shape(ctx, input, weight):
|
169 |
+
n = input.size(0)
|
170 |
+
channels_out = weight.size(0)
|
171 |
+
height, width = input.shape[2:4]
|
172 |
+
kernel_h, kernel_w = weight.shape[2:4]
|
173 |
+
height_out = (height + 2 * ctx.padding -
|
174 |
+
(ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
|
175 |
+
width_out = (width + 2 * ctx.padding -
|
176 |
+
(ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
|
177 |
+
return n, channels_out, height_out, width_out
|
178 |
+
|
179 |
+
|
180 |
+
deform_conv = DeformConvFunction.apply
|
181 |
+
modulated_deform_conv = ModulatedDeformConvFunction.apply
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_pool.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Function
|
3 |
+
|
4 |
+
from .. import deform_pool_cuda
|
5 |
+
|
6 |
+
|
7 |
+
class DeformRoIPoolingFunction(Function):
|
8 |
+
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx,
|
11 |
+
data,
|
12 |
+
rois,
|
13 |
+
offset,
|
14 |
+
spatial_scale,
|
15 |
+
out_size,
|
16 |
+
out_channels,
|
17 |
+
no_trans,
|
18 |
+
group_size=1,
|
19 |
+
part_size=None,
|
20 |
+
sample_per_part=4,
|
21 |
+
trans_std=.0):
|
22 |
+
ctx.spatial_scale = spatial_scale
|
23 |
+
ctx.out_size = out_size
|
24 |
+
ctx.out_channels = out_channels
|
25 |
+
ctx.no_trans = no_trans
|
26 |
+
ctx.group_size = group_size
|
27 |
+
ctx.part_size = out_size if part_size is None else part_size
|
28 |
+
ctx.sample_per_part = sample_per_part
|
29 |
+
ctx.trans_std = trans_std
|
30 |
+
|
31 |
+
assert 0.0 <= ctx.trans_std <= 1.0
|
32 |
+
if not data.is_cuda:
|
33 |
+
raise NotImplementedError
|
34 |
+
|
35 |
+
n = rois.shape[0]
|
36 |
+
output = data.new_empty(n, out_channels, out_size, out_size)
|
37 |
+
output_count = data.new_empty(n, out_channels, out_size, out_size)
|
38 |
+
deform_pool_cuda.deform_psroi_pooling_cuda_forward(
|
39 |
+
data, rois, offset, output, output_count, ctx.no_trans,
|
40 |
+
ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
|
41 |
+
ctx.part_size, ctx.sample_per_part, ctx.trans_std)
|
42 |
+
|
43 |
+
if data.requires_grad or rois.requires_grad or offset.requires_grad:
|
44 |
+
ctx.save_for_backward(data, rois, offset)
|
45 |
+
ctx.output_count = output_count
|
46 |
+
|
47 |
+
return output
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def backward(ctx, grad_output):
|
51 |
+
if not grad_output.is_cuda:
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
data, rois, offset = ctx.saved_tensors
|
55 |
+
output_count = ctx.output_count
|
56 |
+
grad_input = torch.zeros_like(data)
|
57 |
+
grad_rois = None
|
58 |
+
grad_offset = torch.zeros_like(offset)
|
59 |
+
|
60 |
+
deform_pool_cuda.deform_psroi_pooling_cuda_backward(
|
61 |
+
grad_output, data, rois, offset, output_count, grad_input,
|
62 |
+
grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
|
63 |
+
ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
|
64 |
+
ctx.trans_std)
|
65 |
+
return (grad_input, grad_rois, grad_offset, None, None, None, None,
|
66 |
+
None, None, None, None)
|
67 |
+
|
68 |
+
|
69 |
+
deform_roi_pooling = DeformRoIPoolingFunction.apply
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_conv.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn.modules.utils import _pair
|
6 |
+
|
7 |
+
from ..functions.deform_conv import deform_conv, modulated_deform_conv
|
8 |
+
|
9 |
+
|
10 |
+
class DeformConv(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
in_channels,
|
14 |
+
out_channels,
|
15 |
+
kernel_size,
|
16 |
+
stride=1,
|
17 |
+
padding=0,
|
18 |
+
dilation=1,
|
19 |
+
groups=1,
|
20 |
+
deformable_groups=1,
|
21 |
+
bias=False):
|
22 |
+
super(DeformConv, self).__init__()
|
23 |
+
|
24 |
+
assert not bias
|
25 |
+
assert in_channels % groups == 0, \
|
26 |
+
'in_channels {} cannot be divisible by groups {}'.format(
|
27 |
+
in_channels, groups)
|
28 |
+
assert out_channels % groups == 0, \
|
29 |
+
'out_channels {} cannot be divisible by groups {}'.format(
|
30 |
+
out_channels, groups)
|
31 |
+
|
32 |
+
self.in_channels = in_channels
|
33 |
+
self.out_channels = out_channels
|
34 |
+
self.kernel_size = _pair(kernel_size)
|
35 |
+
self.stride = _pair(stride)
|
36 |
+
self.padding = _pair(padding)
|
37 |
+
self.dilation = _pair(dilation)
|
38 |
+
self.groups = groups
|
39 |
+
self.deformable_groups = deformable_groups
|
40 |
+
|
41 |
+
self.weight = nn.Parameter(
|
42 |
+
torch.Tensor(out_channels, in_channels // self.groups,
|
43 |
+
*self.kernel_size))
|
44 |
+
|
45 |
+
self.reset_parameters()
|
46 |
+
|
47 |
+
def reset_parameters(self):
|
48 |
+
n = self.in_channels
|
49 |
+
for k in self.kernel_size:
|
50 |
+
n *= k
|
51 |
+
stdv = 1. / math.sqrt(n)
|
52 |
+
self.weight.data.uniform_(-stdv, stdv)
|
53 |
+
|
54 |
+
def forward(self, x, offset):
|
55 |
+
return deform_conv(x, offset, self.weight, self.stride, self.padding,
|
56 |
+
self.dilation, self.groups, self.deformable_groups)
|
57 |
+
|
58 |
+
|
59 |
+
class DeformConvPack(DeformConv):
|
60 |
+
|
61 |
+
def __init__(self, *args, **kwargs):
|
62 |
+
super(DeformConvPack, self).__init__(*args, **kwargs)
|
63 |
+
|
64 |
+
self.conv_offset = nn.Conv2d(
|
65 |
+
self.in_channels,
|
66 |
+
self.deformable_groups * 2 * self.kernel_size[0] *
|
67 |
+
self.kernel_size[1],
|
68 |
+
kernel_size=self.kernel_size,
|
69 |
+
stride=_pair(self.stride),
|
70 |
+
padding=_pair(self.padding),
|
71 |
+
bias=True)
|
72 |
+
self.init_offset()
|
73 |
+
|
74 |
+
def init_offset(self):
|
75 |
+
self.conv_offset.weight.data.zero_()
|
76 |
+
self.conv_offset.bias.data.zero_()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
offset = self.conv_offset(x)
|
80 |
+
return deform_conv(x, offset, self.weight, self.stride, self.padding,
|
81 |
+
self.dilation, self.groups, self.deformable_groups)
|
82 |
+
|
83 |
+
|
84 |
+
class ModulatedDeformConv(nn.Module):
|
85 |
+
|
86 |
+
def __init__(self,
|
87 |
+
in_channels,
|
88 |
+
out_channels,
|
89 |
+
kernel_size,
|
90 |
+
stride=1,
|
91 |
+
padding=0,
|
92 |
+
dilation=1,
|
93 |
+
groups=1,
|
94 |
+
deformable_groups=1,
|
95 |
+
bias=True):
|
96 |
+
super(ModulatedDeformConv, self).__init__()
|
97 |
+
self.in_channels = in_channels
|
98 |
+
self.out_channels = out_channels
|
99 |
+
self.kernel_size = _pair(kernel_size)
|
100 |
+
self.stride = stride
|
101 |
+
self.padding = padding
|
102 |
+
self.dilation = dilation
|
103 |
+
self.groups = groups
|
104 |
+
self.deformable_groups = deformable_groups
|
105 |
+
self.with_bias = bias
|
106 |
+
|
107 |
+
self.weight = nn.Parameter(
|
108 |
+
torch.Tensor(out_channels, in_channels // groups,
|
109 |
+
*self.kernel_size))
|
110 |
+
if bias:
|
111 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
112 |
+
else:
|
113 |
+
self.register_parameter('bias', None)
|
114 |
+
self.reset_parameters()
|
115 |
+
|
116 |
+
def reset_parameters(self):
|
117 |
+
n = self.in_channels
|
118 |
+
for k in self.kernel_size:
|
119 |
+
n *= k
|
120 |
+
stdv = 1. / math.sqrt(n)
|
121 |
+
self.weight.data.uniform_(-stdv, stdv)
|
122 |
+
if self.bias is not None:
|
123 |
+
self.bias.data.zero_()
|
124 |
+
|
125 |
+
def forward(self, x, offset, mask):
|
126 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
|
127 |
+
self.stride, self.padding, self.dilation,
|
128 |
+
self.groups, self.deformable_groups)
|
129 |
+
|
130 |
+
|
131 |
+
class ModulatedDeformConvPack(ModulatedDeformConv):
|
132 |
+
|
133 |
+
def __init__(self, *args, **kwargs):
|
134 |
+
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
|
135 |
+
|
136 |
+
self.conv_offset_mask = nn.Conv2d(
|
137 |
+
self.in_channels,
|
138 |
+
self.deformable_groups * 3 * self.kernel_size[0] *
|
139 |
+
self.kernel_size[1],
|
140 |
+
kernel_size=self.kernel_size,
|
141 |
+
stride=_pair(self.stride),
|
142 |
+
padding=_pair(self.padding),
|
143 |
+
bias=True)
|
144 |
+
self.init_offset()
|
145 |
+
|
146 |
+
def init_offset(self):
|
147 |
+
self.conv_offset_mask.weight.data.zero_()
|
148 |
+
self.conv_offset_mask.bias.data.zero_()
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
out = self.conv_offset_mask(x)
|
152 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
153 |
+
offset = torch.cat((o1, o2), dim=1)
|
154 |
+
mask = torch.sigmoid(mask)
|
155 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
|
156 |
+
self.stride, self.padding, self.dilation,
|
157 |
+
self.groups, self.deformable_groups)
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_pool.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from ..functions.deform_pool import deform_roi_pooling
|
4 |
+
|
5 |
+
|
6 |
+
class DeformRoIPooling(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self,
|
9 |
+
spatial_scale,
|
10 |
+
out_size,
|
11 |
+
out_channels,
|
12 |
+
no_trans,
|
13 |
+
group_size=1,
|
14 |
+
part_size=None,
|
15 |
+
sample_per_part=4,
|
16 |
+
trans_std=.0):
|
17 |
+
super(DeformRoIPooling, self).__init__()
|
18 |
+
self.spatial_scale = spatial_scale
|
19 |
+
self.out_size = out_size
|
20 |
+
self.out_channels = out_channels
|
21 |
+
self.no_trans = no_trans
|
22 |
+
self.group_size = group_size
|
23 |
+
self.part_size = out_size if part_size is None else part_size
|
24 |
+
self.sample_per_part = sample_per_part
|
25 |
+
self.trans_std = trans_std
|
26 |
+
|
27 |
+
def forward(self, data, rois, offset):
|
28 |
+
if self.no_trans:
|
29 |
+
offset = data.new_empty(0)
|
30 |
+
return deform_roi_pooling(
|
31 |
+
data, rois, offset, self.spatial_scale, self.out_size,
|
32 |
+
self.out_channels, self.no_trans, self.group_size, self.part_size,
|
33 |
+
self.sample_per_part, self.trans_std)
|
34 |
+
|
35 |
+
|
36 |
+
class DeformRoIPoolingPack(DeformRoIPooling):
|
37 |
+
|
38 |
+
def __init__(self,
|
39 |
+
spatial_scale,
|
40 |
+
out_size,
|
41 |
+
out_channels,
|
42 |
+
no_trans,
|
43 |
+
group_size=1,
|
44 |
+
part_size=None,
|
45 |
+
sample_per_part=4,
|
46 |
+
trans_std=.0,
|
47 |
+
num_offset_fcs=3,
|
48 |
+
deform_fc_channels=1024):
|
49 |
+
super(DeformRoIPoolingPack,
|
50 |
+
self).__init__(spatial_scale, out_size, out_channels, no_trans,
|
51 |
+
group_size, part_size, sample_per_part, trans_std)
|
52 |
+
|
53 |
+
self.num_offset_fcs = num_offset_fcs
|
54 |
+
self.deform_fc_channels = deform_fc_channels
|
55 |
+
|
56 |
+
if not no_trans:
|
57 |
+
seq = []
|
58 |
+
ic = self.out_size * self.out_size * self.out_channels
|
59 |
+
for i in range(self.num_offset_fcs):
|
60 |
+
if i < self.num_offset_fcs - 1:
|
61 |
+
oc = self.deform_fc_channels
|
62 |
+
else:
|
63 |
+
oc = self.out_size * self.out_size * 2
|
64 |
+
seq.append(nn.Linear(ic, oc))
|
65 |
+
ic = oc
|
66 |
+
if i < self.num_offset_fcs - 1:
|
67 |
+
seq.append(nn.ReLU(inplace=True))
|
68 |
+
self.offset_fc = nn.Sequential(*seq)
|
69 |
+
self.offset_fc[-1].weight.data.zero_()
|
70 |
+
self.offset_fc[-1].bias.data.zero_()
|
71 |
+
|
72 |
+
def forward(self, data, rois):
|
73 |
+
assert data.size(1) == self.out_channels
|
74 |
+
if self.no_trans:
|
75 |
+
offset = data.new_empty(0)
|
76 |
+
return deform_roi_pooling(
|
77 |
+
data, rois, offset, self.spatial_scale, self.out_size,
|
78 |
+
self.out_channels, self.no_trans, self.group_size,
|
79 |
+
self.part_size, self.sample_per_part, self.trans_std)
|
80 |
+
else:
|
81 |
+
n = rois.shape[0]
|
82 |
+
offset = data.new_empty(0)
|
83 |
+
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
|
84 |
+
self.out_size, self.out_channels, True,
|
85 |
+
self.group_size, self.part_size,
|
86 |
+
self.sample_per_part, self.trans_std)
|
87 |
+
offset = self.offset_fc(x.view(n, -1))
|
88 |
+
offset = offset.view(n, 2, self.out_size, self.out_size)
|
89 |
+
return deform_roi_pooling(
|
90 |
+
data, rois, offset, self.spatial_scale, self.out_size,
|
91 |
+
self.out_channels, self.no_trans, self.group_size,
|
92 |
+
self.part_size, self.sample_per_part, self.trans_std)
|
93 |
+
|
94 |
+
|
95 |
+
class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
|
96 |
+
|
97 |
+
def __init__(self,
|
98 |
+
spatial_scale,
|
99 |
+
out_size,
|
100 |
+
out_channels,
|
101 |
+
no_trans,
|
102 |
+
group_size=1,
|
103 |
+
part_size=None,
|
104 |
+
sample_per_part=4,
|
105 |
+
trans_std=.0,
|
106 |
+
num_offset_fcs=3,
|
107 |
+
num_mask_fcs=2,
|
108 |
+
deform_fc_channels=1024):
|
109 |
+
super(ModulatedDeformRoIPoolingPack, self).__init__(
|
110 |
+
spatial_scale, out_size, out_channels, no_trans, group_size,
|
111 |
+
part_size, sample_per_part, trans_std)
|
112 |
+
|
113 |
+
self.num_offset_fcs = num_offset_fcs
|
114 |
+
self.num_mask_fcs = num_mask_fcs
|
115 |
+
self.deform_fc_channels = deform_fc_channels
|
116 |
+
|
117 |
+
if not no_trans:
|
118 |
+
offset_fc_seq = []
|
119 |
+
ic = self.out_size * self.out_size * self.out_channels
|
120 |
+
for i in range(self.num_offset_fcs):
|
121 |
+
if i < self.num_offset_fcs - 1:
|
122 |
+
oc = self.deform_fc_channels
|
123 |
+
else:
|
124 |
+
oc = self.out_size * self.out_size * 2
|
125 |
+
offset_fc_seq.append(nn.Linear(ic, oc))
|
126 |
+
ic = oc
|
127 |
+
if i < self.num_offset_fcs - 1:
|
128 |
+
offset_fc_seq.append(nn.ReLU(inplace=True))
|
129 |
+
self.offset_fc = nn.Sequential(*offset_fc_seq)
|
130 |
+
self.offset_fc[-1].weight.data.zero_()
|
131 |
+
self.offset_fc[-1].bias.data.zero_()
|
132 |
+
|
133 |
+
mask_fc_seq = []
|
134 |
+
ic = self.out_size * self.out_size * self.out_channels
|
135 |
+
for i in range(self.num_mask_fcs):
|
136 |
+
if i < self.num_mask_fcs - 1:
|
137 |
+
oc = self.deform_fc_channels
|
138 |
+
else:
|
139 |
+
oc = self.out_size * self.out_size
|
140 |
+
mask_fc_seq.append(nn.Linear(ic, oc))
|
141 |
+
ic = oc
|
142 |
+
if i < self.num_mask_fcs - 1:
|
143 |
+
mask_fc_seq.append(nn.ReLU(inplace=True))
|
144 |
+
else:
|
145 |
+
mask_fc_seq.append(nn.Sigmoid())
|
146 |
+
self.mask_fc = nn.Sequential(*mask_fc_seq)
|
147 |
+
self.mask_fc[-2].weight.data.zero_()
|
148 |
+
self.mask_fc[-2].bias.data.zero_()
|
149 |
+
|
150 |
+
def forward(self, data, rois):
|
151 |
+
assert data.size(1) == self.out_channels
|
152 |
+
if self.no_trans:
|
153 |
+
offset = data.new_empty(0)
|
154 |
+
return deform_roi_pooling(
|
155 |
+
data, rois, offset, self.spatial_scale, self.out_size,
|
156 |
+
self.out_channels, self.no_trans, self.group_size,
|
157 |
+
self.part_size, self.sample_per_part, self.trans_std)
|
158 |
+
else:
|
159 |
+
n = rois.shape[0]
|
160 |
+
offset = data.new_empty(0)
|
161 |
+
x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
|
162 |
+
self.out_size, self.out_channels, True,
|
163 |
+
self.group_size, self.part_size,
|
164 |
+
self.sample_per_part, self.trans_std)
|
165 |
+
offset = self.offset_fc(x.view(n, -1))
|
166 |
+
offset = offset.view(n, 2, self.out_size, self.out_size)
|
167 |
+
mask = self.mask_fc(x.view(n, -1))
|
168 |
+
mask = mask.view(n, 1, self.out_size, self.out_size)
|
169 |
+
return deform_roi_pooling(
|
170 |
+
data, rois, offset, self.spatial_scale, self.out_size,
|
171 |
+
self.out_channels, self.no_trans, self.group_size,
|
172 |
+
self.part_size, self.sample_per_part, self.trans_std) * mask
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/setup.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
PATH ="{}:{}".format(os.environ['PATH'], "/opt/cuda/bin")
|
3 |
+
# os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
4 |
+
os.environ['PATH'] = PATH
|
5 |
+
from setuptools import setup
|
6 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
7 |
+
|
8 |
+
setup(
|
9 |
+
name='deform_conv',
|
10 |
+
ext_modules=[
|
11 |
+
CUDAExtension('deform_conv_cuda', [
|
12 |
+
'src/deform_conv_cuda.cpp',
|
13 |
+
'src/deform_conv_cuda_kernel.cu',
|
14 |
+
]),
|
15 |
+
CUDAExtension('deform_pool_cuda', [
|
16 |
+
'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu'
|
17 |
+
]),
|
18 |
+
],
|
19 |
+
cmdclass={'build_ext': BuildExtension})
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda.cpp
ADDED
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// modify from
|
2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
|
3 |
+
|
4 |
+
#include <torch/extension.h>
|
5 |
+
|
6 |
+
#include <cmath>
|
7 |
+
#include <vector>
|
8 |
+
|
9 |
+
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
|
10 |
+
const int channels, const int height, const int width,
|
11 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
12 |
+
const int pad_w, const int stride_h, const int stride_w,
|
13 |
+
const int dilation_h, const int dilation_w,
|
14 |
+
const int parallel_imgs, const int deformable_group,
|
15 |
+
at::Tensor data_col);
|
16 |
+
|
17 |
+
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
|
18 |
+
const int channels, const int height, const int width,
|
19 |
+
const int ksize_h, const int ksize_w, const int pad_h,
|
20 |
+
const int pad_w, const int stride_h, const int stride_w,
|
21 |
+
const int dilation_h, const int dilation_w,
|
22 |
+
const int parallel_imgs, const int deformable_group,
|
23 |
+
at::Tensor grad_im);
|
24 |
+
|
25 |
+
void deformable_col2im_coord(
|
26 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
27 |
+
const at::Tensor data_offset, const int channels, const int height,
|
28 |
+
const int width, const int ksize_h, const int ksize_w, const int pad_h,
|
29 |
+
const int pad_w, const int stride_h, const int stride_w,
|
30 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
31 |
+
const int deformable_group, at::Tensor grad_offset);
|
32 |
+
|
33 |
+
void modulated_deformable_im2col_cuda(
|
34 |
+
const at::Tensor data_im, const at::Tensor data_offset,
|
35 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
36 |
+
const int height_im, const int width_im, const int height_col,
|
37 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
38 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
39 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
40 |
+
at::Tensor data_col);
|
41 |
+
|
42 |
+
void modulated_deformable_col2im_cuda(
|
43 |
+
const at::Tensor data_col, const at::Tensor data_offset,
|
44 |
+
const at::Tensor data_mask, const int batch_size, const int channels,
|
45 |
+
const int height_im, const int width_im, const int height_col,
|
46 |
+
const int width_col, const int kernel_h, const int kenerl_w,
|
47 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
48 |
+
const int dilation_h, const int dilation_w, const int deformable_group,
|
49 |
+
at::Tensor grad_im);
|
50 |
+
|
51 |
+
void modulated_deformable_col2im_coord_cuda(
|
52 |
+
const at::Tensor data_col, const at::Tensor data_im,
|
53 |
+
const at::Tensor data_offset, const at::Tensor data_mask,
|
54 |
+
const int batch_size, const int channels, const int height_im,
|
55 |
+
const int width_im, const int height_col, const int width_col,
|
56 |
+
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
|
57 |
+
const int stride_h, const int stride_w, const int dilation_h,
|
58 |
+
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
|
59 |
+
at::Tensor grad_mask);
|
60 |
+
|
61 |
+
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
|
62 |
+
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
|
63 |
+
int padW, int dilationH, int dilationW, int group,
|
64 |
+
int deformable_group) {
|
65 |
+
TORCH_CHECK(weight.ndimension() == 4,
|
66 |
+
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
|
67 |
+
"but got: %s",
|
68 |
+
weight.ndimension());
|
69 |
+
|
70 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
71 |
+
|
72 |
+
TORCH_CHECK(kW > 0 && kH > 0,
|
73 |
+
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
|
74 |
+
kW);
|
75 |
+
|
76 |
+
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
|
77 |
+
"kernel size should be consistent with weight, ",
|
78 |
+
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
|
79 |
+
kW, weight.size(2), weight.size(3));
|
80 |
+
|
81 |
+
TORCH_CHECK(dW > 0 && dH > 0,
|
82 |
+
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
|
83 |
+
|
84 |
+
TORCH_CHECK(
|
85 |
+
dilationW > 0 && dilationH > 0,
|
86 |
+
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
|
87 |
+
dilationH, dilationW);
|
88 |
+
|
89 |
+
int ndim = input.ndimension();
|
90 |
+
int dimf = 0;
|
91 |
+
int dimh = 1;
|
92 |
+
int dimw = 2;
|
93 |
+
|
94 |
+
if (ndim == 4) {
|
95 |
+
dimf++;
|
96 |
+
dimh++;
|
97 |
+
dimw++;
|
98 |
+
}
|
99 |
+
|
100 |
+
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
|
101 |
+
ndim);
|
102 |
+
|
103 |
+
long nInputPlane = weight.size(1) * group;
|
104 |
+
long inputHeight = input.size(dimh);
|
105 |
+
long inputWidth = input.size(dimw);
|
106 |
+
long nOutputPlane = weight.size(0);
|
107 |
+
long outputHeight =
|
108 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
109 |
+
long outputWidth =
|
110 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
111 |
+
|
112 |
+
TORCH_CHECK(nInputPlane % deformable_group == 0,
|
113 |
+
"input channels must divide deformable group size");
|
114 |
+
|
115 |
+
if (outputWidth < 1 || outputHeight < 1)
|
116 |
+
AT_ERROR(
|
117 |
+
"Given input size: (%ld x %ld x %ld). "
|
118 |
+
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
|
119 |
+
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
|
120 |
+
outputWidth);
|
121 |
+
|
122 |
+
TORCH_CHECK(input.size(1) == nInputPlane,
|
123 |
+
"invalid number of input planes, expected: %d, but got: %d",
|
124 |
+
nInputPlane, input.size(1));
|
125 |
+
|
126 |
+
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
|
127 |
+
"input image is smaller than kernel");
|
128 |
+
|
129 |
+
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
|
130 |
+
"invalid spatial size of offset, expected height: %d width: %d, but "
|
131 |
+
"got height: %d width: %d",
|
132 |
+
outputHeight, outputWidth, offset.size(2), offset.size(3));
|
133 |
+
|
134 |
+
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
|
135 |
+
"invalid number of channels of offset");
|
136 |
+
|
137 |
+
if (gradOutput != NULL) {
|
138 |
+
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
|
139 |
+
"invalid number of gradOutput planes, expected: %d, but got: %d",
|
140 |
+
nOutputPlane, gradOutput->size(dimf));
|
141 |
+
|
142 |
+
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
|
143 |
+
gradOutput->size(dimw) == outputWidth),
|
144 |
+
"invalid size of gradOutput, expected height: %d width: %d , but "
|
145 |
+
"got height: %d width: %d",
|
146 |
+
outputHeight, outputWidth, gradOutput->size(dimh),
|
147 |
+
gradOutput->size(dimw));
|
148 |
+
}
|
149 |
+
}
|
150 |
+
|
151 |
+
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
|
152 |
+
at::Tensor offset, at::Tensor output,
|
153 |
+
at::Tensor columns, at::Tensor ones, int kW,
|
154 |
+
int kH, int dW, int dH, int padW, int padH,
|
155 |
+
int dilationW, int dilationH, int group,
|
156 |
+
int deformable_group, int im2col_step) {
|
157 |
+
// todo: resize columns to include im2col: done
|
158 |
+
// todo: add im2col_step as input
|
159 |
+
// todo: add new output buffer and transpose it to output (or directly
|
160 |
+
// transpose output) todo: possibly change data indexing because of
|
161 |
+
// parallel_imgs
|
162 |
+
|
163 |
+
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
|
164 |
+
dilationH, dilationW, group, deformable_group);
|
165 |
+
|
166 |
+
input = input.contiguous();
|
167 |
+
offset = offset.contiguous();
|
168 |
+
weight = weight.contiguous();
|
169 |
+
|
170 |
+
int batch = 1;
|
171 |
+
if (input.ndimension() == 3) {
|
172 |
+
// Force batch
|
173 |
+
batch = 0;
|
174 |
+
input.unsqueeze_(0);
|
175 |
+
offset.unsqueeze_(0);
|
176 |
+
}
|
177 |
+
|
178 |
+
// todo: assert batchsize dividable by im2col_step
|
179 |
+
|
180 |
+
long batchSize = input.size(0);
|
181 |
+
long nInputPlane = input.size(1);
|
182 |
+
long inputHeight = input.size(2);
|
183 |
+
long inputWidth = input.size(3);
|
184 |
+
|
185 |
+
long nOutputPlane = weight.size(0);
|
186 |
+
|
187 |
+
long outputWidth =
|
188 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
189 |
+
long outputHeight =
|
190 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
191 |
+
|
192 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
193 |
+
|
194 |
+
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
|
195 |
+
outputHeight, outputWidth});
|
196 |
+
columns = at::zeros(
|
197 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
198 |
+
input.options());
|
199 |
+
|
200 |
+
if (ones.ndimension() != 2 ||
|
201 |
+
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
|
202 |
+
ones = at::ones({outputHeight, outputWidth}, input.options());
|
203 |
+
}
|
204 |
+
|
205 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
206 |
+
inputHeight, inputWidth});
|
207 |
+
offset =
|
208 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
209 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
210 |
+
|
211 |
+
at::Tensor output_buffer =
|
212 |
+
at::zeros({batchSize / im2col_step, nOutputPlane,
|
213 |
+
im2col_step * outputHeight, outputWidth},
|
214 |
+
output.options());
|
215 |
+
|
216 |
+
output_buffer = output_buffer.view(
|
217 |
+
{output_buffer.size(0), group, output_buffer.size(1) / group,
|
218 |
+
output_buffer.size(2), output_buffer.size(3)});
|
219 |
+
|
220 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
221 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
222 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
223 |
+
dilationW, im2col_step, deformable_group, columns);
|
224 |
+
|
225 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
226 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
227 |
+
weight.size(2), weight.size(3)});
|
228 |
+
|
229 |
+
for (int g = 0; g < group; g++) {
|
230 |
+
output_buffer[elt][g] = output_buffer[elt][g]
|
231 |
+
.flatten(1)
|
232 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
233 |
+
.view_as(output_buffer[elt][g]);
|
234 |
+
}
|
235 |
+
}
|
236 |
+
|
237 |
+
output_buffer = output_buffer.view(
|
238 |
+
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
|
239 |
+
output_buffer.size(3), output_buffer.size(4)});
|
240 |
+
|
241 |
+
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
|
242 |
+
im2col_step, outputHeight, outputWidth});
|
243 |
+
output_buffer.transpose_(1, 2);
|
244 |
+
output.copy_(output_buffer);
|
245 |
+
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
246 |
+
|
247 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
248 |
+
offset = offset.view(
|
249 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
250 |
+
|
251 |
+
if (batch == 0) {
|
252 |
+
output = output.view({nOutputPlane, outputHeight, outputWidth});
|
253 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
254 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
255 |
+
}
|
256 |
+
|
257 |
+
return 1;
|
258 |
+
}
|
259 |
+
|
260 |
+
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
|
261 |
+
at::Tensor gradOutput, at::Tensor gradInput,
|
262 |
+
at::Tensor gradOffset, at::Tensor weight,
|
263 |
+
at::Tensor columns, int kW, int kH, int dW,
|
264 |
+
int dH, int padW, int padH, int dilationW,
|
265 |
+
int dilationH, int group,
|
266 |
+
int deformable_group, int im2col_step) {
|
267 |
+
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
|
268 |
+
dilationH, dilationW, group, deformable_group);
|
269 |
+
|
270 |
+
input = input.contiguous();
|
271 |
+
offset = offset.contiguous();
|
272 |
+
gradOutput = gradOutput.contiguous();
|
273 |
+
weight = weight.contiguous();
|
274 |
+
|
275 |
+
int batch = 1;
|
276 |
+
|
277 |
+
if (input.ndimension() == 3) {
|
278 |
+
// Force batch
|
279 |
+
batch = 0;
|
280 |
+
input = input.view({1, input.size(0), input.size(1), input.size(2)});
|
281 |
+
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
|
282 |
+
gradOutput = gradOutput.view(
|
283 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
284 |
+
}
|
285 |
+
|
286 |
+
long batchSize = input.size(0);
|
287 |
+
long nInputPlane = input.size(1);
|
288 |
+
long inputHeight = input.size(2);
|
289 |
+
long inputWidth = input.size(3);
|
290 |
+
|
291 |
+
long nOutputPlane = weight.size(0);
|
292 |
+
|
293 |
+
long outputWidth =
|
294 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
295 |
+
long outputHeight =
|
296 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
297 |
+
|
298 |
+
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
|
299 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
300 |
+
columns = at::zeros(
|
301 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
302 |
+
input.options());
|
303 |
+
|
304 |
+
// change order of grad output
|
305 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
306 |
+
nOutputPlane, outputHeight, outputWidth});
|
307 |
+
gradOutput.transpose_(1, 2);
|
308 |
+
|
309 |
+
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
310 |
+
inputHeight, inputWidth});
|
311 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
312 |
+
inputHeight, inputWidth});
|
313 |
+
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
|
314 |
+
deformable_group * 2 * kH * kW, outputHeight,
|
315 |
+
outputWidth});
|
316 |
+
offset =
|
317 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
318 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
319 |
+
|
320 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
321 |
+
// divide into groups
|
322 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
323 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
324 |
+
weight.size(2), weight.size(3)});
|
325 |
+
gradOutput = gradOutput.view(
|
326 |
+
{gradOutput.size(0), group, gradOutput.size(1) / group,
|
327 |
+
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
|
328 |
+
|
329 |
+
for (int g = 0; g < group; g++) {
|
330 |
+
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
331 |
+
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
|
332 |
+
}
|
333 |
+
|
334 |
+
columns =
|
335 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
336 |
+
gradOutput = gradOutput.view(
|
337 |
+
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
|
338 |
+
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
|
339 |
+
|
340 |
+
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
|
341 |
+
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
|
342 |
+
dilationH, dilationW, im2col_step, deformable_group,
|
343 |
+
gradOffset[elt]);
|
344 |
+
|
345 |
+
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
|
346 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
347 |
+
dilationW, im2col_step, deformable_group, gradInput[elt]);
|
348 |
+
}
|
349 |
+
|
350 |
+
gradOutput.transpose_(1, 2);
|
351 |
+
gradOutput =
|
352 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
353 |
+
|
354 |
+
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
355 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
356 |
+
gradOffset = gradOffset.view(
|
357 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
358 |
+
offset = offset.view(
|
359 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
360 |
+
|
361 |
+
if (batch == 0) {
|
362 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
363 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
364 |
+
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
|
365 |
+
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
|
366 |
+
gradOffset =
|
367 |
+
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
|
368 |
+
}
|
369 |
+
|
370 |
+
return 1;
|
371 |
+
}
|
372 |
+
|
373 |
+
int deform_conv_backward_parameters_cuda(
|
374 |
+
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
|
375 |
+
at::Tensor gradWeight, // at::Tensor gradBias,
|
376 |
+
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
|
377 |
+
int padW, int padH, int dilationW, int dilationH, int group,
|
378 |
+
int deformable_group, float scale, int im2col_step) {
|
379 |
+
// todo: transpose and reshape outGrad
|
380 |
+
// todo: reshape columns
|
381 |
+
// todo: add im2col_step as input
|
382 |
+
|
383 |
+
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
|
384 |
+
padW, dilationH, dilationW, group, deformable_group);
|
385 |
+
|
386 |
+
input = input.contiguous();
|
387 |
+
offset = offset.contiguous();
|
388 |
+
gradOutput = gradOutput.contiguous();
|
389 |
+
|
390 |
+
int batch = 1;
|
391 |
+
|
392 |
+
if (input.ndimension() == 3) {
|
393 |
+
// Force batch
|
394 |
+
batch = 0;
|
395 |
+
input = input.view(
|
396 |
+
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
|
397 |
+
gradOutput = gradOutput.view(
|
398 |
+
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
|
399 |
+
}
|
400 |
+
|
401 |
+
long batchSize = input.size(0);
|
402 |
+
long nInputPlane = input.size(1);
|
403 |
+
long inputHeight = input.size(2);
|
404 |
+
long inputWidth = input.size(3);
|
405 |
+
|
406 |
+
long nOutputPlane = gradWeight.size(0);
|
407 |
+
|
408 |
+
long outputWidth =
|
409 |
+
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
|
410 |
+
long outputHeight =
|
411 |
+
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
|
412 |
+
|
413 |
+
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
|
414 |
+
|
415 |
+
columns = at::zeros(
|
416 |
+
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
|
417 |
+
input.options());
|
418 |
+
|
419 |
+
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
|
420 |
+
nOutputPlane, outputHeight, outputWidth});
|
421 |
+
gradOutput.transpose_(1, 2);
|
422 |
+
|
423 |
+
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
|
424 |
+
gradOutputBuffer =
|
425 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
|
426 |
+
outputHeight, outputWidth});
|
427 |
+
gradOutputBuffer.copy_(gradOutput);
|
428 |
+
gradOutputBuffer =
|
429 |
+
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
|
430 |
+
im2col_step * outputHeight, outputWidth});
|
431 |
+
|
432 |
+
gradOutput.transpose_(1, 2);
|
433 |
+
gradOutput =
|
434 |
+
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
|
435 |
+
|
436 |
+
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
|
437 |
+
inputHeight, inputWidth});
|
438 |
+
offset =
|
439 |
+
offset.view({batchSize / im2col_step, im2col_step,
|
440 |
+
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
441 |
+
|
442 |
+
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
|
443 |
+
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
|
444 |
+
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
|
445 |
+
dilationW, im2col_step, deformable_group, columns);
|
446 |
+
|
447 |
+
// divide into group
|
448 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
449 |
+
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
|
450 |
+
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
|
451 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
452 |
+
gradWeight =
|
453 |
+
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
|
454 |
+
gradWeight.size(2), gradWeight.size(3)});
|
455 |
+
|
456 |
+
for (int g = 0; g < group; g++) {
|
457 |
+
gradWeight[g] = gradWeight[g]
|
458 |
+
.flatten(1)
|
459 |
+
.addmm_(gradOutputBuffer[elt][g].flatten(1),
|
460 |
+
columns[g].transpose(1, 0), 1.0, scale)
|
461 |
+
.view_as(gradWeight[g]);
|
462 |
+
}
|
463 |
+
gradOutputBuffer = gradOutputBuffer.view(
|
464 |
+
{gradOutputBuffer.size(0),
|
465 |
+
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
|
466 |
+
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
|
467 |
+
columns =
|
468 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
469 |
+
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
|
470 |
+
gradWeight.size(2), gradWeight.size(3),
|
471 |
+
gradWeight.size(4)});
|
472 |
+
}
|
473 |
+
|
474 |
+
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
|
475 |
+
offset = offset.view(
|
476 |
+
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
|
477 |
+
|
478 |
+
if (batch == 0) {
|
479 |
+
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
|
480 |
+
input = input.view({nInputPlane, inputHeight, inputWidth});
|
481 |
+
}
|
482 |
+
|
483 |
+
return 1;
|
484 |
+
}
|
485 |
+
|
486 |
+
void modulated_deform_conv_cuda_forward(
|
487 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
488 |
+
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
|
489 |
+
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
|
490 |
+
const int pad_h, const int pad_w, const int dilation_h,
|
491 |
+
const int dilation_w, const int group, const int deformable_group,
|
492 |
+
const bool with_bias) {
|
493 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
494 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
495 |
+
|
496 |
+
const int batch = input.size(0);
|
497 |
+
const int channels = input.size(1);
|
498 |
+
const int height = input.size(2);
|
499 |
+
const int width = input.size(3);
|
500 |
+
|
501 |
+
const int channels_out = weight.size(0);
|
502 |
+
const int channels_kernel = weight.size(1);
|
503 |
+
const int kernel_h_ = weight.size(2);
|
504 |
+
const int kernel_w_ = weight.size(3);
|
505 |
+
|
506 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
507 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
508 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
509 |
+
if (channels != channels_kernel * group)
|
510 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
511 |
+
channels, channels_kernel * group);
|
512 |
+
|
513 |
+
const int height_out =
|
514 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
515 |
+
const int width_out =
|
516 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
517 |
+
|
518 |
+
if (ones.ndimension() != 2 ||
|
519 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
520 |
+
// Resize plane and fill with ones...
|
521 |
+
ones = at::ones({height_out, width_out}, input.options());
|
522 |
+
}
|
523 |
+
|
524 |
+
// resize output
|
525 |
+
output = output.view({batch, channels_out, height_out, width_out}).zero_();
|
526 |
+
// resize temporary columns
|
527 |
+
columns =
|
528 |
+
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
|
529 |
+
input.options());
|
530 |
+
|
531 |
+
output = output.view({output.size(0), group, output.size(1) / group,
|
532 |
+
output.size(2), output.size(3)});
|
533 |
+
|
534 |
+
for (int b = 0; b < batch; b++) {
|
535 |
+
modulated_deformable_im2col_cuda(
|
536 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
537 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
538 |
+
dilation_h, dilation_w, deformable_group, columns);
|
539 |
+
|
540 |
+
// divide into group
|
541 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
542 |
+
weight.size(2), weight.size(3)});
|
543 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
544 |
+
|
545 |
+
for (int g = 0; g < group; g++) {
|
546 |
+
output[b][g] = output[b][g]
|
547 |
+
.flatten(1)
|
548 |
+
.addmm_(weight[g].flatten(1), columns[g])
|
549 |
+
.view_as(output[b][g]);
|
550 |
+
}
|
551 |
+
|
552 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
553 |
+
weight.size(3), weight.size(4)});
|
554 |
+
columns =
|
555 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
556 |
+
}
|
557 |
+
|
558 |
+
output = output.view({output.size(0), output.size(1) * output.size(2),
|
559 |
+
output.size(3), output.size(4)});
|
560 |
+
|
561 |
+
if (with_bias) {
|
562 |
+
output += bias.view({1, bias.size(0), 1, 1});
|
563 |
+
}
|
564 |
+
}
|
565 |
+
|
566 |
+
void modulated_deform_conv_cuda_backward(
|
567 |
+
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
|
568 |
+
at::Tensor offset, at::Tensor mask, at::Tensor columns,
|
569 |
+
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
|
570 |
+
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
|
571 |
+
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
|
572 |
+
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
|
573 |
+
const bool with_bias) {
|
574 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
575 |
+
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
|
576 |
+
|
577 |
+
const int batch = input.size(0);
|
578 |
+
const int channels = input.size(1);
|
579 |
+
const int height = input.size(2);
|
580 |
+
const int width = input.size(3);
|
581 |
+
|
582 |
+
const int channels_kernel = weight.size(1);
|
583 |
+
const int kernel_h_ = weight.size(2);
|
584 |
+
const int kernel_w_ = weight.size(3);
|
585 |
+
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
|
586 |
+
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
|
587 |
+
kernel_h_, kernel_w, kernel_h_, kernel_w_);
|
588 |
+
if (channels != channels_kernel * group)
|
589 |
+
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
|
590 |
+
channels, channels_kernel * group);
|
591 |
+
|
592 |
+
const int height_out =
|
593 |
+
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
|
594 |
+
const int width_out =
|
595 |
+
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
|
596 |
+
|
597 |
+
if (ones.ndimension() != 2 ||
|
598 |
+
ones.size(0) * ones.size(1) < height_out * width_out) {
|
599 |
+
// Resize plane and fill with ones...
|
600 |
+
ones = at::ones({height_out, width_out}, input.options());
|
601 |
+
}
|
602 |
+
|
603 |
+
grad_input = grad_input.view({batch, channels, height, width});
|
604 |
+
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
|
605 |
+
input.options());
|
606 |
+
|
607 |
+
grad_output =
|
608 |
+
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
|
609 |
+
grad_output.size(2), grad_output.size(3)});
|
610 |
+
|
611 |
+
for (int b = 0; b < batch; b++) {
|
612 |
+
// divide int group
|
613 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
614 |
+
weight = weight.view({group, weight.size(0) / group, weight.size(1),
|
615 |
+
weight.size(2), weight.size(3)});
|
616 |
+
|
617 |
+
for (int g = 0; g < group; g++) {
|
618 |
+
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
|
619 |
+
grad_output[b][g].flatten(1), 0.0f, 1.0f);
|
620 |
+
}
|
621 |
+
|
622 |
+
columns =
|
623 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
624 |
+
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
|
625 |
+
weight.size(3), weight.size(4)});
|
626 |
+
|
627 |
+
// gradient w.r.t. input coordinate data
|
628 |
+
modulated_deformable_col2im_coord_cuda(
|
629 |
+
columns, input[b], offset[b], mask[b], 1, channels, height, width,
|
630 |
+
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
|
631 |
+
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
|
632 |
+
grad_mask[b]);
|
633 |
+
// gradient w.r.t. input data
|
634 |
+
modulated_deformable_col2im_cuda(
|
635 |
+
columns, offset[b], mask[b], 1, channels, height, width, height_out,
|
636 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
637 |
+
dilation_h, dilation_w, deformable_group, grad_input[b]);
|
638 |
+
|
639 |
+
// gradient w.r.t. weight, dWeight should accumulate across the batch and
|
640 |
+
// group
|
641 |
+
modulated_deformable_im2col_cuda(
|
642 |
+
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
|
643 |
+
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
644 |
+
dilation_h, dilation_w, deformable_group, columns);
|
645 |
+
|
646 |
+
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
|
647 |
+
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
|
648 |
+
grad_weight.size(1), grad_weight.size(2),
|
649 |
+
grad_weight.size(3)});
|
650 |
+
if (with_bias)
|
651 |
+
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
|
652 |
+
|
653 |
+
for (int g = 0; g < group; g++) {
|
654 |
+
grad_weight[g] =
|
655 |
+
grad_weight[g]
|
656 |
+
.flatten(1)
|
657 |
+
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
|
658 |
+
.view_as(grad_weight[g]);
|
659 |
+
if (with_bias) {
|
660 |
+
grad_bias[g] =
|
661 |
+
grad_bias[g]
|
662 |
+
.view({-1, 1})
|
663 |
+
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
|
664 |
+
.view(-1);
|
665 |
+
}
|
666 |
+
}
|
667 |
+
|
668 |
+
columns =
|
669 |
+
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
|
670 |
+
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
|
671 |
+
grad_weight.size(2), grad_weight.size(3),
|
672 |
+
grad_weight.size(4)});
|
673 |
+
if (with_bias)
|
674 |
+
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
|
675 |
+
}
|
676 |
+
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
|
677 |
+
grad_output.size(2), grad_output.size(3),
|
678 |
+
grad_output.size(4)});
|
679 |
+
}
|
680 |
+
|
681 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
682 |
+
m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
|
683 |
+
"deform forward (CUDA)");
|
684 |
+
m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
|
685 |
+
"deform_conv_backward_input (CUDA)");
|
686 |
+
m.def("deform_conv_backward_parameters_cuda",
|
687 |
+
&deform_conv_backward_parameters_cuda,
|
688 |
+
"deform_conv_backward_parameters (CUDA)");
|
689 |
+
m.def("modulated_deform_conv_cuda_forward",
|
690 |
+
&modulated_deform_conv_cuda_forward,
|
691 |
+
"modulated deform conv forward (CUDA)");
|
692 |
+
m.def("modulated_deform_conv_cuda_backward",
|
693 |
+
&modulated_deform_conv_cuda_backward,
|
694 |
+
"modulated deform conv backward (CUDA)");
|
695 |
+
}
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda_kernel.cu
ADDED
@@ -0,0 +1,866 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
|
3 |
+
*
|
4 |
+
* COPYRIGHT
|
5 |
+
*
|
6 |
+
* All contributions by the University of California:
|
7 |
+
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
|
8 |
+
* All rights reserved.
|
9 |
+
*
|
10 |
+
* All other contributions:
|
11 |
+
* Copyright (c) 2014-2017, the respective contributors
|
12 |
+
* All rights reserved.
|
13 |
+
*
|
14 |
+
* Caffe uses a shared copyright model: each contributor holds copyright over
|
15 |
+
* their contributions to Caffe. The project versioning records all such
|
16 |
+
* contribution and copyright details. If a contributor wants to further mark
|
17 |
+
* their specific copyright on a particular contribution, they should indicate
|
18 |
+
* their copyright solely in the commit message of the change when it is
|
19 |
+
* committed.
|
20 |
+
*
|
21 |
+
* LICENSE
|
22 |
+
*
|
23 |
+
* Redistribution and use in source and binary forms, with or without
|
24 |
+
* modification, are permitted provided that the following conditions are met:
|
25 |
+
*
|
26 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
27 |
+
* list of conditions and the following disclaimer.
|
28 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
29 |
+
* this list of conditions and the following disclaimer in the documentation
|
30 |
+
* and/or other materials provided with the distribution.
|
31 |
+
*
|
32 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
33 |
+
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
34 |
+
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
35 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
|
36 |
+
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
37 |
+
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
38 |
+
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
39 |
+
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
40 |
+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
41 |
+
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
42 |
+
*
|
43 |
+
* CONTRIBUTION AGREEMENT
|
44 |
+
*
|
45 |
+
* By contributing to the BVLC/caffe repository through pull-request, comment,
|
46 |
+
* or otherwise, the contributor releases their content to the
|
47 |
+
* license and copyright terms herein.
|
48 |
+
*
|
49 |
+
***************** END Caffe Copyright Notice and Disclaimer ********************
|
50 |
+
*
|
51 |
+
* Copyright (c) 2018 Microsoft
|
52 |
+
* Licensed under The MIT License [see LICENSE for details]
|
53 |
+
* \file modulated_deformable_im2col.cuh
|
54 |
+
* \brief Function definitions of converting an image to
|
55 |
+
* column matrix based on kernel, padding, dilation, and offset.
|
56 |
+
* These functions are mainly used in deformable convolution operators.
|
57 |
+
* \ref: https://arxiv.org/abs/1703.06211
|
58 |
+
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
|
59 |
+
*/
|
60 |
+
|
61 |
+
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
|
62 |
+
|
63 |
+
#include <ATen/ATen.h>
|
64 |
+
#include <THC/THCAtomics.cuh>
|
65 |
+
#include <stdio.h>
|
66 |
+
#include <math.h>
|
67 |
+
#include <float.h>
|
68 |
+
|
69 |
+
using namespace at;
|
70 |
+
|
71 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
72 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
73 |
+
i += blockDim.x * gridDim.x)
|
74 |
+
|
75 |
+
const int CUDA_NUM_THREADS = 1024;
|
76 |
+
const int kMaxGridNum = 65535;
|
77 |
+
|
78 |
+
inline int GET_BLOCKS(const int N)
|
79 |
+
{
|
80 |
+
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
|
81 |
+
}
|
82 |
+
|
83 |
+
template <typename scalar_t>
|
84 |
+
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
85 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
86 |
+
{
|
87 |
+
|
88 |
+
int h_low = floor(h);
|
89 |
+
int w_low = floor(w);
|
90 |
+
int h_high = h_low + 1;
|
91 |
+
int w_high = w_low + 1;
|
92 |
+
|
93 |
+
scalar_t lh = h - h_low;
|
94 |
+
scalar_t lw = w - w_low;
|
95 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
96 |
+
|
97 |
+
scalar_t v1 = 0;
|
98 |
+
if (h_low >= 0 && w_low >= 0)
|
99 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
100 |
+
scalar_t v2 = 0;
|
101 |
+
if (h_low >= 0 && w_high <= width - 1)
|
102 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
103 |
+
scalar_t v3 = 0;
|
104 |
+
if (h_high <= height - 1 && w_low >= 0)
|
105 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
106 |
+
scalar_t v4 = 0;
|
107 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
108 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
109 |
+
|
110 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
111 |
+
|
112 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
113 |
+
return val;
|
114 |
+
}
|
115 |
+
|
116 |
+
template <typename scalar_t>
|
117 |
+
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
118 |
+
const int h, const int w, const int height, const int width)
|
119 |
+
{
|
120 |
+
|
121 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
122 |
+
{
|
123 |
+
//empty
|
124 |
+
return 0;
|
125 |
+
}
|
126 |
+
|
127 |
+
int argmax_h_low = floor(argmax_h);
|
128 |
+
int argmax_w_low = floor(argmax_w);
|
129 |
+
int argmax_h_high = argmax_h_low + 1;
|
130 |
+
int argmax_w_high = argmax_w_low + 1;
|
131 |
+
|
132 |
+
scalar_t weight = 0;
|
133 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
134 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
135 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
136 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
137 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
138 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
139 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
140 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
141 |
+
return weight;
|
142 |
+
}
|
143 |
+
|
144 |
+
template <typename scalar_t>
|
145 |
+
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
146 |
+
const int height, const int width, const scalar_t *im_data,
|
147 |
+
const int data_width, const int bp_dir)
|
148 |
+
{
|
149 |
+
|
150 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
151 |
+
{
|
152 |
+
//empty
|
153 |
+
return 0;
|
154 |
+
}
|
155 |
+
|
156 |
+
int argmax_h_low = floor(argmax_h);
|
157 |
+
int argmax_w_low = floor(argmax_w);
|
158 |
+
int argmax_h_high = argmax_h_low + 1;
|
159 |
+
int argmax_w_high = argmax_w_low + 1;
|
160 |
+
|
161 |
+
scalar_t weight = 0;
|
162 |
+
|
163 |
+
if (bp_dir == 0)
|
164 |
+
{
|
165 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
166 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
167 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
168 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
169 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
170 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
171 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
172 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
173 |
+
}
|
174 |
+
else if (bp_dir == 1)
|
175 |
+
{
|
176 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
177 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
178 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
179 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
180 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
181 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
182 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
183 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
184 |
+
}
|
185 |
+
|
186 |
+
return weight;
|
187 |
+
}
|
188 |
+
|
189 |
+
template <typename scalar_t>
|
190 |
+
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
|
191 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
192 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
193 |
+
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
|
194 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
195 |
+
const int height_col, const int width_col,
|
196 |
+
scalar_t *data_col)
|
197 |
+
{
|
198 |
+
CUDA_KERNEL_LOOP(index, n)
|
199 |
+
{
|
200 |
+
// index index of output matrix
|
201 |
+
const int w_col = index % width_col;
|
202 |
+
const int h_col = (index / width_col) % height_col;
|
203 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
204 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
205 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
206 |
+
|
207 |
+
// compute deformable group index
|
208 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
209 |
+
|
210 |
+
const int h_in = h_col * stride_h - pad_h;
|
211 |
+
const int w_in = w_col * stride_w - pad_w;
|
212 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
213 |
+
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
214 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
215 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
216 |
+
|
217 |
+
for (int i = 0; i < kernel_h; ++i)
|
218 |
+
{
|
219 |
+
for (int j = 0; j < kernel_w; ++j)
|
220 |
+
{
|
221 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
222 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
223 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
224 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
225 |
+
scalar_t val = static_cast<scalar_t>(0);
|
226 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
227 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
228 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
229 |
+
{
|
230 |
+
//const scalar_t map_h = i * dilation_h + offset_h;
|
231 |
+
//const scalar_t map_w = j * dilation_w + offset_w;
|
232 |
+
//const int cur_height = height - h_in;
|
233 |
+
//const int cur_width = width - w_in;
|
234 |
+
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
235 |
+
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
236 |
+
}
|
237 |
+
*data_col_ptr = val;
|
238 |
+
data_col_ptr += batch_size * height_col * width_col;
|
239 |
+
}
|
240 |
+
}
|
241 |
+
}
|
242 |
+
}
|
243 |
+
|
244 |
+
void deformable_im2col(
|
245 |
+
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
|
246 |
+
const int height, const int width, const int ksize_h, const int ksize_w,
|
247 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
248 |
+
const int dilation_h, const int dilation_w, const int parallel_imgs,
|
249 |
+
const int deformable_group, at::Tensor data_col)
|
250 |
+
{
|
251 |
+
// num_axes should be smaller than block size
|
252 |
+
// todo: check parallel_imgs is correctly passed in
|
253 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
254 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
255 |
+
int num_kernels = channels * height_col * width_col * parallel_imgs;
|
256 |
+
int channel_per_deformable_group = channels / deformable_group;
|
257 |
+
|
258 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
259 |
+
data_im.type(), "deformable_im2col_gpu", ([&] {
|
260 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
261 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
262 |
+
scalar_t *data_col_ = data_col.data<scalar_t>();
|
263 |
+
|
264 |
+
deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
265 |
+
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
|
266 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
|
267 |
+
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
|
268 |
+
height_col, width_col, data_col_);
|
269 |
+
}));
|
270 |
+
|
271 |
+
cudaError_t err = cudaGetLastError();
|
272 |
+
if (err != cudaSuccess)
|
273 |
+
{
|
274 |
+
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
|
275 |
+
}
|
276 |
+
}
|
277 |
+
|
278 |
+
template <typename scalar_t>
|
279 |
+
__global__ void deformable_col2im_gpu_kernel(
|
280 |
+
const int n, const scalar_t *data_col, const scalar_t *data_offset,
|
281 |
+
const int channels, const int height, const int width,
|
282 |
+
const int kernel_h, const int kernel_w,
|
283 |
+
const int pad_h, const int pad_w,
|
284 |
+
const int stride_h, const int stride_w,
|
285 |
+
const int dilation_h, const int dilation_w,
|
286 |
+
const int channel_per_deformable_group,
|
287 |
+
const int batch_size, const int deformable_group,
|
288 |
+
const int height_col, const int width_col,
|
289 |
+
scalar_t *grad_im)
|
290 |
+
{
|
291 |
+
CUDA_KERNEL_LOOP(index, n)
|
292 |
+
{
|
293 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
294 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
295 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
296 |
+
// compute the start and end of the output
|
297 |
+
|
298 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
299 |
+
|
300 |
+
int w_out = index % width_col;
|
301 |
+
int h_out = (index / width_col) % height_col;
|
302 |
+
int b = (index / width_col / height_col) % batch_size;
|
303 |
+
int w_in = w_out * stride_w - pad_w;
|
304 |
+
int h_in = h_out * stride_h - pad_h;
|
305 |
+
|
306 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
|
307 |
+
2 * kernel_h * kernel_w * height_col * width_col;
|
308 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
309 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
310 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
311 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
312 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
313 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
314 |
+
|
315 |
+
const scalar_t cur_top_grad = data_col[index];
|
316 |
+
const int cur_h = (int)cur_inv_h_data;
|
317 |
+
const int cur_w = (int)cur_inv_w_data;
|
318 |
+
for (int dy = -2; dy <= 2; dy++)
|
319 |
+
{
|
320 |
+
for (int dx = -2; dx <= 2; dx++)
|
321 |
+
{
|
322 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
323 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
324 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
325 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
326 |
+
{
|
327 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
328 |
+
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
329 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
330 |
+
}
|
331 |
+
}
|
332 |
+
}
|
333 |
+
}
|
334 |
+
}
|
335 |
+
|
336 |
+
void deformable_col2im(
|
337 |
+
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
|
338 |
+
const int height, const int width, const int ksize_h,
|
339 |
+
const int ksize_w, const int pad_h, const int pad_w,
|
340 |
+
const int stride_h, const int stride_w,
|
341 |
+
const int dilation_h, const int dilation_w,
|
342 |
+
const int parallel_imgs, const int deformable_group,
|
343 |
+
at::Tensor grad_im)
|
344 |
+
{
|
345 |
+
|
346 |
+
// todo: make sure parallel_imgs is passed in correctly
|
347 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
348 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
349 |
+
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
|
350 |
+
int channel_per_deformable_group = channels / deformable_group;
|
351 |
+
|
352 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
353 |
+
data_col.type(), "deformable_col2im_gpu", ([&] {
|
354 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
355 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
356 |
+
scalar_t *grad_im_ = grad_im.data<scalar_t>();
|
357 |
+
|
358 |
+
deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
359 |
+
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
|
360 |
+
ksize_w, pad_h, pad_w, stride_h, stride_w,
|
361 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
362 |
+
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
|
363 |
+
}));
|
364 |
+
|
365 |
+
cudaError_t err = cudaGetLastError();
|
366 |
+
if (err != cudaSuccess)
|
367 |
+
{
|
368 |
+
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
|
369 |
+
}
|
370 |
+
}
|
371 |
+
|
372 |
+
template <typename scalar_t>
|
373 |
+
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
|
374 |
+
const scalar_t *data_im, const scalar_t *data_offset,
|
375 |
+
const int channels, const int height, const int width,
|
376 |
+
const int kernel_h, const int kernel_w,
|
377 |
+
const int pad_h, const int pad_w,
|
378 |
+
const int stride_h, const int stride_w,
|
379 |
+
const int dilation_h, const int dilation_w,
|
380 |
+
const int channel_per_deformable_group,
|
381 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
382 |
+
const int height_col, const int width_col, scalar_t *grad_offset)
|
383 |
+
{
|
384 |
+
CUDA_KERNEL_LOOP(index, n)
|
385 |
+
{
|
386 |
+
scalar_t val = 0;
|
387 |
+
int w = index % width_col;
|
388 |
+
int h = (index / width_col) % height_col;
|
389 |
+
int c = (index / width_col / height_col) % offset_channels;
|
390 |
+
int b = (index / width_col / height_col) / offset_channels;
|
391 |
+
// compute the start and end of the output
|
392 |
+
|
393 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
394 |
+
const int col_step = kernel_h * kernel_w;
|
395 |
+
int cnt = 0;
|
396 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
|
397 |
+
batch_size * width_col * height_col;
|
398 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
|
399 |
+
channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
400 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
|
401 |
+
kernel_h * kernel_w * height_col * width_col;
|
402 |
+
|
403 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
404 |
+
|
405 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
406 |
+
{
|
407 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
408 |
+
const int bp_dir = offset_c % 2;
|
409 |
+
|
410 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
411 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
412 |
+
int w_out = col_pos % width_col;
|
413 |
+
int h_out = (col_pos / width_col) % height_col;
|
414 |
+
int w_in = w_out * stride_w - pad_w;
|
415 |
+
int h_in = h_out * stride_h - pad_h;
|
416 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
417 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
418 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
419 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
420 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
421 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
422 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
423 |
+
{
|
424 |
+
inv_h = inv_w = -2;
|
425 |
+
}
|
426 |
+
const scalar_t weight = get_coordinate_weight(
|
427 |
+
inv_h, inv_w,
|
428 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
429 |
+
val += weight * data_col_ptr[col_pos];
|
430 |
+
cnt += 1;
|
431 |
+
}
|
432 |
+
|
433 |
+
grad_offset[index] = val;
|
434 |
+
}
|
435 |
+
}
|
436 |
+
|
437 |
+
void deformable_col2im_coord(
|
438 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
|
439 |
+
const int channels, const int height, const int width, const int ksize_h,
|
440 |
+
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
|
441 |
+
const int stride_w, const int dilation_h, const int dilation_w,
|
442 |
+
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
|
443 |
+
{
|
444 |
+
|
445 |
+
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
|
446 |
+
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
|
447 |
+
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
|
448 |
+
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
|
449 |
+
|
450 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
451 |
+
data_col.type(), "deformable_col2im_coord_gpu", ([&] {
|
452 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
453 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
454 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
455 |
+
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
|
456 |
+
|
457 |
+
deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
458 |
+
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
|
459 |
+
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
|
460 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
461 |
+
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
|
462 |
+
height_col, width_col, grad_offset_);
|
463 |
+
}));
|
464 |
+
}
|
465 |
+
|
466 |
+
template <typename scalar_t>
|
467 |
+
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
|
468 |
+
const int height, const int width, scalar_t h, scalar_t w)
|
469 |
+
{
|
470 |
+
int h_low = floor(h);
|
471 |
+
int w_low = floor(w);
|
472 |
+
int h_high = h_low + 1;
|
473 |
+
int w_high = w_low + 1;
|
474 |
+
|
475 |
+
scalar_t lh = h - h_low;
|
476 |
+
scalar_t lw = w - w_low;
|
477 |
+
scalar_t hh = 1 - lh, hw = 1 - lw;
|
478 |
+
|
479 |
+
scalar_t v1 = 0;
|
480 |
+
if (h_low >= 0 && w_low >= 0)
|
481 |
+
v1 = bottom_data[h_low * data_width + w_low];
|
482 |
+
scalar_t v2 = 0;
|
483 |
+
if (h_low >= 0 && w_high <= width - 1)
|
484 |
+
v2 = bottom_data[h_low * data_width + w_high];
|
485 |
+
scalar_t v3 = 0;
|
486 |
+
if (h_high <= height - 1 && w_low >= 0)
|
487 |
+
v3 = bottom_data[h_high * data_width + w_low];
|
488 |
+
scalar_t v4 = 0;
|
489 |
+
if (h_high <= height - 1 && w_high <= width - 1)
|
490 |
+
v4 = bottom_data[h_high * data_width + w_high];
|
491 |
+
|
492 |
+
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
|
493 |
+
|
494 |
+
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
|
495 |
+
return val;
|
496 |
+
}
|
497 |
+
|
498 |
+
template <typename scalar_t>
|
499 |
+
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
|
500 |
+
const int h, const int w, const int height, const int width)
|
501 |
+
{
|
502 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
503 |
+
{
|
504 |
+
//empty
|
505 |
+
return 0;
|
506 |
+
}
|
507 |
+
|
508 |
+
int argmax_h_low = floor(argmax_h);
|
509 |
+
int argmax_w_low = floor(argmax_w);
|
510 |
+
int argmax_h_high = argmax_h_low + 1;
|
511 |
+
int argmax_w_high = argmax_w_low + 1;
|
512 |
+
|
513 |
+
scalar_t weight = 0;
|
514 |
+
if (h == argmax_h_low && w == argmax_w_low)
|
515 |
+
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
|
516 |
+
if (h == argmax_h_low && w == argmax_w_high)
|
517 |
+
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
|
518 |
+
if (h == argmax_h_high && w == argmax_w_low)
|
519 |
+
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
|
520 |
+
if (h == argmax_h_high && w == argmax_w_high)
|
521 |
+
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
|
522 |
+
return weight;
|
523 |
+
}
|
524 |
+
|
525 |
+
template <typename scalar_t>
|
526 |
+
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
|
527 |
+
const int height, const int width, const scalar_t *im_data,
|
528 |
+
const int data_width, const int bp_dir)
|
529 |
+
{
|
530 |
+
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
|
531 |
+
{
|
532 |
+
//empty
|
533 |
+
return 0;
|
534 |
+
}
|
535 |
+
|
536 |
+
int argmax_h_low = floor(argmax_h);
|
537 |
+
int argmax_w_low = floor(argmax_w);
|
538 |
+
int argmax_h_high = argmax_h_low + 1;
|
539 |
+
int argmax_w_high = argmax_w_low + 1;
|
540 |
+
|
541 |
+
scalar_t weight = 0;
|
542 |
+
|
543 |
+
if (bp_dir == 0)
|
544 |
+
{
|
545 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
546 |
+
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
|
547 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
548 |
+
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
|
549 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
550 |
+
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
|
551 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
552 |
+
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
553 |
+
}
|
554 |
+
else if (bp_dir == 1)
|
555 |
+
{
|
556 |
+
if (argmax_h_low >= 0 && argmax_w_low >= 0)
|
557 |
+
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
|
558 |
+
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
|
559 |
+
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
|
560 |
+
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
|
561 |
+
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
|
562 |
+
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
|
563 |
+
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
|
564 |
+
}
|
565 |
+
|
566 |
+
return weight;
|
567 |
+
}
|
568 |
+
|
569 |
+
template <typename scalar_t>
|
570 |
+
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
|
571 |
+
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
|
572 |
+
const int height, const int width, const int kernel_h, const int kernel_w,
|
573 |
+
const int pad_h, const int pad_w,
|
574 |
+
const int stride_h, const int stride_w,
|
575 |
+
const int dilation_h, const int dilation_w,
|
576 |
+
const int channel_per_deformable_group,
|
577 |
+
const int batch_size, const int num_channels, const int deformable_group,
|
578 |
+
const int height_col, const int width_col,
|
579 |
+
scalar_t *data_col)
|
580 |
+
{
|
581 |
+
CUDA_KERNEL_LOOP(index, n)
|
582 |
+
{
|
583 |
+
// index index of output matrix
|
584 |
+
const int w_col = index % width_col;
|
585 |
+
const int h_col = (index / width_col) % height_col;
|
586 |
+
const int b_col = (index / width_col / height_col) % batch_size;
|
587 |
+
const int c_im = (index / width_col / height_col) / batch_size;
|
588 |
+
const int c_col = c_im * kernel_h * kernel_w;
|
589 |
+
|
590 |
+
// compute deformable group index
|
591 |
+
const int deformable_group_index = c_im / channel_per_deformable_group;
|
592 |
+
|
593 |
+
const int h_in = h_col * stride_h - pad_h;
|
594 |
+
const int w_in = w_col * stride_w - pad_w;
|
595 |
+
|
596 |
+
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
|
597 |
+
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
|
598 |
+
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
|
599 |
+
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
600 |
+
|
601 |
+
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
602 |
+
|
603 |
+
for (int i = 0; i < kernel_h; ++i)
|
604 |
+
{
|
605 |
+
for (int j = 0; j < kernel_w; ++j)
|
606 |
+
{
|
607 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
|
608 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
|
609 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
|
610 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
611 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
612 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
613 |
+
scalar_t val = static_cast<scalar_t>(0);
|
614 |
+
const scalar_t h_im = h_in + i * dilation_h + offset_h;
|
615 |
+
const scalar_t w_im = w_in + j * dilation_w + offset_w;
|
616 |
+
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
|
617 |
+
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
|
618 |
+
{
|
619 |
+
//const float map_h = i * dilation_h + offset_h;
|
620 |
+
//const float map_w = j * dilation_w + offset_w;
|
621 |
+
//const int cur_height = height - h_in;
|
622 |
+
//const int cur_width = width - w_in;
|
623 |
+
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
|
624 |
+
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
|
625 |
+
}
|
626 |
+
*data_col_ptr = val * mask;
|
627 |
+
data_col_ptr += batch_size * height_col * width_col;
|
628 |
+
//data_col_ptr += height_col * width_col;
|
629 |
+
}
|
630 |
+
}
|
631 |
+
}
|
632 |
+
}
|
633 |
+
|
634 |
+
template <typename scalar_t>
|
635 |
+
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
|
636 |
+
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
|
637 |
+
const int channels, const int height, const int width,
|
638 |
+
const int kernel_h, const int kernel_w,
|
639 |
+
const int pad_h, const int pad_w,
|
640 |
+
const int stride_h, const int stride_w,
|
641 |
+
const int dilation_h, const int dilation_w,
|
642 |
+
const int channel_per_deformable_group,
|
643 |
+
const int batch_size, const int deformable_group,
|
644 |
+
const int height_col, const int width_col,
|
645 |
+
scalar_t *grad_im)
|
646 |
+
{
|
647 |
+
CUDA_KERNEL_LOOP(index, n)
|
648 |
+
{
|
649 |
+
const int j = (index / width_col / height_col / batch_size) % kernel_w;
|
650 |
+
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
651 |
+
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
|
652 |
+
// compute the start and end of the output
|
653 |
+
|
654 |
+
const int deformable_group_index = c / channel_per_deformable_group;
|
655 |
+
|
656 |
+
int w_out = index % width_col;
|
657 |
+
int h_out = (index / width_col) % height_col;
|
658 |
+
int b = (index / width_col / height_col) % batch_size;
|
659 |
+
int w_in = w_out * stride_w - pad_w;
|
660 |
+
int h_in = h_out * stride_h - pad_h;
|
661 |
+
|
662 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
663 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
664 |
+
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
|
665 |
+
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
|
666 |
+
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
|
667 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
668 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
669 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
670 |
+
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
|
671 |
+
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
|
672 |
+
|
673 |
+
const scalar_t cur_top_grad = data_col[index] * mask;
|
674 |
+
const int cur_h = (int)cur_inv_h_data;
|
675 |
+
const int cur_w = (int)cur_inv_w_data;
|
676 |
+
for (int dy = -2; dy <= 2; dy++)
|
677 |
+
{
|
678 |
+
for (int dx = -2; dx <= 2; dx++)
|
679 |
+
{
|
680 |
+
if (cur_h + dy >= 0 && cur_h + dy < height &&
|
681 |
+
cur_w + dx >= 0 && cur_w + dx < width &&
|
682 |
+
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
|
683 |
+
abs(cur_inv_w_data - (cur_w + dx)) < 1)
|
684 |
+
{
|
685 |
+
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
|
686 |
+
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
|
687 |
+
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
|
688 |
+
}
|
689 |
+
}
|
690 |
+
}
|
691 |
+
}
|
692 |
+
}
|
693 |
+
|
694 |
+
template <typename scalar_t>
|
695 |
+
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
|
696 |
+
const scalar_t *data_col, const scalar_t *data_im,
|
697 |
+
const scalar_t *data_offset, const scalar_t *data_mask,
|
698 |
+
const int channels, const int height, const int width,
|
699 |
+
const int kernel_h, const int kernel_w,
|
700 |
+
const int pad_h, const int pad_w,
|
701 |
+
const int stride_h, const int stride_w,
|
702 |
+
const int dilation_h, const int dilation_w,
|
703 |
+
const int channel_per_deformable_group,
|
704 |
+
const int batch_size, const int offset_channels, const int deformable_group,
|
705 |
+
const int height_col, const int width_col,
|
706 |
+
scalar_t *grad_offset, scalar_t *grad_mask)
|
707 |
+
{
|
708 |
+
CUDA_KERNEL_LOOP(index, n)
|
709 |
+
{
|
710 |
+
scalar_t val = 0, mval = 0;
|
711 |
+
int w = index % width_col;
|
712 |
+
int h = (index / width_col) % height_col;
|
713 |
+
int c = (index / width_col / height_col) % offset_channels;
|
714 |
+
int b = (index / width_col / height_col) / offset_channels;
|
715 |
+
// compute the start and end of the output
|
716 |
+
|
717 |
+
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
|
718 |
+
const int col_step = kernel_h * kernel_w;
|
719 |
+
int cnt = 0;
|
720 |
+
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
|
721 |
+
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
|
722 |
+
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
|
723 |
+
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
|
724 |
+
|
725 |
+
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
|
726 |
+
|
727 |
+
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
|
728 |
+
{
|
729 |
+
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
|
730 |
+
const int bp_dir = offset_c % 2;
|
731 |
+
|
732 |
+
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
|
733 |
+
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
|
734 |
+
int w_out = col_pos % width_col;
|
735 |
+
int h_out = (col_pos / width_col) % height_col;
|
736 |
+
int w_in = w_out * stride_w - pad_w;
|
737 |
+
int h_in = h_out * stride_h - pad_h;
|
738 |
+
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
|
739 |
+
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
|
740 |
+
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
|
741 |
+
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
|
742 |
+
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
|
743 |
+
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
|
744 |
+
scalar_t inv_h = h_in + i * dilation_h + offset_h;
|
745 |
+
scalar_t inv_w = w_in + j * dilation_w + offset_w;
|
746 |
+
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
|
747 |
+
{
|
748 |
+
inv_h = inv_w = -2;
|
749 |
+
}
|
750 |
+
else
|
751 |
+
{
|
752 |
+
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
|
753 |
+
}
|
754 |
+
const scalar_t weight = dmcn_get_coordinate_weight(
|
755 |
+
inv_h, inv_w,
|
756 |
+
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
|
757 |
+
val += weight * data_col_ptr[col_pos] * mask;
|
758 |
+
cnt += 1;
|
759 |
+
}
|
760 |
+
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
|
761 |
+
grad_offset[index] = val;
|
762 |
+
if (offset_c % 2 == 0)
|
763 |
+
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
|
764 |
+
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
|
765 |
+
}
|
766 |
+
}
|
767 |
+
|
768 |
+
void modulated_deformable_im2col_cuda(
|
769 |
+
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
770 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
771 |
+
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
|
772 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
773 |
+
const int dilation_h, const int dilation_w,
|
774 |
+
const int deformable_group, at::Tensor data_col)
|
775 |
+
{
|
776 |
+
// num_axes should be smaller than block size
|
777 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
778 |
+
const int num_kernels = channels * batch_size * height_col * width_col;
|
779 |
+
|
780 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
781 |
+
data_im.type(), "modulated_deformable_im2col_gpu", ([&] {
|
782 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
783 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
784 |
+
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
785 |
+
scalar_t *data_col_ = data_col.data<scalar_t>();
|
786 |
+
|
787 |
+
modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
788 |
+
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
|
789 |
+
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
|
790 |
+
batch_size, channels, deformable_group, height_col, width_col, data_col_);
|
791 |
+
}));
|
792 |
+
|
793 |
+
cudaError_t err = cudaGetLastError();
|
794 |
+
if (err != cudaSuccess)
|
795 |
+
{
|
796 |
+
// printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
|
797 |
+
}
|
798 |
+
}
|
799 |
+
|
800 |
+
void modulated_deformable_col2im_cuda(
|
801 |
+
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
|
802 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
803 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
804 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
805 |
+
const int dilation_h, const int dilation_w,
|
806 |
+
const int deformable_group, at::Tensor grad_im)
|
807 |
+
{
|
808 |
+
|
809 |
+
const int channel_per_deformable_group = channels / deformable_group;
|
810 |
+
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
|
811 |
+
|
812 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
813 |
+
data_col.type(), "modulated_deformable_col2im_gpu", ([&] {
|
814 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
815 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
816 |
+
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
817 |
+
scalar_t *grad_im_ = grad_im.data<scalar_t>();
|
818 |
+
|
819 |
+
modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
820 |
+
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
|
821 |
+
kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
|
822 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
823 |
+
batch_size, deformable_group, height_col, width_col, grad_im_);
|
824 |
+
}));
|
825 |
+
|
826 |
+
cudaError_t err = cudaGetLastError();
|
827 |
+
if (err != cudaSuccess)
|
828 |
+
{
|
829 |
+
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
|
830 |
+
}
|
831 |
+
}
|
832 |
+
|
833 |
+
void modulated_deformable_col2im_coord_cuda(
|
834 |
+
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
|
835 |
+
const int batch_size, const int channels, const int height_im, const int width_im,
|
836 |
+
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
|
837 |
+
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
|
838 |
+
const int dilation_h, const int dilation_w,
|
839 |
+
const int deformable_group,
|
840 |
+
at::Tensor grad_offset, at::Tensor grad_mask)
|
841 |
+
{
|
842 |
+
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
|
843 |
+
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
|
844 |
+
|
845 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
846 |
+
data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] {
|
847 |
+
const scalar_t *data_col_ = data_col.data<scalar_t>();
|
848 |
+
const scalar_t *data_im_ = data_im.data<scalar_t>();
|
849 |
+
const scalar_t *data_offset_ = data_offset.data<scalar_t>();
|
850 |
+
const scalar_t *data_mask_ = data_mask.data<scalar_t>();
|
851 |
+
scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
|
852 |
+
scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
|
853 |
+
|
854 |
+
modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
|
855 |
+
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
|
856 |
+
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
|
857 |
+
dilation_h, dilation_w, channel_per_deformable_group,
|
858 |
+
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
|
859 |
+
grad_offset_, grad_mask_);
|
860 |
+
}));
|
861 |
+
cudaError_t err = cudaGetLastError();
|
862 |
+
if (err != cudaSuccess)
|
863 |
+
{
|
864 |
+
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
|
865 |
+
}
|
866 |
+
}
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda.cpp
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// modify from
|
2 |
+
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
|
3 |
+
|
4 |
+
// based on
|
5 |
+
// author: Charles Shang
|
6 |
+
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
|
7 |
+
|
8 |
+
#include <torch/extension.h>
|
9 |
+
|
10 |
+
#include <cmath>
|
11 |
+
#include <vector>
|
12 |
+
|
13 |
+
void DeformablePSROIPoolForward(
|
14 |
+
const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
|
15 |
+
at::Tensor out, at::Tensor top_count, const int batch, const int channels,
|
16 |
+
const int height, const int width, const int num_bbox,
|
17 |
+
const int channels_trans, const int no_trans, const float spatial_scale,
|
18 |
+
const int output_dim, const int group_size, const int pooled_size,
|
19 |
+
const int part_size, const int sample_per_part, const float trans_std);
|
20 |
+
|
21 |
+
void DeformablePSROIPoolBackwardAcc(
|
22 |
+
const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
|
23 |
+
const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
|
24 |
+
at::Tensor trans_grad, const int batch, const int channels,
|
25 |
+
const int height, const int width, const int num_bbox,
|
26 |
+
const int channels_trans, const int no_trans, const float spatial_scale,
|
27 |
+
const int output_dim, const int group_size, const int pooled_size,
|
28 |
+
const int part_size, const int sample_per_part, const float trans_std);
|
29 |
+
|
30 |
+
void deform_psroi_pooling_cuda_forward(
|
31 |
+
at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
|
32 |
+
at::Tensor top_count, const int no_trans, const float spatial_scale,
|
33 |
+
const int output_dim, const int group_size, const int pooled_size,
|
34 |
+
const int part_size, const int sample_per_part, const float trans_std) {
|
35 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
36 |
+
|
37 |
+
const int batch = input.size(0);
|
38 |
+
const int channels = input.size(1);
|
39 |
+
const int height = input.size(2);
|
40 |
+
const int width = input.size(3);
|
41 |
+
const int channels_trans = no_trans ? 2 : trans.size(1);
|
42 |
+
|
43 |
+
const int num_bbox = bbox.size(0);
|
44 |
+
if (num_bbox != out.size(0))
|
45 |
+
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
|
46 |
+
out.size(0), num_bbox);
|
47 |
+
|
48 |
+
DeformablePSROIPoolForward(
|
49 |
+
input, bbox, trans, out, top_count, batch, channels, height, width,
|
50 |
+
num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
|
51 |
+
pooled_size, part_size, sample_per_part, trans_std);
|
52 |
+
}
|
53 |
+
|
54 |
+
void deform_psroi_pooling_cuda_backward(
|
55 |
+
at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
|
56 |
+
at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
|
57 |
+
const int no_trans, const float spatial_scale, const int output_dim,
|
58 |
+
const int group_size, const int pooled_size, const int part_size,
|
59 |
+
const int sample_per_part, const float trans_std) {
|
60 |
+
TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
|
61 |
+
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
|
62 |
+
|
63 |
+
const int batch = input.size(0);
|
64 |
+
const int channels = input.size(1);
|
65 |
+
const int height = input.size(2);
|
66 |
+
const int width = input.size(3);
|
67 |
+
const int channels_trans = no_trans ? 2 : trans.size(1);
|
68 |
+
|
69 |
+
const int num_bbox = bbox.size(0);
|
70 |
+
if (num_bbox != out_grad.size(0))
|
71 |
+
AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
|
72 |
+
out_grad.size(0), num_bbox);
|
73 |
+
|
74 |
+
DeformablePSROIPoolBackwardAcc(
|
75 |
+
out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
|
76 |
+
channels, height, width, num_bbox, channels_trans, no_trans,
|
77 |
+
spatial_scale, output_dim, group_size, pooled_size, part_size,
|
78 |
+
sample_per_part, trans_std);
|
79 |
+
}
|
80 |
+
|
81 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
82 |
+
m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
|
83 |
+
"deform psroi pooling forward(CUDA)");
|
84 |
+
m.def("deform_psroi_pooling_cuda_backward",
|
85 |
+
&deform_psroi_pooling_cuda_backward,
|
86 |
+
"deform psroi pooling backward(CUDA)");
|
87 |
+
}
|
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda_kernel.cu
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*!
|
2 |
+
* Copyright (c) 2017 Microsoft
|
3 |
+
* Licensed under The MIT License [see LICENSE for details]
|
4 |
+
* \file deformable_psroi_pooling.cu
|
5 |
+
* \brief
|
6 |
+
* \author Yi Li, Guodong Zhang, Jifeng Dai
|
7 |
+
*/
|
8 |
+
/***************** Adapted by Charles Shang *********************/
|
9 |
+
// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
|
10 |
+
|
11 |
+
#include <ATen/ATen.h>
|
12 |
+
#include <THC/THCAtomics.cuh>
|
13 |
+
#include <stdio.h>
|
14 |
+
#include <math.h>
|
15 |
+
#include <algorithm>
|
16 |
+
|
17 |
+
using namespace at;
|
18 |
+
|
19 |
+
#define CUDA_KERNEL_LOOP(i, n) \
|
20 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
|
21 |
+
i < (n); \
|
22 |
+
i += blockDim.x * gridDim.x)
|
23 |
+
|
24 |
+
const int CUDA_NUM_THREADS = 1024;
|
25 |
+
inline int GET_BLOCKS(const int N)
|
26 |
+
{
|
27 |
+
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
|
28 |
+
}
|
29 |
+
|
30 |
+
template <typename scalar_t>
|
31 |
+
__device__ scalar_t bilinear_interp(
|
32 |
+
const scalar_t *data,
|
33 |
+
const scalar_t x,
|
34 |
+
const scalar_t y,
|
35 |
+
const int width,
|
36 |
+
const int height)
|
37 |
+
{
|
38 |
+
int x1 = floor(x);
|
39 |
+
int x2 = ceil(x);
|
40 |
+
int y1 = floor(y);
|
41 |
+
int y2 = ceil(y);
|
42 |
+
scalar_t dist_x = (scalar_t)(x - x1);
|
43 |
+
scalar_t dist_y = (scalar_t)(y - y1);
|
44 |
+
scalar_t value11 = data[y1 * width + x1];
|
45 |
+
scalar_t value12 = data[y2 * width + x1];
|
46 |
+
scalar_t value21 = data[y1 * width + x2];
|
47 |
+
scalar_t value22 = data[y2 * width + x2];
|
48 |
+
scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
|
49 |
+
return value;
|
50 |
+
}
|
51 |
+
|
52 |
+
template <typename scalar_t>
|
53 |
+
__global__ void DeformablePSROIPoolForwardKernel(
|
54 |
+
const int count,
|
55 |
+
const scalar_t *bottom_data,
|
56 |
+
const scalar_t spatial_scale,
|
57 |
+
const int channels,
|
58 |
+
const int height, const int width,
|
59 |
+
const int pooled_height, const int pooled_width,
|
60 |
+
const scalar_t *bottom_rois, const scalar_t *bottom_trans,
|
61 |
+
const int no_trans,
|
62 |
+
const scalar_t trans_std,
|
63 |
+
const int sample_per_part,
|
64 |
+
const int output_dim,
|
65 |
+
const int group_size,
|
66 |
+
const int part_size,
|
67 |
+
const int num_classes,
|
68 |
+
const int channels_each_class,
|
69 |
+
scalar_t *top_data,
|
70 |
+
scalar_t *top_count)
|
71 |
+
{
|
72 |
+
CUDA_KERNEL_LOOP(index, count)
|
73 |
+
{
|
74 |
+
// The output is in order (n, ctop, ph, pw)
|
75 |
+
int pw = index % pooled_width;
|
76 |
+
int ph = (index / pooled_width) % pooled_height;
|
77 |
+
int ctop = (index / pooled_width / pooled_height) % output_dim;
|
78 |
+
int n = index / pooled_width / pooled_height / output_dim;
|
79 |
+
|
80 |
+
// [start, end) interval for spatial sampling
|
81 |
+
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
|
82 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
83 |
+
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
|
84 |
+
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
|
85 |
+
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
|
86 |
+
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
|
87 |
+
|
88 |
+
// Force too small ROIs to be 1x1
|
89 |
+
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
|
90 |
+
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
|
91 |
+
|
92 |
+
// Compute w and h at bottom
|
93 |
+
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
|
94 |
+
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
|
95 |
+
|
96 |
+
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
|
97 |
+
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
|
98 |
+
|
99 |
+
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
|
100 |
+
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
|
101 |
+
int class_id = ctop / channels_each_class;
|
102 |
+
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
|
103 |
+
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
|
104 |
+
|
105 |
+
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
|
106 |
+
wstart += trans_x * roi_width;
|
107 |
+
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
|
108 |
+
hstart += trans_y * roi_height;
|
109 |
+
|
110 |
+
scalar_t sum = 0;
|
111 |
+
int count = 0;
|
112 |
+
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
|
113 |
+
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
|
114 |
+
gw = min(max(gw, 0), group_size - 1);
|
115 |
+
gh = min(max(gh, 0), group_size - 1);
|
116 |
+
|
117 |
+
const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
|
118 |
+
for (int ih = 0; ih < sample_per_part; ih++)
|
119 |
+
{
|
120 |
+
for (int iw = 0; iw < sample_per_part; iw++)
|
121 |
+
{
|
122 |
+
scalar_t w = wstart + iw * sub_bin_size_w;
|
123 |
+
scalar_t h = hstart + ih * sub_bin_size_h;
|
124 |
+
// bilinear interpolation
|
125 |
+
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
|
126 |
+
{
|
127 |
+
continue;
|
128 |
+
}
|
129 |
+
w = min(max(w, 0.), width - 1.);
|
130 |
+
h = min(max(h, 0.), height - 1.);
|
131 |
+
int c = (ctop * group_size + gh) * group_size + gw;
|
132 |
+
scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
|
133 |
+
sum += val;
|
134 |
+
count++;
|
135 |
+
}
|
136 |
+
}
|
137 |
+
top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
|
138 |
+
top_count[index] = count;
|
139 |
+
}
|
140 |
+
}
|
141 |
+
|
142 |
+
template <typename scalar_t>
|
143 |
+
__global__ void DeformablePSROIPoolBackwardAccKernel(
|
144 |
+
const int count,
|
145 |
+
const scalar_t *top_diff,
|
146 |
+
const scalar_t *top_count,
|
147 |
+
const int num_rois,
|
148 |
+
const scalar_t spatial_scale,
|
149 |
+
const int channels,
|
150 |
+
const int height, const int width,
|
151 |
+
const int pooled_height, const int pooled_width,
|
152 |
+
const int output_dim,
|
153 |
+
scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
|
154 |
+
const scalar_t *bottom_data,
|
155 |
+
const scalar_t *bottom_rois,
|
156 |
+
const scalar_t *bottom_trans,
|
157 |
+
const int no_trans,
|
158 |
+
const scalar_t trans_std,
|
159 |
+
const int sample_per_part,
|
160 |
+
const int group_size,
|
161 |
+
const int part_size,
|
162 |
+
const int num_classes,
|
163 |
+
const int channels_each_class)
|
164 |
+
{
|
165 |
+
CUDA_KERNEL_LOOP(index, count)
|
166 |
+
{
|
167 |
+
// The output is in order (n, ctop, ph, pw)
|
168 |
+
int pw = index % pooled_width;
|
169 |
+
int ph = (index / pooled_width) % pooled_height;
|
170 |
+
int ctop = (index / pooled_width / pooled_height) % output_dim;
|
171 |
+
int n = index / pooled_width / pooled_height / output_dim;
|
172 |
+
|
173 |
+
// [start, end) interval for spatial sampling
|
174 |
+
const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
|
175 |
+
int roi_batch_ind = offset_bottom_rois[0];
|
176 |
+
scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
|
177 |
+
scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
|
178 |
+
scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
|
179 |
+
scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
|
180 |
+
|
181 |
+
// Force too small ROIs to be 1x1
|
182 |
+
scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
|
183 |
+
scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
|
184 |
+
|
185 |
+
// Compute w and h at bottom
|
186 |
+
scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
|
187 |
+
scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
|
188 |
+
|
189 |
+
scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
|
190 |
+
scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
|
191 |
+
|
192 |
+
int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
|
193 |
+
int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
|
194 |
+
int class_id = ctop / channels_each_class;
|
195 |
+
scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
|
196 |
+
scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
|
197 |
+
|
198 |
+
scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
|
199 |
+
wstart += trans_x * roi_width;
|
200 |
+
scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
|
201 |
+
hstart += trans_y * roi_height;
|
202 |
+
|
203 |
+
if (top_count[index] <= 0)
|
204 |
+
{
|
205 |
+
continue;
|
206 |
+
}
|
207 |
+
scalar_t diff_val = top_diff[index] / top_count[index];
|
208 |
+
const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
|
209 |
+
scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
|
210 |
+
int gw = floor((scalar_t)(pw)*group_size / pooled_width);
|
211 |
+
int gh = floor((scalar_t)(ph)*group_size / pooled_height);
|
212 |
+
gw = min(max(gw, 0), group_size - 1);
|
213 |
+
gh = min(max(gh, 0), group_size - 1);
|
214 |
+
|
215 |
+
for (int ih = 0; ih < sample_per_part; ih++)
|
216 |
+
{
|
217 |
+
for (int iw = 0; iw < sample_per_part; iw++)
|
218 |
+
{
|
219 |
+
scalar_t w = wstart + iw * sub_bin_size_w;
|
220 |
+
scalar_t h = hstart + ih * sub_bin_size_h;
|
221 |
+
// bilinear interpolation
|
222 |
+
if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
|
223 |
+
{
|
224 |
+
continue;
|
225 |
+
}
|
226 |
+
w = min(max(w, 0.), width - 1.);
|
227 |
+
h = min(max(h, 0.), height - 1.);
|
228 |
+
int c = (ctop * group_size + gh) * group_size + gw;
|
229 |
+
// backward on feature
|
230 |
+
int x0 = floor(w);
|
231 |
+
int x1 = ceil(w);
|
232 |
+
int y0 = floor(h);
|
233 |
+
int y1 = ceil(h);
|
234 |
+
scalar_t dist_x = w - x0, dist_y = h - y0;
|
235 |
+
scalar_t q00 = (1 - dist_x) * (1 - dist_y);
|
236 |
+
scalar_t q01 = (1 - dist_x) * dist_y;
|
237 |
+
scalar_t q10 = dist_x * (1 - dist_y);
|
238 |
+
scalar_t q11 = dist_x * dist_y;
|
239 |
+
int bottom_index_base = c * height * width;
|
240 |
+
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
|
241 |
+
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
|
242 |
+
atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
|
243 |
+
atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
|
244 |
+
|
245 |
+
if (no_trans)
|
246 |
+
{
|
247 |
+
continue;
|
248 |
+
}
|
249 |
+
scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
|
250 |
+
scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
|
251 |
+
scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
|
252 |
+
scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
|
253 |
+
scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
|
254 |
+
diff_x *= roi_width;
|
255 |
+
scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
|
256 |
+
diff_y *= roi_height;
|
257 |
+
|
258 |
+
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
|
259 |
+
atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
|
260 |
+
}
|
261 |
+
}
|
262 |
+
}
|
263 |
+
}
|
264 |
+
|
265 |
+
void DeformablePSROIPoolForward(const at::Tensor data,
|
266 |
+
const at::Tensor bbox,
|
267 |
+
const at::Tensor trans,
|
268 |
+
at::Tensor out,
|
269 |
+
at::Tensor top_count,
|
270 |
+
const int batch,
|
271 |
+
const int channels,
|
272 |
+
const int height,
|
273 |
+
const int width,
|
274 |
+
const int num_bbox,
|
275 |
+
const int channels_trans,
|
276 |
+
const int no_trans,
|
277 |
+
const float spatial_scale,
|
278 |
+
const int output_dim,
|
279 |
+
const int group_size,
|
280 |
+
const int pooled_size,
|
281 |
+
const int part_size,
|
282 |
+
const int sample_per_part,
|
283 |
+
const float trans_std)
|
284 |
+
{
|
285 |
+
const int pooled_height = pooled_size;
|
286 |
+
const int pooled_width = pooled_size;
|
287 |
+
const int count = num_bbox * output_dim * pooled_height * pooled_width;
|
288 |
+
const int num_classes = no_trans ? 1 : channels_trans / 2;
|
289 |
+
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
|
290 |
+
|
291 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
292 |
+
data.type(), "deformable_psroi_pool_forward", ([&] {
|
293 |
+
const scalar_t *bottom_data = data.data<scalar_t>();
|
294 |
+
const scalar_t *bottom_rois = bbox.data<scalar_t>();
|
295 |
+
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
|
296 |
+
scalar_t *top_data = out.data<scalar_t>();
|
297 |
+
scalar_t *top_count_data = top_count.data<scalar_t>();
|
298 |
+
|
299 |
+
DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
|
300 |
+
count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
|
301 |
+
bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
|
302 |
+
group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
|
303 |
+
}));
|
304 |
+
|
305 |
+
cudaError_t err = cudaGetLastError();
|
306 |
+
if (err != cudaSuccess)
|
307 |
+
{
|
308 |
+
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
|
309 |
+
}
|
310 |
+
}
|
311 |
+
|
312 |
+
void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
|
313 |
+
const at::Tensor data,
|
314 |
+
const at::Tensor bbox,
|
315 |
+
const at::Tensor trans,
|
316 |
+
const at::Tensor top_count,
|
317 |
+
at::Tensor in_grad,
|
318 |
+
at::Tensor trans_grad,
|
319 |
+
const int batch,
|
320 |
+
const int channels,
|
321 |
+
const int height,
|
322 |
+
const int width,
|
323 |
+
const int num_bbox,
|
324 |
+
const int channels_trans,
|
325 |
+
const int no_trans,
|
326 |
+
const float spatial_scale,
|
327 |
+
const int output_dim,
|
328 |
+
const int group_size,
|
329 |
+
const int pooled_size,
|
330 |
+
const int part_size,
|
331 |
+
const int sample_per_part,
|
332 |
+
const float trans_std)
|
333 |
+
{
|
334 |
+
// LOG(INFO) << "DeformablePSROIPoolBackward";
|
335 |
+
const int num_rois = num_bbox;
|
336 |
+
const int pooled_height = pooled_size;
|
337 |
+
const int pooled_width = pooled_size;
|
338 |
+
const int count = num_bbox * output_dim * pooled_height * pooled_width;
|
339 |
+
const int num_classes = no_trans ? 1 : channels_trans / 2;
|
340 |
+
const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
|
341 |
+
|
342 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
343 |
+
out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] {
|
344 |
+
const scalar_t *top_diff = out_grad.data<scalar_t>();
|
345 |
+
const scalar_t *bottom_data = data.data<scalar_t>();
|
346 |
+
const scalar_t *bottom_rois = bbox.data<scalar_t>();
|
347 |
+
const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
|
348 |
+
scalar_t *bottom_data_diff = in_grad.data<scalar_t>();
|
349 |
+
scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>();
|
350 |
+
const scalar_t *top_count_data = top_count.data<scalar_t>();
|
351 |
+
|
352 |
+
DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
|
353 |
+
count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
|
354 |
+
pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
|
355 |
+
bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
|
356 |
+
group_size, part_size, num_classes, channels_each_class);
|
357 |
+
}));
|
358 |
+
|
359 |
+
cudaError_t err = cudaGetLastError();
|
360 |
+
if (err != cudaSuccess)
|
361 |
+
{
|
362 |
+
printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
|
363 |
+
}
|
364 |
+
}
|
IndicPhotoOCR/detection/textbpn/network/backbone/resnet.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import math
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
BatchNorm2d = nn.BatchNorm2d
|
5 |
+
|
6 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
7 |
+
'resnet152']
|
8 |
+
|
9 |
+
|
10 |
+
model_urls = {
|
11 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
12 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def constant_init(module, constant, bias=0):
|
20 |
+
nn.init.constant_(module.weight, constant)
|
21 |
+
if hasattr(module, 'bias'):
|
22 |
+
nn.init.constant_(module.bias, bias)
|
23 |
+
|
24 |
+
|
25 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
26 |
+
"""3x3 convolution with padding"""
|
27 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
28 |
+
padding=1, bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class BasicBlock(nn.Module):
|
32 |
+
expansion = 1
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
|
35 |
+
super(BasicBlock, self).__init__()
|
36 |
+
self.with_dcn = dcn is not None
|
37 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
38 |
+
self.bn1 = BatchNorm2d(planes)
|
39 |
+
self.relu = nn.ReLU(inplace=True)
|
40 |
+
self.with_modulated_dcn = False
|
41 |
+
if self.with_dcn:
|
42 |
+
fallback_on_stride = dcn.get('fallback_on_stride', False)
|
43 |
+
self.with_modulated_dcn = dcn.get('modulated', False)
|
44 |
+
# self.conv2 = conv3x3(planes, planes)
|
45 |
+
if not self.with_dcn or fallback_on_stride:
|
46 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
47 |
+
padding=1, bias=False)
|
48 |
+
else:
|
49 |
+
deformable_groups = dcn.get('deformable_groups', 1)
|
50 |
+
if not self.with_modulated_dcn:
|
51 |
+
from network.backbone.assets.dcn import DeformConv
|
52 |
+
conv_op = DeformConv
|
53 |
+
offset_channels = 18
|
54 |
+
else:
|
55 |
+
from network.backbone.assets.dcn import ModulatedDeformConv
|
56 |
+
conv_op = ModulatedDeformConv
|
57 |
+
offset_channels = 27
|
58 |
+
self.conv2_offset = nn.Conv2d(
|
59 |
+
planes,
|
60 |
+
deformable_groups * offset_channels,
|
61 |
+
kernel_size=3,
|
62 |
+
padding=1)
|
63 |
+
self.conv2 = conv_op(
|
64 |
+
planes,
|
65 |
+
planes,
|
66 |
+
kernel_size=3,
|
67 |
+
padding=1,
|
68 |
+
deformable_groups=deformable_groups,
|
69 |
+
bias=False)
|
70 |
+
self.bn2 = BatchNorm2d(planes)
|
71 |
+
self.downsample = downsample
|
72 |
+
self.stride = stride
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
residual = x
|
76 |
+
|
77 |
+
out = self.conv1(x)
|
78 |
+
out = self.bn1(out)
|
79 |
+
out = self.relu(out)
|
80 |
+
|
81 |
+
# out = self.conv2(out)
|
82 |
+
if not self.with_dcn:
|
83 |
+
out = self.conv2(out)
|
84 |
+
elif self.with_modulated_dcn:
|
85 |
+
offset_mask = self.conv2_offset(out)
|
86 |
+
offset = offset_mask[:, :18, :, :]
|
87 |
+
mask = offset_mask[:, -9:, :, :].sigmoid()
|
88 |
+
out = self.conv2(out, offset, mask)
|
89 |
+
else:
|
90 |
+
offset = self.conv2_offset(out)
|
91 |
+
out = self.conv2(out, offset)
|
92 |
+
out = self.bn2(out)
|
93 |
+
|
94 |
+
if self.downsample is not None:
|
95 |
+
residual = self.downsample(x)
|
96 |
+
|
97 |
+
out += residual
|
98 |
+
out = self.relu(out)
|
99 |
+
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
class Bottleneck(nn.Module):
|
104 |
+
expansion = 4
|
105 |
+
|
106 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
|
107 |
+
super(Bottleneck, self).__init__()
|
108 |
+
self.with_dcn = dcn is not None
|
109 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
110 |
+
self.bn1 = BatchNorm2d(planes)
|
111 |
+
fallback_on_stride = False
|
112 |
+
self.with_modulated_dcn = False
|
113 |
+
if self.with_dcn:
|
114 |
+
fallback_on_stride = dcn.get('fallback_on_stride', False)
|
115 |
+
self.with_modulated_dcn = dcn.get('modulated', False)
|
116 |
+
if not self.with_dcn or fallback_on_stride:
|
117 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
118 |
+
stride=stride, padding=1, bias=False)
|
119 |
+
else:
|
120 |
+
deformable_groups = dcn.get('deformable_groups', 1)
|
121 |
+
if not self.with_modulated_dcn:
|
122 |
+
from network.backbone.assets.dcn import DeformConv
|
123 |
+
conv_op = DeformConv
|
124 |
+
offset_channels = 18
|
125 |
+
else:
|
126 |
+
from network.backbone.assets.dcn import ModulatedDeformConv
|
127 |
+
conv_op = ModulatedDeformConv
|
128 |
+
offset_channels = 27
|
129 |
+
self.conv2_offset = nn.Conv2d(
|
130 |
+
planes, deformable_groups * offset_channels,
|
131 |
+
kernel_size=3,
|
132 |
+
padding=1)
|
133 |
+
self.conv2 = conv_op(
|
134 |
+
planes, planes, kernel_size=3, padding=1, stride=stride,
|
135 |
+
deformable_groups=deformable_groups, bias=False)
|
136 |
+
self.bn2 = BatchNorm2d(planes)
|
137 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
138 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
139 |
+
self.relu = nn.ReLU(inplace=True)
|
140 |
+
self.downsample = downsample
|
141 |
+
self.stride = stride
|
142 |
+
self.dcn = dcn
|
143 |
+
self.with_dcn = dcn is not None
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
residual = x
|
147 |
+
|
148 |
+
out = self.conv1(x)
|
149 |
+
out = self.bn1(out)
|
150 |
+
out = self.relu(out)
|
151 |
+
|
152 |
+
# out = self.conv2(out)
|
153 |
+
if not self.with_dcn:
|
154 |
+
out = self.conv2(out)
|
155 |
+
elif self.with_modulated_dcn:
|
156 |
+
offset_mask = self.conv2_offset(out)
|
157 |
+
offset = offset_mask[:, :18, :, :]
|
158 |
+
mask = offset_mask[:, -9:, :, :].sigmoid()
|
159 |
+
out = self.conv2(out, offset, mask)
|
160 |
+
else:
|
161 |
+
offset = self.conv2_offset(out)
|
162 |
+
out = self.conv2(out, offset)
|
163 |
+
out = self.bn2(out)
|
164 |
+
out = self.relu(out)
|
165 |
+
|
166 |
+
out = self.conv3(out)
|
167 |
+
out = self.bn3(out)
|
168 |
+
|
169 |
+
if self.downsample is not None:
|
170 |
+
residual = self.downsample(x)
|
171 |
+
|
172 |
+
out += residual
|
173 |
+
out = self.relu(out)
|
174 |
+
|
175 |
+
return out
|
176 |
+
|
177 |
+
|
178 |
+
class ResNet(nn.Module):
|
179 |
+
def __init__(self, block, layers, num_classes=1000,
|
180 |
+
dcn=None, stage_with_dcn=(False, False, False, False)):
|
181 |
+
self.dcn = dcn
|
182 |
+
self.stage_with_dcn = stage_with_dcn
|
183 |
+
self.inplanes = 64
|
184 |
+
super(ResNet, self).__init__()
|
185 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
186 |
+
bias=False)
|
187 |
+
self.bn1 = BatchNorm2d(64)
|
188 |
+
self.relu = nn.ReLU(inplace=True)
|
189 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
190 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
191 |
+
self.layer2 = self._make_layer(
|
192 |
+
block, 128, layers[1], stride=2, dcn=dcn)
|
193 |
+
self.layer3 = self._make_layer(
|
194 |
+
block, 256, layers[2], stride=2, dcn=dcn)
|
195 |
+
self.layer4 = self._make_layer(
|
196 |
+
block, 512, layers[3], stride=2, dcn=dcn)
|
197 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
198 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
199 |
+
|
200 |
+
self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
|
201 |
+
|
202 |
+
for m in self.modules():
|
203 |
+
if isinstance(m, nn.Conv2d):
|
204 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
205 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
206 |
+
elif isinstance(m, BatchNorm2d):
|
207 |
+
m.weight.data.fill_(1)
|
208 |
+
m.bias.data.zero_()
|
209 |
+
if self.dcn is not None:
|
210 |
+
for m in self.modules():
|
211 |
+
if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
|
212 |
+
if hasattr(m, 'conv2_offset'):
|
213 |
+
constant_init(m.conv2_offset, 0)
|
214 |
+
|
215 |
+
def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
|
216 |
+
downsample = None
|
217 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
218 |
+
downsample = nn.Sequential(
|
219 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
220 |
+
kernel_size=1, stride=stride, bias=False),
|
221 |
+
BatchNorm2d(planes * block.expansion),
|
222 |
+
)
|
223 |
+
|
224 |
+
layers = []
|
225 |
+
layers.append(block(self.inplanes, planes,
|
226 |
+
stride, downsample, dcn=dcn))
|
227 |
+
self.inplanes = planes * block.expansion
|
228 |
+
for i in range(1, blocks):
|
229 |
+
layers.append(block(self.inplanes, planes, dcn=dcn))
|
230 |
+
|
231 |
+
return nn.Sequential(*layers)
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
x = self.conv1(x)
|
235 |
+
x = self.bn1(x)
|
236 |
+
x = self.relu(x)
|
237 |
+
x1 = self.maxpool(x)
|
238 |
+
|
239 |
+
x2 = self.layer1(x1)
|
240 |
+
x3 = self.layer2(x2)
|
241 |
+
x4 = self.layer3(x3)
|
242 |
+
x5 = self.layer4(x4)
|
243 |
+
|
244 |
+
return x1, x2, x3, x4, x5
|
245 |
+
|
246 |
+
|
247 |
+
def resnet18(pretrained=True, **kwargs):
|
248 |
+
"""Constructs a ResNet-18 model.
|
249 |
+
Args:
|
250 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
251 |
+
"""
|
252 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
253 |
+
if pretrained:
|
254 |
+
model.load_state_dict(model_zoo.load_url(
|
255 |
+
model_urls['resnet18']), strict=False)
|
256 |
+
return model
|
257 |
+
|
258 |
+
def deformable_resnet18(pretrained=True, **kwargs):
|
259 |
+
"""Constructs a ResNet-18 model.
|
260 |
+
Args:
|
261 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
262 |
+
"""
|
263 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2],
|
264 |
+
dcn=dict(modulated=True,
|
265 |
+
deformable_groups=1,
|
266 |
+
fallback_on_stride=False),
|
267 |
+
stage_with_dcn=[False, True, True, True], **kwargs)
|
268 |
+
if pretrained:
|
269 |
+
model.load_state_dict(model_zoo.load_url(
|
270 |
+
model_urls['resnet18']), strict=False)
|
271 |
+
return model
|
272 |
+
|
273 |
+
|
274 |
+
def resnet34(pretrained=True, **kwargs):
|
275 |
+
"""Constructs a ResNet-34 model.
|
276 |
+
Args:
|
277 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
278 |
+
"""
|
279 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
280 |
+
if pretrained:
|
281 |
+
model.load_state_dict(model_zoo.load_url(
|
282 |
+
model_urls['resnet34']), strict=False)
|
283 |
+
return model
|
284 |
+
|
285 |
+
|
286 |
+
def resnet50(pretrained=True, **kwargs):
|
287 |
+
"""Constructs a ResNet-50 model.
|
288 |
+
Args:
|
289 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
290 |
+
"""
|
291 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
292 |
+
if pretrained:
|
293 |
+
model.load_state_dict(model_zoo.load_url(
|
294 |
+
model_urls['resnet50']), strict=False)
|
295 |
+
return model
|
296 |
+
|
297 |
+
|
298 |
+
def deformable_resnet50(pretrained=True, **kwargs):
|
299 |
+
"""Constructs a ResNet-50 model with deformable conv.
|
300 |
+
Args:
|
301 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
302 |
+
"""
|
303 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3],
|
304 |
+
dcn=dict(modulated=True,
|
305 |
+
deformable_groups=1,
|
306 |
+
fallback_on_stride=False),
|
307 |
+
stage_with_dcn=[False, True, True, True],
|
308 |
+
**kwargs)
|
309 |
+
if pretrained:
|
310 |
+
model.load_state_dict(model_zoo.load_url(
|
311 |
+
model_urls['resnet50']), strict=False)
|
312 |
+
return model
|
313 |
+
|
314 |
+
|
315 |
+
def resnet101(pretrained=True, **kwargs):
|
316 |
+
"""Constructs a ResNet-101 model.
|
317 |
+
Args:
|
318 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
319 |
+
"""
|
320 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
321 |
+
if pretrained:
|
322 |
+
model.load_state_dict(model_zoo.load_url(
|
323 |
+
model_urls['resnet101']), strict=False)
|
324 |
+
return model
|
325 |
+
|
326 |
+
|
327 |
+
def resnet152(pretrained=True, **kwargs):
|
328 |
+
"""Constructs a ResNet-152 model.
|
329 |
+
Args:
|
330 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
331 |
+
"""
|
332 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
333 |
+
if pretrained:
|
334 |
+
model.load_state_dict(model_zoo.load_url(
|
335 |
+
model_urls['resnet152']), strict=False)
|
336 |
+
return model
|
IndicPhotoOCR/detection/textbpn/network/backbone/vgg.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.utils.model_zoo as model_zoo
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
model_urls = {
|
6 |
+
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
|
7 |
+
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
|
8 |
+
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
|
9 |
+
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
|
10 |
+
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
|
11 |
+
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
class VggNet(nn.Module):
|
16 |
+
def __init__(self, name="vgg16", pretrain=True):
|
17 |
+
super().__init__()
|
18 |
+
if name == "vgg16":
|
19 |
+
base_net = models.vgg16(pretrained=False)
|
20 |
+
elif name == "vgg16_bn":
|
21 |
+
base_net = models.vgg16_bn(pretrained=False)
|
22 |
+
else:
|
23 |
+
print(" base model is not support !")
|
24 |
+
if pretrain:
|
25 |
+
print("load the {} weight from ./cache".format(name))
|
26 |
+
base_net.load_state_dict(model_zoo.load_url(model_urls[name], model_dir="./cache"))
|
27 |
+
|
28 |
+
if name == "vgg16":
|
29 |
+
self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 5)])
|
30 |
+
self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(5, 10)])
|
31 |
+
self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(10, 17)])
|
32 |
+
self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(17, 24)])
|
33 |
+
self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 31)])
|
34 |
+
elif name == "vgg16_bn":
|
35 |
+
self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 7)])
|
36 |
+
self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(7, 14)])
|
37 |
+
self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(14, 24)])
|
38 |
+
self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 34)])
|
39 |
+
self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(34, 44)])
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
C1 = self.stage1(x)
|
43 |
+
C2 = self.stage2(C1)
|
44 |
+
C3 = self.stage3(C2)
|
45 |
+
C4 = self.stage4(C3)
|
46 |
+
C5 = self.stage5(C4)
|
47 |
+
|
48 |
+
return C1, C2, C3, C4, C5
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
import torch
|
53 |
+
input = torch.randn((4, 3, 512, 512))
|
54 |
+
net = VggNet()
|
55 |
+
C1, C2, C3, C4, C5 = net(input)
|
56 |
+
print(C1.size())
|
57 |
+
print(C2.size())
|
58 |
+
print(C3.size())
|
59 |
+
print(C4.size())
|
60 |
+
print(C5.size())
|
IndicPhotoOCR/detection/textbpn/network/layers/Adaptive_Deformation.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###################################################################
|
2 |
+
# File Name: AdaptiveDeformation.py
|
3 |
+
# Author: S.X.Zhang
|
4 |
+
###################################################################
|
5 |
+
|
6 |
+
from __future__ import print_function
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import absolute_import
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn import init
|
14 |
+
|
15 |
+
|
16 |
+
class MeanAggregator(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super(MeanAggregator, self).__init__()
|
19 |
+
|
20 |
+
def forward(self, features, A):
|
21 |
+
x = torch.bmm(A, features)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class GraphConv(nn.Module):
|
26 |
+
def __init__(self, in_dim, out_dim, agg):
|
27 |
+
super(GraphConv, self).__init__()
|
28 |
+
self.in_dim = in_dim
|
29 |
+
self.out_dim = out_dim
|
30 |
+
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
|
31 |
+
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
|
32 |
+
init.xavier_uniform_(self.weight)
|
33 |
+
init.constant_(self.bias, 0)
|
34 |
+
self.agg = agg()
|
35 |
+
|
36 |
+
def forward(self, features, A):
|
37 |
+
b, n, d = features.shape
|
38 |
+
assert (d == self.in_dim)
|
39 |
+
agg_feats = self.agg(features, A)
|
40 |
+
cat_feats = torch.cat([features, agg_feats], dim=2)
|
41 |
+
out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
|
42 |
+
out = F.relu(out + self.bias)
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class AdaptiveDeformation(nn.Module):
|
47 |
+
def __init__(self, input, state_dim):
|
48 |
+
super(AdaptiveDeformation, self).__init__()
|
49 |
+
self.bn0 = nn.BatchNorm1d(input, affine=False)
|
50 |
+
self.conv1 = nn.Conv1d(input, state_dim, 1)
|
51 |
+
self.rnn = nn.LSTM(input, state_dim, 1, bidirectional=True)
|
52 |
+
self.gconv1 = GraphConv(input, 256, MeanAggregator)
|
53 |
+
self.gconv2 = GraphConv(256, 1024, MeanAggregator)
|
54 |
+
self.gconv3 = GraphConv(1024, 512, MeanAggregator)
|
55 |
+
self.gconv4 = GraphConv(512, state_dim, MeanAggregator)
|
56 |
+
|
57 |
+
self.prediction = nn.Sequential(
|
58 |
+
nn.Conv1d(4*state_dim, 128, 1),
|
59 |
+
nn.ReLU(inplace=True),
|
60 |
+
nn.Dropout(0.1),
|
61 |
+
nn.Conv1d(128, 64, 1),
|
62 |
+
nn.ReLU(inplace=True),
|
63 |
+
nn.Dropout(0.1),
|
64 |
+
nn.Conv1d(64, 2, 1))
|
65 |
+
|
66 |
+
def forward(self, x, A):
|
67 |
+
x = self.bn0(x)
|
68 |
+
|
69 |
+
# # rnn block
|
70 |
+
yl = x.permute(2, 0, 1)
|
71 |
+
yl, _ = self.rnn(yl)
|
72 |
+
yl = yl.permute(1, 2, 0)
|
73 |
+
|
74 |
+
# # gcn block
|
75 |
+
yg = x.permute(0, 2, 1)
|
76 |
+
b, n, c = yg.shape
|
77 |
+
A = A.expand(b, n, n)
|
78 |
+
yg = self.gconv1(yg, A)
|
79 |
+
yg = self.gconv2(yg, A)
|
80 |
+
yg = self.gconv3(yg, A)
|
81 |
+
yg = self.gconv4(yg, A)
|
82 |
+
yg = yg.permute(0, 2, 1)
|
83 |
+
|
84 |
+
# res block
|
85 |
+
x = torch.cat([yl, yg, self.conv1(x)], dim=1)
|
86 |
+
pred = self.prediction(x)
|
87 |
+
|
88 |
+
return pred
|
IndicPhotoOCR/detection/textbpn/network/layers/CircConv.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class CircConv(nn.Module):
|
6 |
+
def __init__(self, state_dim, out_state_dim=None, n_adj=4):
|
7 |
+
super(CircConv, self).__init__()
|
8 |
+
|
9 |
+
self.n_adj = n_adj
|
10 |
+
out_state_dim = state_dim if out_state_dim is None else out_state_dim
|
11 |
+
self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1)
|
12 |
+
|
13 |
+
def forward(self, input, adj):
|
14 |
+
input = torch.cat([input[..., -self.n_adj:], input, input[..., :self.n_adj]], dim=2)
|
15 |
+
return self.fc(input)
|
16 |
+
|
17 |
+
|
18 |
+
class DilatedCircConv(nn.Module):
|
19 |
+
def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1):
|
20 |
+
super(DilatedCircConv, self).__init__()
|
21 |
+
|
22 |
+
self.n_adj = n_adj
|
23 |
+
self.dilation = dilation
|
24 |
+
out_state_dim = state_dim if out_state_dim is None else out_state_dim
|
25 |
+
self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation)
|
26 |
+
|
27 |
+
def forward(self, input, adj):
|
28 |
+
if self.n_adj != 0:
|
29 |
+
input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2)
|
30 |
+
return self.fc(input)
|
31 |
+
|
32 |
+
|
33 |
+
_conv_factory = {
|
34 |
+
'grid': CircConv,
|
35 |
+
'dgrid': DilatedCircConv
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class BasicBlock(nn.Module):
|
40 |
+
def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1):
|
41 |
+
super(BasicBlock, self).__init__()
|
42 |
+
|
43 |
+
self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation)
|
44 |
+
self.relu = nn.ReLU(inplace=True)
|
45 |
+
self.norm = nn.BatchNorm1d(out_state_dim)
|
46 |
+
|
47 |
+
def forward(self, x, adj=None):
|
48 |
+
x = self.conv(x, adj)
|
49 |
+
x = self.relu(x)
|
50 |
+
x = self.norm(x)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
class DeepSnake(nn.Module):
|
55 |
+
def __init__(self, state_dim, feature_dim, conv_type='dgrid'):
|
56 |
+
super(DeepSnake, self).__init__()
|
57 |
+
|
58 |
+
self.head = BasicBlock(feature_dim, state_dim, conv_type)
|
59 |
+
|
60 |
+
self.res_layer_num = 7
|
61 |
+
dilation = [1, 1, 1, 2, 2, 4, 4]
|
62 |
+
for i in range(self.res_layer_num):
|
63 |
+
conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i])
|
64 |
+
self.__setattr__('res'+str(i), conv)
|
65 |
+
|
66 |
+
fusion_state_dim = 256
|
67 |
+
self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1)
|
68 |
+
self.prediction = nn.Sequential(
|
69 |
+
nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1),
|
70 |
+
nn.ReLU(inplace=True),
|
71 |
+
nn.Conv1d(256, 64, 1),
|
72 |
+
nn.ReLU(inplace=True),
|
73 |
+
nn.Conv1d(64, 2, 1)
|
74 |
+
)
|
75 |
+
|
76 |
+
def forward(self, x, adj):
|
77 |
+
states = []
|
78 |
+
|
79 |
+
x = self.head(x, adj)
|
80 |
+
states.append(x)
|
81 |
+
for i in range(self.res_layer_num):
|
82 |
+
x = self.__getattr__('res'+str(i))(x, adj) + x
|
83 |
+
states.append(x)
|
84 |
+
|
85 |
+
state = torch.cat(states, dim=1)
|
86 |
+
global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0]
|
87 |
+
global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2))
|
88 |
+
state = torch.cat([global_state, state], dim=1)
|
89 |
+
x = self.prediction(state)
|
90 |
+
|
91 |
+
return x
|
IndicPhotoOCR/detection/textbpn/network/layers/GCN.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###################################################################
|
2 |
+
# File Name: GCN.py
|
3 |
+
# Author: S.X.Zhang
|
4 |
+
###################################################################
|
5 |
+
|
6 |
+
from __future__ import print_function
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import absolute_import
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn import init
|
14 |
+
|
15 |
+
|
16 |
+
class MeanAggregator(nn.Module):
|
17 |
+
def __init__(self):
|
18 |
+
super(MeanAggregator, self).__init__()
|
19 |
+
|
20 |
+
def forward(self, features, A):
|
21 |
+
x = torch.bmm(A, features)
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class GraphConv(nn.Module):
|
26 |
+
def __init__(self, in_dim, out_dim, agg):
|
27 |
+
super(GraphConv, self).__init__()
|
28 |
+
self.in_dim = in_dim
|
29 |
+
self.out_dim = out_dim
|
30 |
+
self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
|
31 |
+
self.bias = nn.Parameter(torch.FloatTensor(out_dim))
|
32 |
+
init.xavier_uniform_(self.weight)
|
33 |
+
init.constant_(self.bias, 0)
|
34 |
+
self.agg = agg()
|
35 |
+
|
36 |
+
def forward(self, features, A):
|
37 |
+
b, n, d = features.shape
|
38 |
+
assert (d == self.in_dim)
|
39 |
+
agg_feats = self.agg(features, A)
|
40 |
+
cat_feats = torch.cat([features, agg_feats], dim=2)
|
41 |
+
out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
|
42 |
+
out = F.relu(out + self.bias)
|
43 |
+
return out
|
44 |
+
|
45 |
+
|
46 |
+
class GCN(nn.Module):
|
47 |
+
def __init__(self, in_dim, out_dim):
|
48 |
+
super(GCN, self).__init__()
|
49 |
+
self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
|
50 |
+
|
51 |
+
self.conv1 = GraphConv(in_dim, 256, MeanAggregator)
|
52 |
+
self.conv2 = GraphConv(256, 1024, MeanAggregator)
|
53 |
+
self.conv3 = GraphConv(1024, 512, MeanAggregator)
|
54 |
+
self.conv4 = GraphConv(512, out_dim, MeanAggregator)
|
55 |
+
|
56 |
+
self.prediction = nn.Sequential(
|
57 |
+
nn.Conv1d(out_dim, 128, 1),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
nn.Conv1d(128, 64, 1),
|
60 |
+
nn.ReLU(inplace=True),
|
61 |
+
nn.Conv1d(64, 2, 1))
|
62 |
+
|
63 |
+
def forward(self, x, A):
|
64 |
+
x = self.bn0(x)
|
65 |
+
x = x.permute(0, 2, 1)
|
66 |
+
b, n, c = x.shape
|
67 |
+
A = A.expand(b, n, n)
|
68 |
+
|
69 |
+
x = self.conv1(x, A)
|
70 |
+
x = self.conv2(x, A)
|
71 |
+
x = self.conv3(x, A)
|
72 |
+
x = self.conv4(x, A)
|
73 |
+
|
74 |
+
x = x.permute(0, 2, 1)
|
75 |
+
pred = self.prediction(x)
|
76 |
+
|
77 |
+
return pred
|
IndicPhotoOCR/detection/textbpn/network/layers/GraphConv.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn.parameter import Parameter
|
5 |
+
from torch.nn.modules.module import Module
|
6 |
+
from torch.nn import init
|
7 |
+
|
8 |
+
|
9 |
+
class GraphConvolution(Module):
|
10 |
+
"""
|
11 |
+
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, in_features, out_features, bias=True):
|
15 |
+
super(GraphConvolution, self).__init__()
|
16 |
+
self.in_features = in_features
|
17 |
+
self.out_features = out_features
|
18 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
19 |
+
init.xavier_uniform_(self.weight)
|
20 |
+
if bias:
|
21 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
22 |
+
init.constant_(self.bias, 0)
|
23 |
+
else:
|
24 |
+
self.register_parameter('bias', None)
|
25 |
+
|
26 |
+
self.reset_parameters()
|
27 |
+
|
28 |
+
def reset_parameters(self):
|
29 |
+
stdv = 1. / math.sqrt(self.weight.size(1))
|
30 |
+
self.weight.data.uniform_(-stdv, stdv)
|
31 |
+
if self.bias is not None:
|
32 |
+
self.bias.data.uniform_(-stdv, stdv)
|
33 |
+
|
34 |
+
def forward(self, input, adj):
|
35 |
+
support = torch.mm(input, self.weight)
|
36 |
+
output = torch.spmm(adj, support)
|
37 |
+
if self.bias is not None:
|
38 |
+
return output + self.bias
|
39 |
+
else:
|
40 |
+
return output
|
41 |
+
|
42 |
+
def __repr__(self):
|
43 |
+
return self.__class__.__name__ + ' (' \
|
44 |
+
+ str(self.in_features) + ' -> ' \
|
45 |
+
+ str(self.out_features) + ')'
|
IndicPhotoOCR/detection/textbpn/network/layers/RNN.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###################################################################
|
2 |
+
# File Name: RNN.py
|
3 |
+
# Author: S.X.Zhang
|
4 |
+
###################################################################
|
5 |
+
|
6 |
+
from __future__ import print_function
|
7 |
+
from __future__ import division
|
8 |
+
from __future__ import absolute_import
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch.nn import init
|
14 |
+
|
15 |
+
|
16 |
+
class RNN(nn.Module):
|
17 |
+
def __init__(self, input, state_dim):
|
18 |
+
super(RNN, self).__init__()
|
19 |
+
self.bn0 = nn.BatchNorm1d(input, affine=False)
|
20 |
+
self.rnn = nn.LSTM(input, state_dim, 1, dropout=0.1, bidirectional=True)
|
21 |
+
self.prediction = nn.Sequential(
|
22 |
+
nn.Conv1d(state_dim*2, 128, 1),
|
23 |
+
nn.ReLU(inplace=True),
|
24 |
+
nn.Conv1d(128, 64, 1),
|
25 |
+
nn.ReLU(inplace=True),
|
26 |
+
nn.Conv1d(64, 2, 1))
|
27 |
+
|
28 |
+
def forward(self, x, adj):
|
29 |
+
x = self.bn0(x)
|
30 |
+
x = x.permute(2, 0, 1)
|
31 |
+
x, _ = self.rnn(x)
|
32 |
+
x = x.permute(1, 2, 0)
|
33 |
+
pred = self.prediction(x)
|
34 |
+
|
35 |
+
return pred
|
IndicPhotoOCR/detection/textbpn/network/layers/Transformer.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###################################################################
|
2 |
+
# File Name: GCN.py
|
3 |
+
# Author: S.X.Zhang
|
4 |
+
###################################################################
|
5 |
+
import torch
|
6 |
+
from torch import nn, Tensor
|
7 |
+
import numpy as np
|
8 |
+
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
|
9 |
+
|
10 |
+
|
11 |
+
class Positional_encoding(nn.Module):
|
12 |
+
def __init__(self, PE_size, n_position=256):
|
13 |
+
super(Positional_encoding, self).__init__()
|
14 |
+
self.PE_size = PE_size
|
15 |
+
self.n_position = n_position
|
16 |
+
self.register_buffer('pos_table', self.get_encoding_table(n_position, PE_size))
|
17 |
+
|
18 |
+
def get_encoding_table(self, n_position, PE_size):
|
19 |
+
position_table = np.array(
|
20 |
+
[[pos / np.power(10000, 2. * i / self.PE_size) for i in range(self.PE_size)] for pos in range(n_position)])
|
21 |
+
position_table[:, 0::2] = np.sin(position_table[:, 0::2])
|
22 |
+
position_table[:, 1::2] = np.cos(position_table[:, 1::2])
|
23 |
+
return torch.FloatTensor(position_table).unsqueeze(0)
|
24 |
+
|
25 |
+
def forward(self, inputs):
|
26 |
+
return inputs + self.pos_table[:, :inputs.size(1), :].clone().detach()
|
27 |
+
|
28 |
+
|
29 |
+
class MultiHeadAttention(nn.Module):
|
30 |
+
def __init__(self, num_heads, embed_dim, dropout=0.1, if_resi=True):
|
31 |
+
super(MultiHeadAttention, self).__init__()
|
32 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
33 |
+
self.MultiheadAttention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
|
34 |
+
self.Q_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
|
35 |
+
self.K_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
|
36 |
+
self.V_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
|
37 |
+
self.if_resi = if_resi
|
38 |
+
|
39 |
+
def forward(self, inputs):
|
40 |
+
query = self.layer_norm(inputs)
|
41 |
+
q = self.Q_proj(query)
|
42 |
+
k = self.K_proj(query)
|
43 |
+
v = self.V_proj(query)
|
44 |
+
attn_output, attn_output_weights = self.MultiheadAttention(q, k, v)
|
45 |
+
if self.if_resi:
|
46 |
+
attn_output += inputs
|
47 |
+
else:
|
48 |
+
attn_output = attn_output
|
49 |
+
|
50 |
+
return attn_output
|
51 |
+
|
52 |
+
|
53 |
+
class FeedForward(nn.Module):
|
54 |
+
def __init__(self, in_channel, FFN_channel, if_resi=True):
|
55 |
+
super(FeedForward, self).__init__()
|
56 |
+
"""
|
57 |
+
1024 2048
|
58 |
+
"""
|
59 |
+
output_channel = (FFN_channel, in_channel)
|
60 |
+
self.fc1 = nn.Sequential(nn.Linear(in_channel, output_channel[0]), nn.ReLU())
|
61 |
+
self.fc2 = nn.Linear(output_channel[0], output_channel[1])
|
62 |
+
self.layer_norm = nn.LayerNorm(in_channel)
|
63 |
+
self.if_resi = if_resi
|
64 |
+
|
65 |
+
def forward(self, inputs):
|
66 |
+
outputs = self.layer_norm(inputs)
|
67 |
+
outputs = self.fc1(outputs)
|
68 |
+
outputs = self.fc2(outputs)
|
69 |
+
if self.if_resi:
|
70 |
+
outputs += inputs
|
71 |
+
else:
|
72 |
+
outputs = outputs
|
73 |
+
return outputs
|
74 |
+
|
75 |
+
|
76 |
+
class TransformerLayer(nn.Module):
|
77 |
+
def __init__(self, out_dim, in_dim, num_heads, attention_size,
|
78 |
+
dim_feedforward=1024, drop_rate=0.1, if_resi=True, block_nums=3):
|
79 |
+
super(TransformerLayer, self).__init__()
|
80 |
+
self.block_nums = block_nums
|
81 |
+
self.if_resi = if_resi
|
82 |
+
self.linear = nn.Linear(in_dim, attention_size)
|
83 |
+
for i in range(self.block_nums):
|
84 |
+
self.__setattr__('MHA_self_%d' % i, MultiHeadAttention(num_heads, attention_size,
|
85 |
+
dropout=drop_rate, if_resi=if_resi))
|
86 |
+
self.__setattr__('FFN_%d' % i, FeedForward(out_dim, dim_feedforward, if_resi=if_resi))
|
87 |
+
|
88 |
+
def forward(self, query):
|
89 |
+
inputs = self.linear(query)
|
90 |
+
# outputs = inputs
|
91 |
+
for i in range(self.block_nums):
|
92 |
+
outputs = self.__getattr__('MHA_self_%d' % i)(inputs)
|
93 |
+
outputs = self.__getattr__('FFN_%d' % i)(outputs)
|
94 |
+
if self.if_resi:
|
95 |
+
inputs = inputs+outputs
|
96 |
+
else:
|
97 |
+
inputs = outputs
|
98 |
+
# outputs = inputs
|
99 |
+
return inputs
|
100 |
+
|
101 |
+
|
102 |
+
class Transformer(nn.Module):
|
103 |
+
|
104 |
+
def __init__(self, in_dim, out_dim, num_heads=8,
|
105 |
+
dim_feedforward=1024, drop_rate=0.1, if_resi=False, block_nums=3):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
|
109 |
+
self.conv1 = nn.Conv1d(in_dim, out_dim, 1, dilation=1)
|
110 |
+
|
111 |
+
# self.pos_embedding = Positional_encoding(in_dim)
|
112 |
+
self.transformer = TransformerLayer(out_dim, in_dim, num_heads, attention_size=out_dim,
|
113 |
+
dim_feedforward=dim_feedforward, drop_rate=drop_rate,
|
114 |
+
if_resi=if_resi, block_nums=block_nums)
|
115 |
+
|
116 |
+
self.prediction = nn.Sequential(
|
117 |
+
nn.Conv1d(2*out_dim, 128, 1),
|
118 |
+
nn.ReLU(inplace=True),
|
119 |
+
nn.Dropout(0.1),
|
120 |
+
nn.Conv1d(128, 64, 1),
|
121 |
+
nn.ReLU(inplace=True),
|
122 |
+
# nn.Dropout(0.1),
|
123 |
+
nn.Conv1d(64, 2, 1))
|
124 |
+
|
125 |
+
def forward(self, x, adj):
|
126 |
+
x = self.bn0(x)
|
127 |
+
|
128 |
+
x1 = x.permute(0, 2, 1)
|
129 |
+
# x1 = self.pos_embedding(x1)
|
130 |
+
x1 = self.transformer(x1)
|
131 |
+
x1 = x1.permute(0, 2, 1)
|
132 |
+
|
133 |
+
x = torch.cat([x1, self.conv1(x)], dim=1)
|
134 |
+
# x = x1+self.conv1(x)
|
135 |
+
pred = self.prediction(x)
|
136 |
+
|
137 |
+
return pred
|
138 |
+
|
139 |
+
|
140 |
+
|
IndicPhotoOCR/detection/textbpn/network/layers/Transformer_old.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###################################################################
|
2 |
+
# File Name: GCN.py
|
3 |
+
# Author: S.X.Zhang
|
4 |
+
###################################################################
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn, Tensor
|
9 |
+
from torch.autograd import Variable
|
10 |
+
import numpy as np
|
11 |
+
from cfglib.config import config as cfg
|
12 |
+
|
13 |
+
|
14 |
+
class Positional_encoding(nn.Module):
|
15 |
+
def __init__(self, PE_size, n_position=200):
|
16 |
+
super(Positional_encoding, self).__init__()
|
17 |
+
self.PE_size = PE_size
|
18 |
+
self.n_position = n_position
|
19 |
+
self.register_buffer('pos_table', self.get_encoding_table(n_position, PE_size))
|
20 |
+
|
21 |
+
def get_encoding_table(self, n_position, PE_size):
|
22 |
+
position_table = np.array(
|
23 |
+
[[pos / np.power(10000, 2. * i / self.PE_size) for i in range(self.PE_size)] for pos in range(n_position)])
|
24 |
+
position_table[:, 0::2] = np.sin(position_table[:, 0::2])
|
25 |
+
position_table[:, 1::2] = np.cos(position_table[:, 1::2])
|
26 |
+
return torch.FloatTensor(position_table).unsqueeze(0)
|
27 |
+
|
28 |
+
def forward(self, inputs):
|
29 |
+
return inputs + self.pos_table[:, :inputs.size(1), :].clone().detach()
|
30 |
+
|
31 |
+
|
32 |
+
class MultiHeadAttention(nn.Module):
|
33 |
+
def __init__(self, num_heads, embedding_size, attention_size,
|
34 |
+
drop_rate, future_blind=True, query_mask=False, if_resi=True):
|
35 |
+
super(MultiHeadAttention, self).__init__()
|
36 |
+
self.num_heads = num_heads
|
37 |
+
self.embedding_size = embedding_size
|
38 |
+
self.attention_size = attention_size
|
39 |
+
self.drop_rate = drop_rate
|
40 |
+
self.future_blind = future_blind
|
41 |
+
|
42 |
+
self.Q_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
|
43 |
+
self.K_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
|
44 |
+
self.V_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
|
45 |
+
|
46 |
+
self.drop_out = nn.Dropout(p=self.drop_rate)
|
47 |
+
self.layer_norm = nn.LayerNorm(self.attention_size)
|
48 |
+
self.if_resi = if_resi
|
49 |
+
|
50 |
+
def forward(self, query, key, value):
|
51 |
+
q = self.Q_proj(query)
|
52 |
+
k = self.K_proj(key)
|
53 |
+
v = self.V_proj(value)
|
54 |
+
|
55 |
+
q_ = torch.cat(torch.chunk(q, self.num_heads, dim=2), dim=0)
|
56 |
+
k_ = torch.cat(torch.chunk(k, self.num_heads, dim=2), dim=0)
|
57 |
+
v_ = torch.cat(torch.chunk(v, self.num_heads, dim=2), dim=0)
|
58 |
+
|
59 |
+
outputs = torch.bmm(q_, k_.permute(0, 2, 1))
|
60 |
+
outputs = outputs / (k_.size()[-1] ** 0.5)
|
61 |
+
|
62 |
+
# key mask
|
63 |
+
|
64 |
+
# future mask
|
65 |
+
if self.future_blind:
|
66 |
+
diag_vals = torch.ones_like(outputs[0, :, :]).to(cfg.device)
|
67 |
+
tril = torch.tril(diag_vals, diagonal=0)
|
68 |
+
masks = Variable(torch.unsqueeze(tril, 0).repeat(outputs.size()[0], 1, 1)) # (h*N,T_q,T_k)
|
69 |
+
padding = Variable(torch.ones_like(masks).to(cfg.device) * (-2 ** 32 + 1))
|
70 |
+
condition = masks.eq(0)
|
71 |
+
outputs = torch.where(condition, padding, outputs)
|
72 |
+
|
73 |
+
outputs = F.softmax(outputs, dim=-1)
|
74 |
+
# if self.future_blind==True:a
|
75 |
+
# print(outputs[0])
|
76 |
+
outputs = self.drop_out(outputs)
|
77 |
+
|
78 |
+
outputs = torch.bmm(outputs, v_)
|
79 |
+
outputs = torch.cat(torch.chunk(outputs, self.num_heads, dim=0), dim=2) # N,T_q,C
|
80 |
+
|
81 |
+
if self.if_resi:
|
82 |
+
# outputs += query
|
83 |
+
outputs += q
|
84 |
+
else:
|
85 |
+
outputs = outputs
|
86 |
+
outputs = self.layer_norm(outputs)
|
87 |
+
|
88 |
+
return outputs
|
89 |
+
|
90 |
+
|
91 |
+
class FeedForward(nn.Module):
|
92 |
+
def __init__(self, in_channel, FFN_channel, if_resi=True):
|
93 |
+
super(FeedForward, self).__init__()
|
94 |
+
"""
|
95 |
+
1024 2048
|
96 |
+
"""
|
97 |
+
output_channel = (FFN_channel, in_channel)
|
98 |
+
self.fc1 = nn.Sequential(nn.Linear(in_channel, output_channel[0]), nn.ReLU())
|
99 |
+
self.fc2 = nn.Linear(output_channel[0], output_channel[1])
|
100 |
+
self.layer_norm = nn.LayerNorm(in_channel)
|
101 |
+
self.if_resi = if_resi
|
102 |
+
|
103 |
+
def forward(self, inputs):
|
104 |
+
outputs = self.fc1(inputs)
|
105 |
+
outputs = self.fc2(outputs)
|
106 |
+
if self.if_resi:
|
107 |
+
outputs += inputs
|
108 |
+
else:
|
109 |
+
outputs = outputs
|
110 |
+
outputs = self.layer_norm(outputs)
|
111 |
+
return outputs
|
112 |
+
|
113 |
+
|
114 |
+
class TransformerLayer(nn.Module):
|
115 |
+
def __init__(self, out_dim, num_heads, embedding_size, attention_size,
|
116 |
+
dim_feedforward=1024, drop_rate=0.1, if_resi=True, block_nums=3):
|
117 |
+
super(TransformerLayer, self).__init__()
|
118 |
+
self.block_nums = block_nums
|
119 |
+
self.if_resi = if_resi
|
120 |
+
for i in range(self.block_nums):
|
121 |
+
self.__setattr__('MHA_self_%d' % i, MultiHeadAttention(num_heads, embedding_size, attention_size,
|
122 |
+
drop_rate, future_blind=False, if_resi=if_resi))
|
123 |
+
self.__setattr__('FFN_%d' % i, FeedForward(out_dim, dim_feedforward, if_resi=if_resi))
|
124 |
+
|
125 |
+
def forward(self, query):
|
126 |
+
outputs = None
|
127 |
+
for i in range(self.block_nums):
|
128 |
+
outputs = self.__getattr__('MHA_self_%d' % i)(query, query, query)
|
129 |
+
outputs = self.__getattr__('FFN_%d' % i)(outputs)
|
130 |
+
return outputs
|
131 |
+
|
132 |
+
|
133 |
+
class Transformer(nn.Module):
|
134 |
+
|
135 |
+
def __init__(self, in_dim, out_dim, num_heads=8,
|
136 |
+
dim_feedforward=1024, drop_rate=0.1, if_resi=False, block_nums=3):
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
+
self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
|
140 |
+
self.conv1 = nn.Conv1d(in_dim, out_dim, 1, dilation=1)
|
141 |
+
|
142 |
+
embed_dim = in_dim
|
143 |
+
# self.pos_embedding = Positional_encoding(embed_dim)
|
144 |
+
self.transformer = TransformerLayer(out_dim, num_heads, embedding_size=embed_dim,
|
145 |
+
attention_size=out_dim, dim_feedforward=dim_feedforward,
|
146 |
+
drop_rate=drop_rate, if_resi=if_resi, block_nums=block_nums)
|
147 |
+
|
148 |
+
self.prediction = nn.Sequential(
|
149 |
+
nn.Conv1d(out_dim*2, 128, 1),
|
150 |
+
nn.ReLU(inplace=True),
|
151 |
+
nn.Dropout(0.1),
|
152 |
+
nn.Conv1d(128, 64, 1),
|
153 |
+
nn.ReLU(inplace=True),
|
154 |
+
# nn.Dropout(0.1),
|
155 |
+
nn.Conv1d(64, 2, 1))
|
156 |
+
|
157 |
+
def forward(self, x, adj):
|
158 |
+
x = self.bn0(x)
|
159 |
+
|
160 |
+
x1 = x.permute(0, 2, 1)
|
161 |
+
x1 = self.transformer(x1)
|
162 |
+
x1 = x1.permute(0, 2, 1)
|
163 |
+
|
164 |
+
x = torch.cat([x1, self.conv1(x)], dim=1)
|
165 |
+
# x = x1+self.conv1(x)
|
166 |
+
pred = self.prediction(x)
|
167 |
+
|
168 |
+
return pred
|
169 |
+
|
170 |
+
|
171 |
+
|
IndicPhotoOCR/detection/textbpn/network/layers/__init__.py
ADDED
File without changes
|
IndicPhotoOCR/detection/textbpn/network/layers/gcn_utils.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
__author__ = "S.X.Zhang"
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.autograd import Variable
|
8 |
+
|
9 |
+
|
10 |
+
def normalize_adj(A, type="AD"):
|
11 |
+
if type == "DAD":
|
12 |
+
A = A + np.eye(A.shape[0]) # A=A+I
|
13 |
+
d = np.sum(A, axis=0)
|
14 |
+
d_inv = np.power(d, -0.5).flatten()
|
15 |
+
d_inv[np.isinf(d_inv)] = 0.0
|
16 |
+
d_inv = np.diag(d_inv)
|
17 |
+
G = A.dot(d_inv).transpose().dot(d_inv) # L = D^-1/2 A D^-1/2
|
18 |
+
G = torch.from_numpy(G)
|
19 |
+
elif type == "AD":
|
20 |
+
A = A + np.eye(A.shape[0]) # A=A+I
|
21 |
+
A = torch.from_numpy(A)
|
22 |
+
D = A.sum(1, keepdim=True)
|
23 |
+
G = A.div(D) # L= A/D
|
24 |
+
else:
|
25 |
+
A = A + np.eye(A.shape[0]) # A=A+I
|
26 |
+
D = A.sum(1, keepdim=True)
|
27 |
+
D = np.diag(D)
|
28 |
+
G = torch.from_numpy(D - A) # L = D-A
|
29 |
+
return G
|
30 |
+
|
31 |
+
|
32 |
+
def np_to_variable(x, is_cuda=True, dtype=torch.FloatTensor):
|
33 |
+
v = Variable(torch.from_numpy(x).type(dtype))
|
34 |
+
if is_cuda:
|
35 |
+
v = v.cuda()
|
36 |
+
return v
|
37 |
+
|
38 |
+
|
39 |
+
def set_trainable(model, requires_grad):
|
40 |
+
for param in model.parameters():
|
41 |
+
param.requires_grad = requires_grad
|
42 |
+
|
43 |
+
|
44 |
+
def weights_normal_init(model, dev=0.01):
|
45 |
+
if isinstance(model, list):
|
46 |
+
for m in model:
|
47 |
+
weights_normal_init(m, dev)
|
48 |
+
else:
|
49 |
+
for m in model.modules():
|
50 |
+
if isinstance(m, nn.Conv2d):
|
51 |
+
m.weight.data.normal_(0.0, dev)
|
52 |
+
elif isinstance(m, nn.Linear):
|
53 |
+
m.weight.data.normal_(0.0, dev)
|
54 |
+
|
55 |
+
|
56 |
+
def clip_gradient(model, clip_norm):
|
57 |
+
"""Computes a gradient clipping coefficient based on gradient norm."""
|
58 |
+
totalnorm = 0
|
59 |
+
for p in model.parameters():
|
60 |
+
if p.requires_grad:
|
61 |
+
modulenorm = p.grad.data.norm()
|
62 |
+
totalnorm += modulenorm ** 2
|
63 |
+
totalnorm = np.sqrt(totalnorm)
|
64 |
+
|
65 |
+
norm = clip_norm / max(totalnorm, clip_norm)
|
66 |
+
for p in model.parameters():
|
67 |
+
if p.requires_grad:
|
68 |
+
p.grad.mul_(norm)
|
69 |
+
|
70 |
+
|
71 |
+
def EuclideanDistances(A, B):
|
72 |
+
BT = B.transpose()
|
73 |
+
vecProd = np.dot(A,BT)
|
74 |
+
SqA = A**2
|
75 |
+
sumSqA = np.matrix(np.sum(SqA, axis=1))
|
76 |
+
sumSqAEx = np.tile(sumSqA.transpose(), (1, vecProd.shape[1]))
|
77 |
+
|
78 |
+
SqB = B**2
|
79 |
+
sumSqB = np.sum(SqB, axis=1)
|
80 |
+
sumSqBEx = np.tile(sumSqB, (vecProd.shape[0], 1))
|
81 |
+
SqED = sumSqBEx + sumSqAEx - 2*vecProd
|
82 |
+
SqED[SqED<0]=0.0
|
83 |
+
ED = np.sqrt(SqED)
|
84 |
+
return ED
|
85 |
+
|
86 |
+
|
87 |
+
def get_center_feature(cnn_feature, img_poly, ind, h, w):
|
88 |
+
batch_size = cnn_feature.size(0)
|
89 |
+
for i in range(batch_size):
|
90 |
+
poly = img_poly[ind == i].cpu().numpy()
|
91 |
+
mask = np.zeros((h, w), dtype=np.uint8)
|
92 |
+
cv2.fillPoly(mask, poly.astype(np.int32), color=(1,))
|
93 |
+
return None
|
94 |
+
|
95 |
+
|
96 |
+
def get_node_feature(cnn_feature, img_poly, ind, h, w):
|
97 |
+
img_poly = img_poly.clone().float()
|
98 |
+
img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1
|
99 |
+
img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1
|
100 |
+
|
101 |
+
batch_size = cnn_feature.size(0)
|
102 |
+
gcn_feature = torch.zeros([img_poly.size(0), cnn_feature.size(1), img_poly.size(1)]).to(img_poly.device)
|
103 |
+
for i in range(batch_size):
|
104 |
+
poly = img_poly[ind == i].unsqueeze(0)
|
105 |
+
gcn_feature[ind == i] = torch.nn.functional.grid_sample(cnn_feature[i:i + 1], poly)[0].permute(1, 0, 2)
|
106 |
+
return gcn_feature
|
107 |
+
|
108 |
+
|
109 |
+
def get_adj_mat(n_adj, n_nodes):
|
110 |
+
a = np.zeros([n_nodes, n_nodes], dtype=np.float)
|
111 |
+
|
112 |
+
for i in range(n_nodes):
|
113 |
+
for j in range(-n_adj // 2, n_adj // 2 + 1):
|
114 |
+
if j != 0:
|
115 |
+
a[i][(i + j) % n_nodes] = 1
|
116 |
+
a[(i + j) % n_nodes][i] = 1
|
117 |
+
return a
|
118 |
+
|
119 |
+
|
120 |
+
def get_adj_ind(n_adj, n_nodes, device):
|
121 |
+
ind = torch.tensor([i for i in range(-n_adj // 2, n_adj // 2 + 1) if i != 0]).long()
|
122 |
+
ind = (torch.arange(n_nodes)[:, None] + ind[None]) % n_nodes
|
123 |
+
return ind.to(device)
|
124 |
+
|
125 |
+
|
126 |
+
def coord_embedding(b, w, h, device):
|
127 |
+
x_range = torch.linspace(0, 1, w, device=device)
|
128 |
+
y_range = torch.linspace(0, 1, h, device=device)
|
129 |
+
y, x = torch.meshgrid(y_range, x_range)
|
130 |
+
y = y.expand([b, 1, -1, -1])
|
131 |
+
x = x.expand([b, 1, -1, -1])
|
132 |
+
coord_map = torch.cat([x, y], 1)
|
133 |
+
|
134 |
+
return coord_map
|
135 |
+
|
136 |
+
|
137 |
+
def img_poly_to_can_poly(img_poly):
|
138 |
+
if len(img_poly) == 0:
|
139 |
+
return torch.zeros_like(img_poly)
|
140 |
+
x_min = torch.min(img_poly[..., 0], dim=-1)[0]
|
141 |
+
y_min = torch.min(img_poly[..., 1], dim=-1)[0]
|
142 |
+
can_poly = img_poly.clone()
|
143 |
+
can_poly[..., 0] = can_poly[..., 0] - x_min[..., None]
|
144 |
+
can_poly[..., 1] = can_poly[..., 1] - y_min[..., None]
|
145 |
+
# x_max = torch.max(img_poly[..., 0], dim=-1)[0]
|
146 |
+
# y_max = torch.max(img_poly[..., 1], dim=-1)[0]
|
147 |
+
# h, w = y_max - y_min + 1, x_max - x_min + 1
|
148 |
+
# long_side = torch.max(h, w)
|
149 |
+
# can_poly = can_poly / long_side[..., None, None]
|
150 |
+
return can_poly
|
IndicPhotoOCR/detection/textbpn/network/layers/model_block.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
__author__ = "S.X.Zhang"
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.vgg import VggNet
|
7 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.resnet import ResNet
|
8 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.resnet_dcn import ResNet_DCN
|
9 |
+
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
|
10 |
+
|
11 |
+
|
12 |
+
class UpBlok(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, in_channels, out_channels):
|
15 |
+
super().__init__()
|
16 |
+
self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
17 |
+
self.conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
18 |
+
self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
19 |
+
|
20 |
+
def forward(self, upsampled, shortcut):
|
21 |
+
x = torch.cat([upsampled, shortcut], dim=1)
|
22 |
+
x = self.conv1x1(x)
|
23 |
+
x = F.relu(x)
|
24 |
+
x = self.conv3x3(x)
|
25 |
+
x = F.relu(x)
|
26 |
+
x = self.deconv(x)
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class MergeBlok(nn.Module):
|
31 |
+
def __init__(self, in_channels, out_channels):
|
32 |
+
super().__init__()
|
33 |
+
self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
34 |
+
self.conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
35 |
+
|
36 |
+
def forward(self, upsampled, shortcut):
|
37 |
+
x = torch.cat([upsampled, shortcut], dim=1)
|
38 |
+
x = self.conv1x1(x)
|
39 |
+
x = F.relu(x)
|
40 |
+
x = self.conv3x3(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class FPN(nn.Module):
|
45 |
+
|
46 |
+
def __init__(self, backbone='resnet50', is_training=True):
|
47 |
+
super().__init__()
|
48 |
+
self.is_training = is_training
|
49 |
+
self.backbone_name = backbone
|
50 |
+
|
51 |
+
if backbone in ['vgg_bn', 'vgg']:
|
52 |
+
self.backbone = VggNet(name=backbone, pretrain=is_training)
|
53 |
+
self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
|
54 |
+
self.merge4 = UpBlok(512 + 256, 128)
|
55 |
+
self.merge3 = UpBlok(256 + 128, 64)
|
56 |
+
if cfg.scale == 1:
|
57 |
+
self.merge2 = UpBlok(128 + 64, 32) # FPN 1/2
|
58 |
+
self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
|
59 |
+
elif cfg.scale == 2:
|
60 |
+
self.merge2 = UpBlok(128 + 64, 32) # FPN 1/2
|
61 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
|
62 |
+
elif cfg.scale == 4:
|
63 |
+
self.merge2 = MergeBlok(128 + 64, 32) # FPN 1/4
|
64 |
+
|
65 |
+
elif backbone in ['resnet50']:
|
66 |
+
self.backbone = ResNet(name=backbone, pretrain=is_training)
|
67 |
+
self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
|
68 |
+
self.merge4 = UpBlok(1024 + 256, 128)
|
69 |
+
self.merge3 = UpBlok(512 + 128, 64)
|
70 |
+
if cfg.scale == 1:
|
71 |
+
self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
|
72 |
+
self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
|
73 |
+
elif cfg.scale == 2:
|
74 |
+
self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
|
75 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
|
76 |
+
elif cfg.scale == 4:
|
77 |
+
self.merge2 = MergeBlok(256 + 64, 32) # FPN 1/4
|
78 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
|
79 |
+
|
80 |
+
elif backbone in ['resnet18']:
|
81 |
+
self.backbone = ResNet(name=backbone, pretrain=is_training)
|
82 |
+
self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
|
83 |
+
self.merge4 = UpBlok(256 + 256, 128)
|
84 |
+
self.merge3 = UpBlok(128 + 128, 64)
|
85 |
+
if cfg.scale == 1:
|
86 |
+
self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
|
87 |
+
self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
|
88 |
+
elif cfg.scale == 2:
|
89 |
+
self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
|
90 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
|
91 |
+
elif cfg.scale == 4:
|
92 |
+
self.merge2 = MergeBlok(64 + 64, 32) # FPN 1/4
|
93 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
|
94 |
+
|
95 |
+
elif backbone in ["deformable_resnet18"]:
|
96 |
+
self.backbone = ResNet_DCN(name=backbone, pretrain=is_training)
|
97 |
+
self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
|
98 |
+
self.merge4 = UpBlok(256 + 256, 128)
|
99 |
+
self.merge3 = UpBlok(128 + 128, 64)
|
100 |
+
if cfg.scale == 1:
|
101 |
+
self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
|
102 |
+
self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
|
103 |
+
elif cfg.scale == 2:
|
104 |
+
self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
|
105 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
|
106 |
+
elif cfg.scale == 4:
|
107 |
+
self.merge2 = MergeBlok(64 + 64, 32) # FPN 1/4
|
108 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
|
109 |
+
|
110 |
+
elif backbone in ["deformable_resnet50"]:
|
111 |
+
self.backbone = ResNet_DCN(name=backbone, pretrain=is_training)
|
112 |
+
self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
|
113 |
+
self.merge4 = UpBlok(1024 + 256, 128)
|
114 |
+
self.merge3 = UpBlok(512 + 128, 64)
|
115 |
+
if cfg.scale == 1:
|
116 |
+
self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
|
117 |
+
self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
|
118 |
+
elif cfg.scale == 2:
|
119 |
+
self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
|
120 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
|
121 |
+
elif cfg.scale == 4:
|
122 |
+
self.merge2 = MergeBlok(256 + 64, 32) # FPN 1/4
|
123 |
+
self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
|
124 |
+
else:
|
125 |
+
print("backbone is not support !")
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
C1, C2, C3, C4, C5 = self.backbone(x)
|
129 |
+
#print(C5.size())
|
130 |
+
#print(C4.size())
|
131 |
+
#print(C3.size())
|
132 |
+
#print(C2.size())
|
133 |
+
#print(C1.size())
|
134 |
+
up5 = self.deconv5(C5)
|
135 |
+
up5 = F.relu(up5)
|
136 |
+
|
137 |
+
up4 = self.merge4(C4, up5)
|
138 |
+
up4 = F.relu(up4)
|
139 |
+
|
140 |
+
up3 = self.merge3(C3, up4)
|
141 |
+
up3 = F.relu(up3)
|
142 |
+
|
143 |
+
up2 = self.merge2(C2, up3)
|
144 |
+
up2 = F.relu(up2)
|
145 |
+
|
146 |
+
up1 = self.merge1(C1, up2)
|
147 |
+
up1 = F.relu(up1)
|
148 |
+
|
149 |
+
return up1, up2, up3, up4, up5
|
IndicPhotoOCR/detection/textbpn/network/layers/position_encoding.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various positional encodings for the transformer.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
from util.misc import NestedTensor
|
10 |
+
|
11 |
+
|
12 |
+
class PositionEmbeddingSine(nn.Module):
|
13 |
+
"""
|
14 |
+
This is a more standard version of the position embedding, very similar to the one
|
15 |
+
used by the Attention is all you need paper, generalized to work on images.
|
16 |
+
"""
|
17 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
18 |
+
super().__init__()
|
19 |
+
self.num_pos_feats = num_pos_feats
|
20 |
+
self.temperature = temperature
|
21 |
+
self.normalize = normalize
|
22 |
+
if scale is not None and normalize is False:
|
23 |
+
raise ValueError("normalize should be True if scale is passed")
|
24 |
+
if scale is None:
|
25 |
+
scale = 2 * math.pi
|
26 |
+
self.scale = scale
|
27 |
+
|
28 |
+
def forward(self, tensor_list: NestedTensor):
|
29 |
+
x = tensor_list.tensors
|
30 |
+
mask = tensor_list.mask
|
31 |
+
assert mask is not None
|
32 |
+
not_mask = ~mask
|
33 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
34 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
35 |
+
if self.normalize:
|
36 |
+
eps = 1e-6
|
37 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
38 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
39 |
+
|
40 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
41 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
42 |
+
|
43 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
44 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
45 |
+
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
46 |
+
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
47 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
48 |
+
return pos
|
49 |
+
|
50 |
+
|
51 |
+
class PositionEmbeddingLearned(nn.Module):
|
52 |
+
"""
|
53 |
+
Absolute pos embedding, learned.
|
54 |
+
"""
|
55 |
+
def __init__(self, num_pos_feats=256):
|
56 |
+
super().__init__()
|
57 |
+
self.row_embed = nn.Embedding(50, num_pos_feats)
|
58 |
+
self.col_embed = nn.Embedding(50, num_pos_feats)
|
59 |
+
self.reset_parameters()
|
60 |
+
|
61 |
+
def reset_parameters(self):
|
62 |
+
nn.init.uniform_(self.row_embed.weight)
|
63 |
+
nn.init.uniform_(self.col_embed.weight)
|
64 |
+
|
65 |
+
def forward(self, tensor_list: NestedTensor):
|
66 |
+
x = tensor_list.tensors
|
67 |
+
h, w = x.shape[-2:]
|
68 |
+
i = torch.arange(w, device=x.device)
|
69 |
+
j = torch.arange(h, device=x.device)
|
70 |
+
x_emb = self.col_embed(i)
|
71 |
+
y_emb = self.row_embed(j)
|
72 |
+
pos = torch.cat([
|
73 |
+
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
74 |
+
y_emb.unsqueeze(1).repeat(1, w, 1),
|
75 |
+
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
76 |
+
return pos
|
77 |
+
|
78 |
+
|
79 |
+
def build_position_encoding(args):
|
80 |
+
N_steps = args.hidden_dim // 2
|
81 |
+
if args.position_embedding in ('v2', 'sine'):
|
82 |
+
# TODO find a better way of exposing other arguments
|
83 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
84 |
+
elif args.position_embedding in ('v3', 'learned'):
|
85 |
+
position_embedding = PositionEmbeddingLearned(N_steps)
|
86 |
+
else:
|
87 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
88 |
+
|
89 |
+
return position_embedding
|
IndicPhotoOCR/detection/textbpn/network/layers/resnet.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision.models import resnet
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
|
6 |
+
|
7 |
+
model_urls = {
|
8 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
9 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
10 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
11 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
12 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
13 |
+
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
class ResNet(nn.Module):
|
18 |
+
def __init__(self, name="resnet50", pretrain=True):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
if name == "resnet50":
|
22 |
+
base_net = resnet.resnet50(pretrained=False)
|
23 |
+
elif name == "resnet101":
|
24 |
+
base_net = resnet.resnet101(pretrained=False)
|
25 |
+
elif name == "resnet18":
|
26 |
+
base_net = resnet.resnet18(pretrained=False)
|
27 |
+
elif name == "resnet34":
|
28 |
+
base_net = resnet.resnet34(pretrained=False)
|
29 |
+
|
30 |
+
else:
|
31 |
+
print(" base model is not support !")
|
32 |
+
|
33 |
+
if pretrain:
|
34 |
+
print("load the {} weight from ./cache".format(name))
|
35 |
+
base_net.load_state_dict(model_zoo.load_url(model_urls[name], model_dir="./cache",
|
36 |
+
map_location=torch.device(cfg.device)), strict=False)
|
37 |
+
# print(base_net)
|
38 |
+
self.stage1 = nn.Sequential(
|
39 |
+
base_net.conv1,
|
40 |
+
base_net.bn1,
|
41 |
+
base_net.relu,
|
42 |
+
base_net.maxpool
|
43 |
+
)
|
44 |
+
self.stage2 = base_net.layer1
|
45 |
+
self.stage3 = base_net.layer2
|
46 |
+
self.stage4 = base_net.layer3
|
47 |
+
self.stage5 = base_net.layer4
|
48 |
+
self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
C1 = self.stage1(x)
|
52 |
+
C2 = self.stage2(C1)
|
53 |
+
C3 = self.stage3(C2)
|
54 |
+
C4 = self.stage4(C3)
|
55 |
+
C5 = self.stage5(C4)
|
56 |
+
|
57 |
+
if cfg.scale == 2 or cfg.scale == 1:
|
58 |
+
# up2 --> 1/2
|
59 |
+
C1 = self.up2(C1)
|
60 |
+
|
61 |
+
return C1, C2, C3, C4, C5
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
import torch
|
66 |
+
input = torch.randn((4, 3, 512, 512))
|
67 |
+
net = ResNet()
|
68 |
+
C1, C2, C3, C4, C5 = net(input)
|
69 |
+
print(C1.size())
|
70 |
+
print(C2.size())
|
71 |
+
print(C3.size())
|
72 |
+
print(C4.size())
|
73 |
+
print(C5.size())
|
IndicPhotoOCR/detection/textbpn/network/layers/resnet_dcn.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from IndicPhotoOCR.detection.textbpn.network.backbone.resnet import deformable_resnet18,deformable_resnet50
|
4 |
+
import torch.utils.model_zoo as model_zoo
|
5 |
+
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
|
6 |
+
|
7 |
+
model_urls = {
|
8 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
9 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
10 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
11 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
12 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
13 |
+
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
class ResNet_DCN(nn.Module):
|
18 |
+
def __init__(self, name="deformable_resnet18", pretrain=False):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
if name == "deformable_resnet18":
|
22 |
+
self.base_net = deformable_resnet18(pretrained=False)
|
23 |
+
if pretrain:
|
24 |
+
print("load the {} weight from ./cache".format(name))
|
25 |
+
self.base_net.load_state_dict(
|
26 |
+
model_zoo.load_url(model_urls["resnet18"], model_dir="./cache",
|
27 |
+
map_location=torch.device(cfg.device)), strict=False)
|
28 |
+
|
29 |
+
elif name == "deformable_resnet50":
|
30 |
+
self.base_net = deformable_resnet50(pretrained=False)
|
31 |
+
if pretrain:
|
32 |
+
print("load the {} weight from ./cache".format(name))
|
33 |
+
self.base_net.load_state_dict(
|
34 |
+
model_zoo.load_url(model_urls["resnet50"], model_dir="./cache",
|
35 |
+
map_location=torch.device(cfg.device)), strict=False)
|
36 |
+
else:
|
37 |
+
print(" base model is not support !")
|
38 |
+
|
39 |
+
# print(base_net)
|
40 |
+
self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
C1, C2, C3, C4, C5 = self.base_net(x)
|
44 |
+
# up2 --> 1/2
|
45 |
+
C1 = self.up2(C1)
|
46 |
+
|
47 |
+
return C1, C2, C3, C4, C5
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
import torch
|
52 |
+
input = torch.randn((4, 3, 512, 512))
|
53 |
+
net = ResNet_DCN()
|
54 |
+
C1, C2, C3, C4, C5 = net(input)
|
55 |
+
print(C1.size())
|
56 |
+
print(C2.size())
|
57 |
+
print(C3.size())
|
58 |
+
print(C4.size())
|
59 |
+
print(C5.size())
|
IndicPhotoOCR/detection/textbpn/network/layers/vgg.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.utils.model_zoo as model_zoo
|
3 |
+
import torchvision.models as models
|
4 |
+
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
|
5 |
+
|
6 |
+
model_urls = {
|
7 |
+
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
|
8 |
+
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
|
9 |
+
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
|
10 |
+
'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
|
11 |
+
'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
|
12 |
+
'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
class VggNet(nn.Module):
|
17 |
+
def __init__(self, name="vgg16", pretrain=True):
|
18 |
+
super().__init__()
|
19 |
+
if name == "vgg16":
|
20 |
+
base_net = models.vgg16(pretrained=False)
|
21 |
+
elif name == "vgg16_bn":
|
22 |
+
base_net = models.vgg16_bn(pretrained=False)
|
23 |
+
else:
|
24 |
+
print(" base model is not support !")
|
25 |
+
if pretrain:
|
26 |
+
print("load the {} weight from ./cache".format(name))
|
27 |
+
base_net.load_state_dict(model_zoo.load_url(model_urls[name],
|
28 |
+
model_dir="./cache",map_location=torch.device(cfg.device)))
|
29 |
+
|
30 |
+
if name == "vgg16":
|
31 |
+
self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 5)])
|
32 |
+
self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(5, 10)])
|
33 |
+
self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(10, 17)])
|
34 |
+
self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(17, 24)])
|
35 |
+
self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 31)])
|
36 |
+
elif name == "vgg16_bn":
|
37 |
+
self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 7)])
|
38 |
+
self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(7, 14)])
|
39 |
+
self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(14, 24)])
|
40 |
+
self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 34)])
|
41 |
+
self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(34, 44)])
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
C1 = self.stage1(x)
|
45 |
+
C2 = self.stage2(C1)
|
46 |
+
C3 = self.stage3(C2)
|
47 |
+
C4 = self.stage4(C3)
|
48 |
+
C5 = self.stage5(C4)
|
49 |
+
|
50 |
+
return C1, C2, C3, C4, C5
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
import torch
|
55 |
+
input = torch.randn((4, 3, 512, 512))
|
56 |
+
net = VggNet()
|
57 |
+
C1, C2, C3, C4, C5 = net(input)
|
58 |
+
print(C1.size())
|
59 |
+
print(C2.size())
|
60 |
+
print(C3.size())
|
61 |
+
print(C4.size())
|
62 |
+
print(C5.size())
|
IndicPhotoOCR/detection/textbpn/network/loss.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 10/1/21
|
3 |
+
# @Author : GXYM
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from cfglib.config import config as cfg
|
7 |
+
from network.Seg_loss import SegmentLoss
|
8 |
+
from network.Reg_loss import PolyMatchingLoss
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class TextLoss(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
self.MSE_loss = torch.nn.MSELoss(reduce=False, size_average=False)
|
17 |
+
self.BCE_loss = torch.nn.BCELoss(reduce=False, size_average=False)
|
18 |
+
self.PolyMatchingLoss = PolyMatchingLoss(cfg.num_points, cfg.device)
|
19 |
+
self.KL_loss = torch.nn.KLDivLoss(reduce=False, size_average=False)
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def single_image_loss(pre_loss, loss_label):
|
23 |
+
batch_size = pre_loss.shape[0]
|
24 |
+
sum_loss = torch.mean(pre_loss.view(-1)) * 0
|
25 |
+
pre_loss = pre_loss.view(batch_size, -1)
|
26 |
+
loss_label = loss_label.view(batch_size, -1)
|
27 |
+
eps = 0.001
|
28 |
+
for i in range(batch_size):
|
29 |
+
average_number = 0
|
30 |
+
positive_pixel = len(pre_loss[i][(loss_label[i] >= eps)])
|
31 |
+
average_number += positive_pixel
|
32 |
+
if positive_pixel != 0:
|
33 |
+
posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= eps)])
|
34 |
+
sum_loss += posi_loss
|
35 |
+
if len(pre_loss[i][(loss_label[i] < eps)]) < 3 * positive_pixel:
|
36 |
+
nega_loss = torch.mean(pre_loss[i][(loss_label[i] < eps)])
|
37 |
+
average_number += len(pre_loss[i][(loss_label[i] < eps)])
|
38 |
+
else:
|
39 |
+
nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < eps)], 3 * positive_pixel)[0])
|
40 |
+
average_number += 3 * positive_pixel
|
41 |
+
sum_loss += nega_loss
|
42 |
+
else:
|
43 |
+
nega_loss = torch.mean(torch.topk(pre_loss[i], 100)[0])
|
44 |
+
average_number += 100
|
45 |
+
sum_loss += nega_loss
|
46 |
+
# sum_loss += loss/average_number
|
47 |
+
|
48 |
+
return sum_loss/batch_size
|
49 |
+
|
50 |
+
def cls_ohem(self, predict, target, train_mask, negative_ratio=3.):
|
51 |
+
pos = (target * train_mask).bool()
|
52 |
+
neg = ((1 - target) * train_mask).bool()
|
53 |
+
|
54 |
+
n_pos = pos.float().sum()
|
55 |
+
if n_pos.item() > 0:
|
56 |
+
loss_pos = self.BCE_loss(predict[pos], target[pos]).sum()
|
57 |
+
loss_neg = self.BCE_loss(predict[neg], target[neg])
|
58 |
+
n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
|
59 |
+
else:
|
60 |
+
loss_pos = torch.tensor(0.)
|
61 |
+
loss_neg = self.BCE_loss(predict[neg], target[neg])
|
62 |
+
n_neg = 100
|
63 |
+
loss_neg, _ = torch.topk(loss_neg, n_neg)
|
64 |
+
|
65 |
+
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def loss_calc_flux(pred_flux, gt_flux, weight_matrix, mask, train_mask):
|
69 |
+
|
70 |
+
# norm loss
|
71 |
+
gt_flux = 0.999999 * gt_flux / (gt_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-3)
|
72 |
+
norm_loss = weight_matrix * torch.mean((pred_flux - gt_flux) ** 2, dim=1)*train_mask
|
73 |
+
norm_loss = norm_loss.sum(-1).mean()
|
74 |
+
# norm_loss = norm_loss.sum()
|
75 |
+
|
76 |
+
# angle loss
|
77 |
+
mask = train_mask * mask
|
78 |
+
pred_flux = 0.999999 * pred_flux / (pred_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-3)
|
79 |
+
# angle_loss = weight_matrix * (torch.acos(torch.sum(pred_flux * gt_flux, dim=1))) ** 2
|
80 |
+
# angle_loss = angle_loss.sum(-1).mean()
|
81 |
+
angle_loss = (1 - torch.cosine_similarity(pred_flux, gt_flux, dim=1))
|
82 |
+
angle_loss = angle_loss[mask].mean()
|
83 |
+
|
84 |
+
return norm_loss, angle_loss
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def get_poly_energy(energy_field, img_poly, ind, h, w):
|
88 |
+
img_poly = img_poly.clone().float()
|
89 |
+
img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1
|
90 |
+
img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1
|
91 |
+
|
92 |
+
batch_size = energy_field.size(0)
|
93 |
+
gcn_feature = torch.zeros([img_poly.size(0), energy_field.size(1), img_poly.size(1)]).to(img_poly.device)
|
94 |
+
for i in range(batch_size):
|
95 |
+
poly = img_poly[ind == i].unsqueeze(0)
|
96 |
+
gcn_feature[ind == i] = torch.nn.functional.grid_sample(energy_field[i:i + 1], poly)[0].permute(1, 0, 2)
|
97 |
+
return gcn_feature
|
98 |
+
|
99 |
+
def loss_energy_regularization(self, energy_field, img_poly, inds, h, w):
|
100 |
+
energys = []
|
101 |
+
for i, py in enumerate(img_poly):
|
102 |
+
energy = self.get_poly_energy(energy_field.unsqueeze(1), py, inds, h, w)
|
103 |
+
energys.append(energy.squeeze(1).sum(-1))
|
104 |
+
|
105 |
+
regular_loss = torch.tensor(0.)
|
106 |
+
energy_loss = torch.tensor(0.)
|
107 |
+
for i, e in enumerate(energys[1:]):
|
108 |
+
regular_loss += torch.clamp(e - energys[i], min=0.0).mean()
|
109 |
+
energy_loss += torch.where(e <= 0.01, torch.tensor(0.), e).mean()
|
110 |
+
|
111 |
+
return (energy_loss+regular_loss)/len(energys[1:])
|
112 |
+
|
113 |
+
def forward(self, input_dict, output_dict, eps=None):
|
114 |
+
"""
|
115 |
+
calculate boundary proposal network loss
|
116 |
+
"""
|
117 |
+
# tr_mask = tr_mask.permute(0, 3, 1, 2).contiguous()
|
118 |
+
|
119 |
+
fy_preds = output_dict["fy_preds"]
|
120 |
+
py_preds = output_dict["py_preds"]
|
121 |
+
inds = output_dict["inds"]
|
122 |
+
|
123 |
+
train_mask = input_dict['train_mask']
|
124 |
+
tr_mask = input_dict['tr_mask'] > 0
|
125 |
+
distance_field = input_dict['distance_field']
|
126 |
+
direction_field = input_dict['direction_field']
|
127 |
+
weight_matrix = input_dict['weight_matrix']
|
128 |
+
gt_tags = input_dict['gt_points']
|
129 |
+
|
130 |
+
# # scale the prediction map
|
131 |
+
# fy_preds = F.interpolate(fy_preds, scale_factor=cfg.scale, mode='bilinear')
|
132 |
+
|
133 |
+
if cfg.scale > 1:
|
134 |
+
train_mask = F.interpolate(train_mask.float().unsqueeze(1),
|
135 |
+
scale_factor=1/cfg.scale, mode='bilinear').squeeze().bool()
|
136 |
+
tr_mask = F.interpolate(tr_mask.float().unsqueeze(1),
|
137 |
+
scale_factor=1/cfg.scale, mode='bilinear').squeeze().bool()
|
138 |
+
|
139 |
+
distance_field = F.interpolate(distance_field.unsqueeze(1),
|
140 |
+
scale_factor=1/cfg.scale, mode='bilinear').squeeze()
|
141 |
+
direction_field = F.interpolate(direction_field,
|
142 |
+
scale_factor=1 / cfg.scale, mode='bilinear')
|
143 |
+
weight_matrix = F.interpolate(weight_matrix.unsqueeze(1),
|
144 |
+
scale_factor=1/cfg.scale, mode='bilinear').squeeze()
|
145 |
+
|
146 |
+
# pixel class loss
|
147 |
+
# cls_loss = self.cls_ohem(fy_preds[:, 0, :, :], tr_mask.float(), train_mask)
|
148 |
+
cls_loss = self.BCE_loss(fy_preds[:, 0, :, :], tr_mask.float())
|
149 |
+
cls_loss = torch.mul(cls_loss, train_mask.float()).mean()
|
150 |
+
|
151 |
+
# distance field loss
|
152 |
+
dis_loss = self.MSE_loss(fy_preds[:, 1, :, :], distance_field)
|
153 |
+
dis_loss = torch.mul(dis_loss, train_mask.float())
|
154 |
+
dis_loss = self.single_image_loss(dis_loss, distance_field)
|
155 |
+
|
156 |
+
# # direction field loss
|
157 |
+
norm_loss, angle_loss = self.loss_calc_flux(fy_preds[:, 2:4, :, :], direction_field,
|
158 |
+
weight_matrix, tr_mask, train_mask)
|
159 |
+
|
160 |
+
# boundary point loss
|
161 |
+
point_loss = self.PolyMatchingLoss(py_preds[1:], gt_tags[inds])
|
162 |
+
|
163 |
+
# Minimum energy loss regularization
|
164 |
+
h, w = distance_field.size(1) * cfg.scale, distance_field.size(2) * cfg.scale
|
165 |
+
energy_loss = self.loss_energy_regularization(distance_field, py_preds, inds[0], h, w)
|
166 |
+
|
167 |
+
if eps is None:
|
168 |
+
alpha = 1.0; beta = 3.0; theta=0.5; gama = 0.05
|
169 |
+
else:
|
170 |
+
alpha = 1.0; beta = 3.0; theta=0.5;
|
171 |
+
gama = 0.1*torch.sigmoid(torch.tensor((eps - cfg.max_epoch)/cfg.max_epoch))
|
172 |
+
loss = alpha*cls_loss + beta*dis_loss + theta*(norm_loss + angle_loss) + gama*(point_loss + energy_loss)
|
173 |
+
|
174 |
+
loss_dict = {
|
175 |
+
'total_loss': loss,
|
176 |
+
'cls_loss': alpha*cls_loss,
|
177 |
+
'distance loss': beta*dis_loss,
|
178 |
+
'dir_loss': theta*(norm_loss + angle_loss),
|
179 |
+
'norm_loss': theta*norm_loss,
|
180 |
+
'angle_loss': theta*angle_loss,
|
181 |
+
'point_loss': gama*point_loss,
|
182 |
+
'energy_loss': gama*energy_loss,
|
183 |
+
|
184 |
+
}
|
185 |
+
|
186 |
+
return loss_dict
|
187 |
+
|
IndicPhotoOCR/detection/textbpn/network/loss_org.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 10/1/21
|
3 |
+
# @Author : GXYM
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from cfglib.config import config as cfg
|
7 |
+
from network.Seg_loss import SegmentLoss
|
8 |
+
from network.Reg_loss import PolyMatchingLoss
|
9 |
+
|
10 |
+
|
11 |
+
class TextLoss(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self):
|
14 |
+
super().__init__()
|
15 |
+
self.MSE_loss = torch.nn.MSELoss(reduce=False, size_average=False)
|
16 |
+
self.BCE_loss = torch.nn.BCELoss(reduce=False, size_average=False)
|
17 |
+
self.PolyMatchingLoss = PolyMatchingLoss(cfg.num_points, cfg.device)
|
18 |
+
self.KL_loss = torch.nn.KLDivLoss(reduce=False, size_average=False)
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def single_image_loss(pre_loss, loss_label):
|
22 |
+
batch_size = pre_loss.shape[0]
|
23 |
+
sum_loss = torch.mean(pre_loss.view(-1)) * 0
|
24 |
+
pre_loss = pre_loss.view(batch_size, -1)
|
25 |
+
loss_label = loss_label.view(batch_size, -1)
|
26 |
+
eps = 0.001
|
27 |
+
for i in range(batch_size):
|
28 |
+
average_number = 0
|
29 |
+
positive_pixel = len(pre_loss[i][(loss_label[i] >= eps)])
|
30 |
+
average_number += positive_pixel
|
31 |
+
if positive_pixel != 0:
|
32 |
+
posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= eps)])
|
33 |
+
sum_loss += posi_loss
|
34 |
+
if len(pre_loss[i][(loss_label[i] < eps)]) < 3 * positive_pixel:
|
35 |
+
nega_loss = torch.mean(pre_loss[i][(loss_label[i] < eps)])
|
36 |
+
average_number += len(pre_loss[i][(loss_label[i] < eps)])
|
37 |
+
else:
|
38 |
+
nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < eps)], 3 * positive_pixel)[0])
|
39 |
+
average_number += 3 * positive_pixel
|
40 |
+
sum_loss += nega_loss
|
41 |
+
else:
|
42 |
+
nega_loss = torch.mean(torch.topk(pre_loss[i], 100)[0])
|
43 |
+
average_number += 100
|
44 |
+
sum_loss += nega_loss
|
45 |
+
# sum_loss += loss/average_number
|
46 |
+
|
47 |
+
return sum_loss/batch_size
|
48 |
+
|
49 |
+
def cls_ohem(self, predict, target, train_mask, negative_ratio=3.):
|
50 |
+
pos = (target * train_mask).bool()
|
51 |
+
neg = ((1 - target) * train_mask).bool()
|
52 |
+
|
53 |
+
n_pos = pos.float().sum()
|
54 |
+
|
55 |
+
if n_pos.item() > 0:
|
56 |
+
loss_pos = self.BCE_loss(predict[pos], target[pos]).sum()
|
57 |
+
loss_neg = self.BCE_loss(predict[neg], target[neg])
|
58 |
+
n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
|
59 |
+
else:
|
60 |
+
loss_pos = torch.tensor(0.)
|
61 |
+
loss_neg = self.BCE_loss(predict[neg], target[neg])
|
62 |
+
n_neg = 100
|
63 |
+
loss_neg, _ = torch.topk(loss_neg, n_neg)
|
64 |
+
|
65 |
+
return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def loss_calc_flux(pred_flux, gt_flux, weight_matrix, mask, train_mask):
|
69 |
+
|
70 |
+
# norm loss
|
71 |
+
gt_flux = 0.999999 * gt_flux / (gt_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-9)
|
72 |
+
norm_loss = weight_matrix * torch.sum((pred_flux - gt_flux) ** 2, dim=1)*train_mask
|
73 |
+
norm_loss = norm_loss.sum(-1).mean()
|
74 |
+
|
75 |
+
# angle loss
|
76 |
+
mask = train_mask * mask
|
77 |
+
pred_flux = 0.999999 * pred_flux / (pred_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-9)
|
78 |
+
# angle_loss = weight_matrix * (torch.acos(torch.sum(pred_flux * gt_flux, dim=1))) ** 2
|
79 |
+
# angle_loss = angle_loss.sum(-1).mean()
|
80 |
+
angle_loss = (1 - torch.cosine_similarity(pred_flux, gt_flux, dim=1))
|
81 |
+
angle_loss = angle_loss[mask].mean()
|
82 |
+
|
83 |
+
return norm_loss, angle_loss
|
84 |
+
|
85 |
+
def forward(self, input_dict, output_dict, eps=None):
|
86 |
+
"""
|
87 |
+
calculate boundary proposal network loss
|
88 |
+
"""
|
89 |
+
# tr_mask = tr_mask.permute(0, 3, 1, 2).contiguous()
|
90 |
+
|
91 |
+
fy_preds = output_dict["fy_preds"]
|
92 |
+
py_preds = output_dict["py_preds"]
|
93 |
+
inds = output_dict["inds"]
|
94 |
+
|
95 |
+
train_mask = input_dict['train_mask']
|
96 |
+
tr_mask = input_dict['tr_mask'] > 0
|
97 |
+
distance_field = input_dict['distance_field']
|
98 |
+
direction_field = input_dict['direction_field']
|
99 |
+
weight_matrix = input_dict['weight_matrix']
|
100 |
+
gt_tags = input_dict['gt_points']
|
101 |
+
|
102 |
+
# pixel class loss
|
103 |
+
cls_loss = self.cls_ohem(fy_preds[:, 0, :, :], tr_mask.float(), train_mask.bool())
|
104 |
+
|
105 |
+
# distance field loss
|
106 |
+
dis_loss = self.MSE_loss(fy_preds[:, 1, :, :], distance_field)
|
107 |
+
dis_loss = torch.mul(dis_loss, train_mask.float())
|
108 |
+
dis_loss = self.single_image_loss(dis_loss, distance_field)
|
109 |
+
|
110 |
+
# direction field loss
|
111 |
+
norm_loss, angle_loss = self.loss_calc_flux(fy_preds[:, 2:4, :, :],
|
112 |
+
direction_field, weight_matrix, tr_mask, train_mask)
|
113 |
+
|
114 |
+
# boundary point loss
|
115 |
+
point_loss = self.PolyMatchingLoss(py_preds, gt_tags[inds])
|
116 |
+
|
117 |
+
if eps is None:
|
118 |
+
loss_b = 0.05*point_loss
|
119 |
+
loss = cls_loss + 3.0*dis_loss + norm_loss + angle_loss + loss_b
|
120 |
+
else:
|
121 |
+
loss_b = 0.1*(torch.sigmoid(torch.tensor((eps - cfg.max_epoch)/cfg.max_epoch))) * point_loss
|
122 |
+
loss = cls_loss + 3.0*dis_loss + norm_loss + angle_loss + loss_b
|
123 |
+
|
124 |
+
loss_dict = {
|
125 |
+
'total_loss': loss,
|
126 |
+
'cls_loss': cls_loss,
|
127 |
+
'distance loss': 3.0*dis_loss,
|
128 |
+
'dir_loss': norm_loss + angle_loss,
|
129 |
+
'point_loss': loss_b,
|
130 |
+
'norm_loss': norm_loss,
|
131 |
+
'angle_loss': angle_loss,
|
132 |
+
|
133 |
+
}
|
134 |
+
|
135 |
+
return loss_dict
|
136 |
+
|
IndicPhotoOCR/detection/textbpn/network/textnet.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 10/1/21
|
3 |
+
# @Author : GXYM
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.model_block import FPN
|
7 |
+
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
|
8 |
+
import numpy as np
|
9 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.CircConv import DeepSnake
|
10 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.GCN import GCN
|
11 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.RNN import RNN
|
12 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.Adaptive_Deformation import AdaptiveDeformation
|
13 |
+
# from IndicPhotoOCR.detection.textbpn.network.layers.Transformer_old import Transformer_old
|
14 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.Transformer import Transformer
|
15 |
+
import cv2
|
16 |
+
from IndicPhotoOCR.detection.textbpn.util.misc import get_sample_point, fill_hole
|
17 |
+
from IndicPhotoOCR.detection.textbpn.network.layers.gcn_utils import get_node_feature, \
|
18 |
+
get_adj_mat, get_adj_ind, coord_embedding, normalize_adj
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import time
|
21 |
+
|
22 |
+
|
23 |
+
class Evolution(nn.Module):
|
24 |
+
def __init__(self, node_num, adj_num, is_training=True, device=None, model="snake"):
|
25 |
+
super(Evolution, self).__init__()
|
26 |
+
self.node_num = node_num
|
27 |
+
self.adj_num = adj_num
|
28 |
+
self.device = device
|
29 |
+
self.is_training = is_training
|
30 |
+
self.clip_dis = 16
|
31 |
+
|
32 |
+
self.iter = 3
|
33 |
+
if model == "gcn":
|
34 |
+
self.adj = get_adj_mat(self.adj_num, self.node_num)
|
35 |
+
self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device)
|
36 |
+
for i in range(self.iter):
|
37 |
+
evolve_gcn = GCN(36, 128)
|
38 |
+
self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
|
39 |
+
elif model == "rnn":
|
40 |
+
self.adj = None
|
41 |
+
for i in range(self.iter):
|
42 |
+
evolve_gcn = RNN(36, 128)
|
43 |
+
self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
|
44 |
+
elif model == "AD":
|
45 |
+
self.adj = get_adj_mat(self.adj_num, self.node_num)
|
46 |
+
self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device)
|
47 |
+
for i in range(self.iter):
|
48 |
+
evolve_gcn = AdaptiveDeformation(36, 128)
|
49 |
+
self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
|
50 |
+
# elif model == "BT_old":
|
51 |
+
# self.adj = None
|
52 |
+
# for i in range(self.iter):
|
53 |
+
# evolve_gcn = Transformer_old(36, 512, num_heads=8,
|
54 |
+
# dim_feedforward=2048, drop_rate=0.0, if_resi=True, block_nums=4)
|
55 |
+
# self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
|
56 |
+
elif model == "BT":
|
57 |
+
self.adj = None
|
58 |
+
for i in range(self.iter):
|
59 |
+
evolve_gcn = Transformer(36, 128, num_heads=8,
|
60 |
+
dim_feedforward=1024, drop_rate=0.0, if_resi=True, block_nums=3)
|
61 |
+
self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
|
62 |
+
else:
|
63 |
+
self.adj = get_adj_ind(self.adj_num, self.node_num, self.device)
|
64 |
+
for i in range(self.iter):
|
65 |
+
evolve_gcn = DeepSnake(state_dim=128, feature_dim=36, conv_type='dgrid')
|
66 |
+
self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
|
67 |
+
|
68 |
+
for m in self.modules():
|
69 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
|
70 |
+
m.weight.data.normal_(0.0, 0.02)
|
71 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_in')
|
72 |
+
if m.bias is not None:
|
73 |
+
nn.init.constant_(m.bias, 0)
|
74 |
+
|
75 |
+
@staticmethod
|
76 |
+
def get_boundary_proposal(input=None, seg_preds=None, switch="gt"):
|
77 |
+
|
78 |
+
if switch == "gt":
|
79 |
+
inds = torch.where(input['ignore_tags'] > 0)
|
80 |
+
# if len(inds[0]) > 320:
|
81 |
+
# inds = (inds[0][:320], inds[1][:320])
|
82 |
+
init_polys = input['proposal_points'][inds]
|
83 |
+
else:
|
84 |
+
tr_masks = input['tr_mask'].cpu().numpy()
|
85 |
+
tcl_masks = seg_preds[:, 0, :, :].detach().cpu().numpy() > cfg.threshold
|
86 |
+
inds = []
|
87 |
+
init_polys = []
|
88 |
+
for bid, tcl_mask in enumerate(tcl_masks):
|
89 |
+
ret, labels = cv2.connectedComponents(tcl_mask.astype(np.uint8), connectivity=8)
|
90 |
+
for idx in range(1, ret):
|
91 |
+
text_mask = labels == idx
|
92 |
+
ist_id = int(np.sum(text_mask*tr_masks[bid])/np.sum(text_mask))-1
|
93 |
+
inds.append([bid, ist_id])
|
94 |
+
poly = get_sample_point(text_mask, cfg.num_points, cfg.approx_factor)
|
95 |
+
init_polys.append(poly)
|
96 |
+
inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device)
|
97 |
+
init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device)
|
98 |
+
|
99 |
+
return init_polys, inds, None
|
100 |
+
|
101 |
+
def get_boundary_proposal_eval(self, input=None, seg_preds=None):
|
102 |
+
|
103 |
+
# if cfg.scale > 1:
|
104 |
+
# seg_preds = F.interpolate(seg_preds, scale_factor=cfg.scale, mode='bilinear')
|
105 |
+
cls_preds = seg_preds[:, 0, :, :].detach().cpu().numpy()
|
106 |
+
dis_preds = seg_preds[:, 1, :, ].detach().cpu().numpy()
|
107 |
+
|
108 |
+
inds = []
|
109 |
+
init_polys = []
|
110 |
+
confidences = []
|
111 |
+
for bid, dis_pred in enumerate(dis_preds):
|
112 |
+
# # dis_mask = (dis_pred / np.max(dis_pred)) > cfg.dis_threshold
|
113 |
+
dis_mask = dis_pred > cfg.dis_threshold
|
114 |
+
# dis_mask = fill_hole(dis_mask)
|
115 |
+
ret, labels = cv2.connectedComponents(dis_mask.astype(np.uint8), connectivity=8, ltype=cv2.CV_16U)
|
116 |
+
for idx in range(1, ret):
|
117 |
+
text_mask = labels == idx
|
118 |
+
confidence = round(cls_preds[bid][text_mask].mean(), 3)
|
119 |
+
# 50 for MLT2017 and ArT (or DCN is used in backone); else is all 150;
|
120 |
+
# just can set to 50, which has little effect on the performance
|
121 |
+
if np.sum(text_mask) < 50/(cfg.scale*cfg.scale) or confidence < cfg.cls_threshold:
|
122 |
+
continue
|
123 |
+
confidences.append(confidence)
|
124 |
+
inds.append([bid, 0])
|
125 |
+
|
126 |
+
poly = get_sample_point(text_mask, cfg.num_points,
|
127 |
+
cfg.approx_factor, scales=np.array([cfg.scale, cfg.scale]))
|
128 |
+
init_polys.append(poly)
|
129 |
+
|
130 |
+
if len(inds) > 0:
|
131 |
+
inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device, non_blocking=True)
|
132 |
+
init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float()
|
133 |
+
else:
|
134 |
+
init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float()
|
135 |
+
inds = torch.from_numpy(np.array(inds)).to(input["img"].device, non_blocking=True)
|
136 |
+
|
137 |
+
return init_polys, inds, confidences
|
138 |
+
|
139 |
+
def evolve_poly(self, snake, cnn_feature, i_it_poly, ind):
|
140 |
+
if len(i_it_poly) == 0:
|
141 |
+
return torch.zeros_like(i_it_poly)
|
142 |
+
h, w = cnn_feature.size(2)*cfg.scale, cnn_feature.size(3)*cfg.scale
|
143 |
+
node_feats = get_node_feature(cnn_feature, i_it_poly, ind, h, w)
|
144 |
+
i_poly = i_it_poly + torch.clamp(snake(node_feats, self.adj).permute(0, 2, 1), -self.clip_dis, self.clip_dis)
|
145 |
+
if self.is_training:
|
146 |
+
i_poly = torch.clamp(i_poly, 0, w-1)
|
147 |
+
else:
|
148 |
+
i_poly[:, :, 0] = torch.clamp(i_poly[:, :, 0], 0, w - 1)
|
149 |
+
i_poly[:, :, 1] = torch.clamp(i_poly[:, :, 1], 0, h - 1)
|
150 |
+
return i_poly
|
151 |
+
|
152 |
+
def forward(self, embed_feature, input=None, seg_preds=None, switch="gt"):
|
153 |
+
if self.is_training:
|
154 |
+
init_polys, inds, confidences = self.get_boundary_proposal(input=input, seg_preds=seg_preds, switch=switch)
|
155 |
+
# TODO sample fix number
|
156 |
+
else:
|
157 |
+
init_polys, inds, confidences = self.get_boundary_proposal_eval(input=input, seg_preds=seg_preds)
|
158 |
+
if init_polys.shape[0] == 0:
|
159 |
+
return [init_polys for i in range(self.iter+1)], inds, confidences
|
160 |
+
|
161 |
+
py_preds = [init_polys, ]
|
162 |
+
for i in range(self.iter):
|
163 |
+
evolve_gcn = self.__getattr__('evolve_gcn' + str(i))
|
164 |
+
init_polys = self.evolve_poly(evolve_gcn, embed_feature, init_polys, inds[0])
|
165 |
+
py_preds.append(init_polys)
|
166 |
+
|
167 |
+
return py_preds, inds, confidences
|
168 |
+
|
169 |
+
|
170 |
+
class TextNet(nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, backbone='vgg', is_training=True):
|
173 |
+
super().__init__()
|
174 |
+
self.is_training = is_training
|
175 |
+
self.backbone_name = backbone
|
176 |
+
self.fpn = FPN(self.backbone_name, is_training=(not cfg.resume and is_training))
|
177 |
+
|
178 |
+
self.seg_head = nn.Sequential(
|
179 |
+
nn.Conv2d(32, 16, kernel_size=3, padding=2, dilation=2),
|
180 |
+
nn.PReLU(),
|
181 |
+
nn.Conv2d(16, 16, kernel_size=3, padding=4, dilation=4),
|
182 |
+
nn.PReLU(),
|
183 |
+
nn.Conv2d(16, 4, kernel_size=1, stride=1, padding=0),
|
184 |
+
)
|
185 |
+
self.BPN = Evolution(cfg.num_points, adj_num=4,
|
186 |
+
is_training=is_training, device=cfg.device, model="BT")
|
187 |
+
|
188 |
+
def load_model(self, model_path):
|
189 |
+
print('Loading from {}'.format(model_path))
|
190 |
+
state_dict = torch.load(model_path, map_location=torch.device(cfg.device))
|
191 |
+
self.load_state_dict(state_dict['model'], strict=(not self.is_training))
|
192 |
+
|
193 |
+
def forward(self, input_dict, test_speed=False):
|
194 |
+
output = {}
|
195 |
+
b, c, h, w = input_dict["img"].shape
|
196 |
+
if self.is_training or cfg.exp_name in ['ArT', 'MLT2017', "MLT2019"] or test_speed:
|
197 |
+
image = input_dict["img"]
|
198 |
+
else:
|
199 |
+
image = torch.zeros((b, c, cfg.test_size[1], cfg.test_size[1]), dtype=torch.float32).to(cfg.device)
|
200 |
+
image[:, :, :h, :w] = input_dict["img"][:, :, :, :]
|
201 |
+
|
202 |
+
up1, _, _, _, _ = self.fpn(image)
|
203 |
+
up1 = up1[:, :, :h // cfg.scale, :w // cfg.scale]
|
204 |
+
|
205 |
+
preds = self.seg_head(up1)
|
206 |
+
fy_preds = torch.cat([torch.sigmoid(preds[:, 0:2, :, :]), preds[:, 2:4, :, :]], dim=1)
|
207 |
+
cnn_feats = torch.cat([up1, fy_preds], dim=1)
|
208 |
+
|
209 |
+
py_preds, inds, confidences = self.BPN(cnn_feats, input=input_dict, seg_preds=fy_preds, switch="gt")
|
210 |
+
|
211 |
+
output["fy_preds"] = fy_preds
|
212 |
+
output["py_preds"] = py_preds
|
213 |
+
output["inds"] = inds
|
214 |
+
output["confidences"] = confidences
|
215 |
+
|
216 |
+
return output
|
IndicPhotoOCR/detection/textbpn/output.png
ADDED
Git LFS Details
|
IndicPhotoOCR/detection/textbpn/textbpnpp_detector.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from IndicPhotoOCR.detection.textbpn.network.textnet import TextNet
|
5 |
+
from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
|
6 |
+
import warnings
|
7 |
+
import os
|
8 |
+
import requests
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
# Suppress warnings
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
model_info = {
|
15 |
+
"textbpnpp": {
|
16 |
+
"path": "models/TextBPN_resnet50_300.pth",
|
17 |
+
"url" : "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet50_300.pth",
|
18 |
+
},
|
19 |
+
"textbpnpp_deformable": {
|
20 |
+
"path":"models/TextBPN_deformable_resnet50_300.pth",
|
21 |
+
"url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_deformable_resnet50_300.pth",
|
22 |
+
},
|
23 |
+
"textbpn_resnet18" : {
|
24 |
+
"path":"models/TextBPN_resnet18_300.pth",
|
25 |
+
"url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet18_300.pth",
|
26 |
+
|
27 |
+
}
|
28 |
+
}
|
29 |
+
# Ensure model file exists; download directly if not
|
30 |
+
def ensure_model(model_name):
|
31 |
+
model_path = model_info[model_name]["path"]
|
32 |
+
url = model_info[model_name]["url"]
|
33 |
+
root_model_dir = "IndicPhotoOCR/detection/textbpn"
|
34 |
+
model_path = os.path.join(root_model_dir, model_path)
|
35 |
+
|
36 |
+
if not os.path.exists(model_path):
|
37 |
+
print(f"Model not found locally. Downloading {model_name} from {url}...")
|
38 |
+
|
39 |
+
# Start the download with a progress bar
|
40 |
+
response = requests.get(url, stream=True)
|
41 |
+
total_size = int(response.headers.get('content-length', 0))
|
42 |
+
os.makedirs(f"{root_model_dir}/models", exist_ok=True)
|
43 |
+
|
44 |
+
with open(model_path, "wb") as f, tqdm(
|
45 |
+
desc=model_name,
|
46 |
+
total=total_size,
|
47 |
+
unit='B',
|
48 |
+
unit_scale=True,
|
49 |
+
unit_divisor=1024,
|
50 |
+
) as bar:
|
51 |
+
for data in response.iter_content(chunk_size=1024):
|
52 |
+
f.write(data)
|
53 |
+
bar.update(len(data))
|
54 |
+
|
55 |
+
print(f"Downloaded model for {model_name}.")
|
56 |
+
|
57 |
+
return model_path
|
58 |
+
|
59 |
+
class TextBPNpp_detector:
|
60 |
+
def __init__(self, model_name="textbpnpp", backbone="resnet50", device="cpu"):
|
61 |
+
"""
|
62 |
+
Initialize the TextBPN model.
|
63 |
+
:param model_path: Path to the pre-trained model.
|
64 |
+
:param backbone: Backbone architecture (default: "resnet50").
|
65 |
+
:param device: Device to run the model on (default: "cpu").
|
66 |
+
"""
|
67 |
+
self.model_path = ensure_model(model_name)
|
68 |
+
self.device = torch.device(device)
|
69 |
+
self.model = TextNet(is_training=False, backbone=backbone)
|
70 |
+
self.model.load_model(self.model_path)
|
71 |
+
self.model.eval()
|
72 |
+
self.model.to(self.device)
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def to_device(tensor, device):
|
76 |
+
"""
|
77 |
+
Move tensor to the specified device.
|
78 |
+
:param tensor: Tensor to move.
|
79 |
+
:param device: Target device.
|
80 |
+
:return: Tensor on the target device.
|
81 |
+
"""
|
82 |
+
return tensor.to(device, non_blocking=True)
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def pad_image(image, stride=32):
|
86 |
+
"""
|
87 |
+
Pad the image to make its dimensions divisible by the stride.
|
88 |
+
:param image: Input image.
|
89 |
+
:param stride: Stride size.
|
90 |
+
:return: Padded image and original dimensions.
|
91 |
+
"""
|
92 |
+
h, w = image.shape[:2]
|
93 |
+
new_h = (h + stride - 1) // stride * stride
|
94 |
+
new_w = (w + stride - 1) // stride * stride
|
95 |
+
padded_image = cv2.copyMakeBorder(
|
96 |
+
image, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=(0, 0, 0)
|
97 |
+
)
|
98 |
+
return padded_image, (h, w)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def rescale_result(image, bbox_contours, original_height, original_width):
|
102 |
+
"""
|
103 |
+
Rescale the bounding box contours to the original image size.
|
104 |
+
:param image: Image after resizing.
|
105 |
+
:param bbox_contours: Bounding box contours.
|
106 |
+
:param original_height: Original image height.
|
107 |
+
:param original_width: Original image width.
|
108 |
+
:return: Original image and rescaled contours.
|
109 |
+
"""
|
110 |
+
contours = []
|
111 |
+
for cont in bbox_contours:
|
112 |
+
cont[:, 0] = (cont[:, 0] * original_width / image.shape[1]).astype(int)
|
113 |
+
cont[:, 1] = (cont[:, 1] * original_height / image.shape[0]).astype(int)
|
114 |
+
contours.append(cont)
|
115 |
+
return contours
|
116 |
+
|
117 |
+
def detect(self, image_path):
|
118 |
+
"""
|
119 |
+
Perform text detection on the given image.
|
120 |
+
:param image_path: Path to the input image.
|
121 |
+
:return: Dictionary with detection results.
|
122 |
+
"""
|
123 |
+
image = cv2.imread(image_path)
|
124 |
+
if image is None:
|
125 |
+
raise ValueError(f"Failed to read the image at {image_path}")
|
126 |
+
|
127 |
+
padded_image, original_size = self.pad_image(image)
|
128 |
+
padded_tensor = (
|
129 |
+
torch.from_numpy(padded_image).permute(2, 0, 1).float() / 255.0
|
130 |
+
).unsqueeze(0) # Convert to tensor and add batch dimension
|
131 |
+
|
132 |
+
cfg.test_size = [padded_image.shape[0], padded_image.shape[1]]
|
133 |
+
|
134 |
+
input_dict = {"img": self.to_device(padded_tensor, self.device)}
|
135 |
+
with torch.no_grad():
|
136 |
+
output_dict = self.model(input_dict, padded_image.shape)
|
137 |
+
|
138 |
+
contours = output_dict["py_preds"][-1].int().cpu().numpy()
|
139 |
+
contours = self.rescale_result(image, contours, *original_size)
|
140 |
+
|
141 |
+
bbox_result_dict = {"detections": []}
|
142 |
+
for contour in contours:
|
143 |
+
# x_min, y_min = np.min(contour, axis=0)
|
144 |
+
# x_max, y_max = np.max(contour, axis=0)
|
145 |
+
# bbox_result_dict["detections"].append([x_min, y_min, x_max, y_max])
|
146 |
+
bbox_result_dict["detections"].append(contour.tolist())
|
147 |
+
|
148 |
+
return bbox_result_dict
|
149 |
+
|
150 |
+
def visualize_detections(self, image_path, bbox_result_dict, output_path="output.png"):
|
151 |
+
"""
|
152 |
+
Visualize detections on the image.
|
153 |
+
:param image_path: Path to the input image.
|
154 |
+
:param bbox_result_dict: Detection results in the format:
|
155 |
+
{'detections': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...]}.
|
156 |
+
:param output_path: Path to save the visualized image. If None, the image is only displayed.
|
157 |
+
"""
|
158 |
+
# Load the image
|
159 |
+
image = cv2.imread(image_path)
|
160 |
+
if image is None:
|
161 |
+
raise ValueError(f"Failed to read the image at {image_path}")
|
162 |
+
|
163 |
+
# Draw each detection
|
164 |
+
for bbox in bbox_result_dict.get("detections", []):
|
165 |
+
points = np.array(bbox, dtype=np.int32) # Convert to numpy array
|
166 |
+
cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=2)
|
167 |
+
|
168 |
+
# Display or save the visualized image
|
169 |
+
if output_path:
|
170 |
+
cv2.imwrite(output_path, image)
|
171 |
+
print(f"Visualization saved to {output_path}")
|
172 |
+
else:
|
173 |
+
cv2.imshow("Detections", image)
|
174 |
+
cv2.waitKey(0)
|
175 |
+
cv2.destroyAllWindows()
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
import argparse
|
179 |
+
parser = argparse.ArgumentParser(description='Text detection using EAST model')
|
180 |
+
parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
|
181 |
+
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
|
182 |
+
parser.add_argument('--model_name', type=str, required=True, help='Path to the model checkpoint file')
|
183 |
+
args = parser.parse_args()
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
# model_path = "/DATA1/ocrteam/anik/git/IndicPhotoOCR/IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth"
|
188 |
+
# image_path = "/DATA1/ocrteam/anik/splitonBSTD/detection/D/image_542.jpg"
|
189 |
+
|
190 |
+
detector = TextBPNpp_detector(args.model_name, device="cpu")
|
191 |
+
result = detector.detect(args.image_path)
|
192 |
+
print(result)
|
193 |
+
# detector.visualize_detections(image_path, result)
|
194 |
+
|
195 |
+
# python -m IndicPhotoOCR.detection.textbpn.textbpnpp_detector \
|
196 |
+
# --image_path /DATA1/ocrteam/anik/splitonBSTD/detection/D/image_542.jpg \
|
197 |
+
# --model_name textbpnpp
|
IndicPhotoOCR/detection/textbpn/util/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .visualize import *
|
2 |
+
from .pbox import *
|
IndicPhotoOCR/detection/textbpn/util/augmentation.py
ADDED
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
__author__ = "S.X.Zhang"
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
import cv2
|
6 |
+
import copy
|
7 |
+
import numpy.random as random
|
8 |
+
from shapely.geometry import Polygon
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
import torchvision.transforms.functional as F
|
11 |
+
from PIL import ImageEnhance, Image
|
12 |
+
|
13 |
+
|
14 |
+
###<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<###
|
15 |
+
###<<<<<<<<< Function >>>>>>>>>>>>###
|
16 |
+
###>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>###
|
17 |
+
def crop_first(image, polygons, scale =10):
|
18 |
+
polygons_new = copy.deepcopy(polygons)
|
19 |
+
h, w, _ = image.shape
|
20 |
+
pad_h = h // scale
|
21 |
+
pad_w = w // scale
|
22 |
+
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
23 |
+
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
24 |
+
|
25 |
+
text_polys = []
|
26 |
+
pos_polys = []
|
27 |
+
for polygon in polygons_new:
|
28 |
+
rect = cv2.minAreaRect(polygon.points.astype(np.int32))
|
29 |
+
box = cv2.boxPoints(rect)
|
30 |
+
box = np.int0(box)
|
31 |
+
text_polys.append([box[0], box[1], box[2], box[3]])
|
32 |
+
if polygon.label != -1:
|
33 |
+
pos_polys.append([box[0], box[1], box[2], box[3]])
|
34 |
+
|
35 |
+
polys = np.array(text_polys, dtype=np.int32)
|
36 |
+
for poly in polys:
|
37 |
+
poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入
|
38 |
+
minx = np.min(poly[:, 0])
|
39 |
+
maxx = np.max(poly[:, 0])
|
40 |
+
w_array[minx + pad_w:maxx + pad_w] = 1
|
41 |
+
miny = np.min(poly[:, 1])
|
42 |
+
maxy = np.max(poly[:, 1])
|
43 |
+
h_array[miny + pad_h:maxy + pad_h] = 1
|
44 |
+
# ensure the cropped area not across a text 保证截取区域不会横穿文字
|
45 |
+
h_axis = np.where(h_array == 0)[0]
|
46 |
+
w_axis = np.where(w_array == 0)[0]
|
47 |
+
pp_polys = np.array(pos_polys, dtype=np.int32)
|
48 |
+
|
49 |
+
return h_axis, w_axis, pp_polys
|
50 |
+
|
51 |
+
####<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<####
|
52 |
+
####<<<<<<<<<<< Class >>>>>>>>>>>>>####
|
53 |
+
####>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>####
|
54 |
+
class Compose(object):
|
55 |
+
"""Composes several augmentations together.
|
56 |
+
Args:
|
57 |
+
transforms (List[Transform]): list of transforms to compose.
|
58 |
+
Example:
|
59 |
+
>>> augmentations.Compose([
|
60 |
+
>>> transforms.CenterCrop(10),
|
61 |
+
>>> transforms.ToTensor(),
|
62 |
+
>>> ])
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(self, transforms):
|
66 |
+
self.transforms = transforms
|
67 |
+
|
68 |
+
def __call__(self, img, pts=None):
|
69 |
+
for t in self.transforms:
|
70 |
+
img, pts = t(img, pts)
|
71 |
+
return img, pts
|
72 |
+
|
73 |
+
|
74 |
+
class Normalize(object):
|
75 |
+
def __init__(self, mean, std):
|
76 |
+
self.mean = np.array(mean)
|
77 |
+
self.std = np.array(std)
|
78 |
+
|
79 |
+
def __call__(self, image, polygons=None):
|
80 |
+
image = image.astype(np.float32)
|
81 |
+
image /= 255.0
|
82 |
+
image -= self.mean
|
83 |
+
image /= self.std
|
84 |
+
return image, polygons
|
85 |
+
|
86 |
+
|
87 |
+
class MinusMean(object):
|
88 |
+
def __init__(self, mean):
|
89 |
+
self.mean = np.array(mean)
|
90 |
+
|
91 |
+
def __call__(self, image, polygons=None):
|
92 |
+
image = image.astype(np.float32)
|
93 |
+
image -= self.mean
|
94 |
+
return image, polygons
|
95 |
+
|
96 |
+
|
97 |
+
class RandomMirror(object):
|
98 |
+
# 镜像
|
99 |
+
def __init__(self):
|
100 |
+
pass
|
101 |
+
|
102 |
+
def __call__(self, image, polygons=None):
|
103 |
+
if polygons is None:
|
104 |
+
return image, polygons
|
105 |
+
if random.random()< 0.3:
|
106 |
+
image = np.ascontiguousarray(image[:, ::-1])
|
107 |
+
_, width, _ = image.shape
|
108 |
+
for polygon in polygons:
|
109 |
+
polygon.points[:, 0] = width - polygon.points[:, 0]
|
110 |
+
return image, polygons
|
111 |
+
|
112 |
+
|
113 |
+
class AugmentColor(object):
|
114 |
+
# 颜色增强(添加噪声)
|
115 |
+
def __init__(self):
|
116 |
+
self.U = np.array([[-0.56543481, 0.71983482, 0.40240142],
|
117 |
+
[-0.5989477, -0.02304967, -0.80036049],
|
118 |
+
[-0.56694071, -0.6935729, 0.44423429]], dtype=np.float32)
|
119 |
+
self.EV = np.array([1.65513492, 0.48450358, 0.1565086], dtype=np.float32)
|
120 |
+
self.sigma = 0.1
|
121 |
+
self.color_vec = None
|
122 |
+
|
123 |
+
def __call__(self, img, polygons=None):
|
124 |
+
color_vec = self.color_vec
|
125 |
+
if self.color_vec is None:
|
126 |
+
if not self.sigma > 0.0:
|
127 |
+
color_vec = np.zeros(3, dtype=np.float32)
|
128 |
+
else:
|
129 |
+
color_vec = np.random.normal(0.0, self.sigma, 3)
|
130 |
+
|
131 |
+
alpha = color_vec.astype(np.float32) * self.EV
|
132 |
+
noise = np.dot(self.U, alpha.T) * 255
|
133 |
+
return np.clip(img + noise[np.newaxis, np.newaxis, :], 0, 255), polygons
|
134 |
+
|
135 |
+
|
136 |
+
class RandomContrast(object):
|
137 |
+
def __init__(self, lower=0.5, upper=1.5):
|
138 |
+
self.lower = lower
|
139 |
+
self.upper = upper
|
140 |
+
assert self.upper >= self.lower, "contrast upper must be >= lower."
|
141 |
+
assert self.lower >= 0, "contrast lower must be non-negative."
|
142 |
+
|
143 |
+
# expects float image
|
144 |
+
def __call__(self, image, polygons=None):
|
145 |
+
if random.randint(2):
|
146 |
+
alpha = random.uniform(self.lower, self.upper)
|
147 |
+
image *= alpha
|
148 |
+
return np.clip(image, 0, 255), polygons
|
149 |
+
|
150 |
+
|
151 |
+
class RandomBrightness(object):
|
152 |
+
def __init__(self, delta=32):
|
153 |
+
assert delta >= 0.0
|
154 |
+
assert delta <= 255.0
|
155 |
+
self.delta = delta
|
156 |
+
|
157 |
+
def __call__(self, image, polygons=None):
|
158 |
+
image = image.astype(np.float32)
|
159 |
+
if random.randint(2):
|
160 |
+
delta = random.uniform(-self.delta, self.delta)
|
161 |
+
image += delta
|
162 |
+
return np.clip(image, 0, 255), polygons
|
163 |
+
|
164 |
+
|
165 |
+
class RandomErasing(object):
|
166 |
+
def __init__(self, sr=(0.0004, 0.01), scale=(0.5, 3), ratio=0.2, Type ="Erasing"):
|
167 |
+
"""
|
168 |
+
|
169 |
+
:param area:
|
170 |
+
:param type: Erasing or Cutout
|
171 |
+
"""
|
172 |
+
self.sr = sr
|
173 |
+
self.scale= scale
|
174 |
+
self.ratio=ratio
|
175 |
+
self.type=Type
|
176 |
+
|
177 |
+
def __call__(self, img, polygons=None):
|
178 |
+
|
179 |
+
if random.random()< self.ratio:
|
180 |
+
return img, polygons
|
181 |
+
area=img.shape[0]*img.shape[1]
|
182 |
+
target_area=random.randint(*self.sr)*area
|
183 |
+
aspect_ratio=random.uniform(*self.scale)
|
184 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
185 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
186 |
+
|
187 |
+
if w < img.shape[1] and h < img.shape[0]:
|
188 |
+
x1 = random.randint(0, img.shape[1] - w)
|
189 |
+
y1 = random.randint(0, img.shape[0] - h)
|
190 |
+
if self.type == "Erasing":
|
191 |
+
color=(random.randint(0, 255),random.randint(0, 255),random.randint(0, 255))
|
192 |
+
img[y1:y1+h, x1:x1+h,:]=color
|
193 |
+
else:
|
194 |
+
Gray_value=random.randint(0, 255)
|
195 |
+
color = (Gray_value, Gray_value ,Gray_value)
|
196 |
+
img[y1:y1 + h, x1:x1 + h, :] = color
|
197 |
+
|
198 |
+
return img, polygons
|
199 |
+
|
200 |
+
|
201 |
+
class RandomMixUp(object):
|
202 |
+
def __init__(self, mixup_alpha=2):
|
203 |
+
self.mixup_alpha = mixup_alpha
|
204 |
+
|
205 |
+
def __call__(self, img1, img2, label1=[], label2=[]):
|
206 |
+
beta=np.random.beta(self.mixup_alpha,self.mixup_alpha)
|
207 |
+
|
208 |
+
#image = img1 * Gama + (1 - Gama) * img2
|
209 |
+
image=cv2.addWeighted(img1, beta, img2, (1-beta), 0)
|
210 |
+
|
211 |
+
if label1 is None or label2 is None:
|
212 |
+
return img1, label1
|
213 |
+
if isinstance(label1, list) and isinstance(label2, list):
|
214 |
+
label=[]
|
215 |
+
for id in range(len(label1)):
|
216 |
+
lab = beta*label1[id]+ (1-beta)*label2[id]
|
217 |
+
label.append(lab)
|
218 |
+
return image, label
|
219 |
+
else:
|
220 |
+
print("Error: label is not a list type")
|
221 |
+
|
222 |
+
return img1, label1
|
223 |
+
|
224 |
+
|
225 |
+
class Rotate(object):
|
226 |
+
def __init__(self, up=30):
|
227 |
+
self.up = up
|
228 |
+
|
229 |
+
@staticmethod
|
230 |
+
def rotate(center, pt, theta): # 二维图形学的旋转
|
231 |
+
xr, yr = center
|
232 |
+
yr = -yr
|
233 |
+
x, y = pt[:, 0], pt[:, 1]
|
234 |
+
y = -y
|
235 |
+
|
236 |
+
theta = theta / 180 * math.pi
|
237 |
+
cos = math.cos(theta)
|
238 |
+
sin = math.sin(theta)
|
239 |
+
|
240 |
+
_x = xr + (x - xr) * cos - (y - yr) * sin
|
241 |
+
_y = yr + (x - xr) * sin + (y - yr) * cos
|
242 |
+
|
243 |
+
return _x, -_y
|
244 |
+
|
245 |
+
def __call__(self, img, polygons=None):
|
246 |
+
if np.random.randint(2):
|
247 |
+
return img, polygons
|
248 |
+
angle = np.random.normal(loc=0.0, scale=0.5) * self.up # angle 按照高斯分布
|
249 |
+
rows, cols = img.shape[0:2]
|
250 |
+
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1.0)
|
251 |
+
img = cv2.warpAffine(img, M, (cols, rows), borderValue=[0, 0, 0])
|
252 |
+
center = cols / 2.0, rows / 2.0
|
253 |
+
if polygons is not None:
|
254 |
+
for polygon in polygons:
|
255 |
+
x, y = self.rotate(center, polygon.points, angle)
|
256 |
+
pts = np.vstack([x, y]).T
|
257 |
+
polygon.points = pts
|
258 |
+
return img, polygons
|
259 |
+
|
260 |
+
|
261 |
+
class RotatePadding(object):
|
262 |
+
def __init__(self, up=60,colors=True):
|
263 |
+
self.up = up
|
264 |
+
self.colors = colors
|
265 |
+
self.ratio = 0.5
|
266 |
+
|
267 |
+
@staticmethod
|
268 |
+
def rotate(center, pt, theta, movSize=[0, 0], scale=1): # 二维图形学的旋转
|
269 |
+
(xr, yr) = center
|
270 |
+
yr = -yr
|
271 |
+
x, y = pt[:, 0], pt[:, 1]
|
272 |
+
y = -y
|
273 |
+
|
274 |
+
theta = theta / 180 * math.pi
|
275 |
+
cos = math.cos(theta)
|
276 |
+
sin = math.sin(theta)
|
277 |
+
|
278 |
+
x = (x - xr) * scale
|
279 |
+
y = (y - yr) * scale
|
280 |
+
|
281 |
+
_x = xr + x * cos - y * sin + movSize[0]
|
282 |
+
_y = -(yr + x * sin + y * cos) + movSize[1]
|
283 |
+
|
284 |
+
return _x, _y
|
285 |
+
|
286 |
+
@staticmethod
|
287 |
+
def shift(size, degree):
|
288 |
+
angle = degree * math.pi / 180.0
|
289 |
+
width = size[0]
|
290 |
+
height = size[1]
|
291 |
+
|
292 |
+
alpha = math.cos(angle)
|
293 |
+
beta = math.sin(angle)
|
294 |
+
new_width = int(width * math.fabs(alpha) + height * math.fabs(beta))
|
295 |
+
new_height = int(width * math.fabs(beta) + height * math.fabs(alpha))
|
296 |
+
|
297 |
+
size = [new_width, new_height]
|
298 |
+
return size
|
299 |
+
|
300 |
+
def __call__(self, image, polygons=None, scale=1.0):
|
301 |
+
if np.random.random() <= self.ratio:
|
302 |
+
return image, polygons
|
303 |
+
angle = np.random.normal(loc=0.0, scale=0.5) * self.up # angle 按照高斯分布
|
304 |
+
rows, cols = image.shape[0:2]
|
305 |
+
center = (cols / 2.0, rows / 2.0)
|
306 |
+
newSize = self.shift([cols * scale, rows * scale], angle)
|
307 |
+
movSize = [int((newSize[0] - cols) / 2), int((newSize[1] - rows) / 2)]
|
308 |
+
|
309 |
+
M = cv2.getRotationMatrix2D(center, angle, scale)
|
310 |
+
M[0, 2] += int((newSize[0] - cols) / 2)
|
311 |
+
M[1, 2] += int((newSize[1] - rows) / 2)
|
312 |
+
|
313 |
+
if self.colors:
|
314 |
+
H, W, _ = image.shape
|
315 |
+
mask = np.zeros_like(image)
|
316 |
+
(h_index, w_index) = (np.random.randint(0, H * 7 // 8), np.random.randint(0, W * 7 // 8))
|
317 |
+
img_cut = image[h_index:(h_index + H // 9), w_index:(w_index + W // 9)]
|
318 |
+
img_cut = cv2.resize(img_cut, (newSize[0], newSize[1]))
|
319 |
+
mask = cv2.warpAffine(mask, M, (newSize[0], newSize[1]), borderValue=[1, 1, 1])
|
320 |
+
image = cv2.warpAffine(image, M, (newSize[0], newSize[1]), borderValue=[0,0,0])
|
321 |
+
image=image+img_cut*mask
|
322 |
+
else:
|
323 |
+
color = [0, 0, 0]
|
324 |
+
image = cv2.warpAffine(image, M, (newSize[0], newSize[1]), borderValue=color)
|
325 |
+
|
326 |
+
if polygons is not None:
|
327 |
+
for polygon in polygons:
|
328 |
+
x, y = self.rotate(center, polygon.points, angle,movSize,scale)
|
329 |
+
pts = np.vstack([x, y]).T
|
330 |
+
polygon.points = pts
|
331 |
+
return image, polygons
|
332 |
+
|
333 |
+
|
334 |
+
class SquarePadding(object):
|
335 |
+
|
336 |
+
def __call__(self, image, polygons=None):
|
337 |
+
|
338 |
+
H, W, _ = image.shape
|
339 |
+
|
340 |
+
if H == W:
|
341 |
+
return image, polygons
|
342 |
+
|
343 |
+
padding_size = max(H, W)
|
344 |
+
(h_index, w_index) = (np.random.randint(0, H*7//8),np.random.randint(0, W*7//8))
|
345 |
+
img_cut = image[h_index:(h_index+H//9),w_index:(w_index+W//9)]
|
346 |
+
expand_image = cv2.resize(img_cut,(padding_size, padding_size))
|
347 |
+
#expand_image = np.zeros((padding_size, padding_size, 3), dtype=image.dtype)
|
348 |
+
#expand_image=img_cut[:,:,:]
|
349 |
+
if H > W:
|
350 |
+
y0, x0 = 0, (H - W) // 2
|
351 |
+
else:
|
352 |
+
y0, x0 = (W - H) // 2, 0
|
353 |
+
if polygons is not None:
|
354 |
+
for polygon in polygons:
|
355 |
+
polygon.points += np.array([x0, y0])
|
356 |
+
expand_image[y0:y0+H, x0:x0+W] = image
|
357 |
+
image = expand_image
|
358 |
+
|
359 |
+
return image, polygons
|
360 |
+
|
361 |
+
|
362 |
+
class RandomImgCropPatch(object):
|
363 |
+
def __init__(self, up=30, beta=0.3):
|
364 |
+
self.up = up
|
365 |
+
self.beta=0.3
|
366 |
+
self.scale = 10
|
367 |
+
|
368 |
+
@staticmethod
|
369 |
+
def get_contour_min_area_box(contour):
|
370 |
+
rect = cv2.minAreaRect(contour)
|
371 |
+
box = cv2.boxPoints(rect)
|
372 |
+
box = np.int0(box)
|
373 |
+
return box
|
374 |
+
|
375 |
+
def CropWH(self, image, cut_w, cut_h, polygons=None):
|
376 |
+
h_axis, w_axis, polys = crop_first(image, polygons, scale=self.scale)
|
377 |
+
h, w, _ = image.shape
|
378 |
+
pad_h = h // self.scale
|
379 |
+
pad_w = w // self.scale
|
380 |
+
# TODO try Flip
|
381 |
+
xx = np.random.choice(w_axis, size=2)
|
382 |
+
xmin = np.min(xx) - pad_w
|
383 |
+
xmax = xmin + cut_w
|
384 |
+
yy = np.random.choice(h_axis, size=2)
|
385 |
+
ymin = np.min(yy) - pad_h
|
386 |
+
ymax = ymin + cut_h
|
387 |
+
if polys.shape[0] != 0:
|
388 |
+
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
389 |
+
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
390 |
+
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
391 |
+
else:
|
392 |
+
selected_polys = []
|
393 |
+
|
394 |
+
cropped = image[ymin:ymax + 1, xmin:xmax + 1, :]
|
395 |
+
polygons_new = []
|
396 |
+
for idx in selected_polys:
|
397 |
+
polygon = polygons[idx]
|
398 |
+
polygon.points -= np.array([xmin, ymin])
|
399 |
+
polygons_new.append(polygon)
|
400 |
+
image = cropped
|
401 |
+
polygon = polygons_new
|
402 |
+
|
403 |
+
return image, polygon
|
404 |
+
|
405 |
+
def __call__(self, images, polygons_list=None):
|
406 |
+
I_x, I_y = 1024,1024
|
407 |
+
|
408 |
+
w = int(round(I_x * random.beta(self.beta, self.beta)))
|
409 |
+
h = int(round(I_y * random.beta(self.beta, self.beta)))
|
410 |
+
w_ = [w, I_x - w, w, I_x - w]
|
411 |
+
h_ = [h, h, I_y - h, I_y - h]
|
412 |
+
new_img = np.zeros((I_x, I_y, 3), dtype=images[0].dtype)
|
413 |
+
imgs=[]
|
414 |
+
new_polygons=[]
|
415 |
+
for i, im in enumerate(images):
|
416 |
+
img, polygons = self.CropWH(im, w_[i], h_[i], polygons=polygons_list[i])
|
417 |
+
imgs.append(img)
|
418 |
+
new_polygons.append(polygons)
|
419 |
+
new_img[0:w, 0:h, :] = imgs[0]
|
420 |
+
new_img[w:I_x, 0:h, :] = imgs[1]
|
421 |
+
new_img[0:w, h:I_y, :] = imgs[2]
|
422 |
+
new_img[w:I_x, h:I_y, :] = imgs[3]
|
423 |
+
for polygon in new_polygons[1]:
|
424 |
+
polygon.points += np.array([w, 0])
|
425 |
+
for polygon in new_polygons[2]:
|
426 |
+
polygon.points += np.array([0, h])
|
427 |
+
for polygon in new_polygons[3]:
|
428 |
+
polygon.points += np.array([w, h])
|
429 |
+
|
430 |
+
polygons=new_polygons[0]+new_polygons[1]+new_polygons[2]+new_polygons[3]
|
431 |
+
|
432 |
+
return new_img, polygons
|
433 |
+
|
434 |
+
|
435 |
+
class RandomCropFlip(object):
|
436 |
+
|
437 |
+
def __init__(self, min_crop_side_ratio=0.01):
|
438 |
+
self.scale = 10
|
439 |
+
self.ratio = 0.2
|
440 |
+
self.epsilon = 10.0
|
441 |
+
self.min_crop_side_ratio = min_crop_side_ratio
|
442 |
+
|
443 |
+
def __call__(self, image, polygons=None):
|
444 |
+
|
445 |
+
if polygons is None:
|
446 |
+
return image, polygons
|
447 |
+
|
448 |
+
if np.random.random() <= self.ratio:
|
449 |
+
return image, polygons
|
450 |
+
|
451 |
+
# 计算 有效的Crop区域, 方便选取有效的种子点
|
452 |
+
h_axis, w_axis, pp_polys = crop_first(image, polygons, scale =self.scale)
|
453 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
454 |
+
return image, polygons
|
455 |
+
|
456 |
+
# TODO try crop
|
457 |
+
attempt = 0
|
458 |
+
h, w, _ = image.shape
|
459 |
+
area = h * w
|
460 |
+
pad_h = h // self.scale
|
461 |
+
pad_w = w // self.scale
|
462 |
+
while attempt < 10:
|
463 |
+
attempt += 1
|
464 |
+
polygons_new = []
|
465 |
+
xx = np.random.choice(w_axis, size=2)
|
466 |
+
xmin = np.min(xx) - pad_w
|
467 |
+
xmax = np.max(xx) - pad_w
|
468 |
+
xmin = np.clip(xmin, 0, w - 1)
|
469 |
+
xmax = np.clip(xmax, 0, w - 1)
|
470 |
+
yy = np.random.choice(h_axis, size=2)
|
471 |
+
ymin = np.min(yy) - pad_h
|
472 |
+
ymax = np.max(yy) - pad_h
|
473 |
+
ymin = np.clip(ymin, 0, h - 1)
|
474 |
+
ymax = np.clip(ymax, 0, h - 1)
|
475 |
+
if (xmax - xmin) * (ymax - ymin) < area * self.min_crop_side_ratio:
|
476 |
+
# area too small
|
477 |
+
continue
|
478 |
+
|
479 |
+
pts = np.stack([[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
480 |
+
pp = Polygon(pts).buffer(0)
|
481 |
+
Fail_flag = False
|
482 |
+
for polygon in polygons:
|
483 |
+
ppi = Polygon(polygon.points).buffer(0)
|
484 |
+
ppiou = float(ppi.intersection(pp).area)
|
485 |
+
if np.abs(ppiou - float(ppi.area)) > self.epsilon and np.abs(ppiou) > self.epsilon:
|
486 |
+
Fail_flag = True
|
487 |
+
break
|
488 |
+
if np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
489 |
+
polygons_new.append(polygon)
|
490 |
+
|
491 |
+
if Fail_flag:
|
492 |
+
continue
|
493 |
+
else:
|
494 |
+
break
|
495 |
+
|
496 |
+
if len(polygons_new) == 0:
|
497 |
+
cropped = image[ymin:ymax, xmin:xmax, :]
|
498 |
+
select_type = random.randint(3)
|
499 |
+
if select_type == 0:
|
500 |
+
img = np.ascontiguousarray(cropped[:, ::-1])
|
501 |
+
elif select_type == 1:
|
502 |
+
img = np.ascontiguousarray(cropped[::-1, :])
|
503 |
+
else:
|
504 |
+
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
505 |
+
image[ymin:ymax, xmin:xmax, :] = img
|
506 |
+
return image, polygons
|
507 |
+
|
508 |
+
else:
|
509 |
+
cropped = image[ymin:ymax, xmin:xmax, :]
|
510 |
+
height, width, _ = cropped.shape
|
511 |
+
select_type = random.randint(3)
|
512 |
+
if select_type == 0:
|
513 |
+
img = np.ascontiguousarray(cropped[:, ::-1])
|
514 |
+
for polygon in polygons_new:
|
515 |
+
polygon.points[:, 0] = width - polygon.points[:, 0] + 2 * xmin
|
516 |
+
elif select_type == 1:
|
517 |
+
img = np.ascontiguousarray(cropped[::-1, :])
|
518 |
+
for polygon in polygons_new:
|
519 |
+
polygon.points[:, 1] = height - polygon.points[:, 1] + 2 * ymin
|
520 |
+
else:
|
521 |
+
img = np.ascontiguousarray(cropped[::-1, ::-1])
|
522 |
+
for polygon in polygons_new:
|
523 |
+
polygon.points[:, 0] = width - polygon.points[:, 0] + 2 * xmin
|
524 |
+
polygon.points[:, 1] = height - polygon.points[:, 1] + 2 * ymin
|
525 |
+
image[ymin:ymax, xmin:xmax, :] = img
|
526 |
+
|
527 |
+
return image, polygons
|
528 |
+
|
529 |
+
|
530 |
+
class RandomResizedCrop(object):
|
531 |
+
def __init__(self, min_crop_side_ratio=0.1):
|
532 |
+
self.scale = 10
|
533 |
+
self.epsilon = 1e-2
|
534 |
+
self.min_crop_side_ratio = min_crop_side_ratio
|
535 |
+
|
536 |
+
def __call__(self, image, polygons):
|
537 |
+
|
538 |
+
if polygons is None:
|
539 |
+
return image, polygons
|
540 |
+
|
541 |
+
# 计算 有效的Crop区域, 方便选取有效的种子点
|
542 |
+
h_axis, w_axis, pp_polys = crop_first(image, polygons, scale =self.scale)
|
543 |
+
if len(h_axis) == 0 or len(w_axis) == 0:
|
544 |
+
return image, polygons
|
545 |
+
|
546 |
+
# TODO try crop
|
547 |
+
attempt = 0
|
548 |
+
h, w, _ = image.shape
|
549 |
+
area = h * w
|
550 |
+
pad_h = h // self.scale
|
551 |
+
pad_w = w // self.scale
|
552 |
+
while attempt < 10:
|
553 |
+
attempt += 1
|
554 |
+
xx = np.random.choice(w_axis, size=2)
|
555 |
+
xmin = np.min(xx) - pad_w
|
556 |
+
xmax = np.max(xx) - pad_w
|
557 |
+
xmin = np.clip(xmin, 0, w - 1)
|
558 |
+
xmax = np.clip(xmax, 0, w - 1)
|
559 |
+
yy = np.random.choice(h_axis, size=2)
|
560 |
+
ymin = np.min(yy) - pad_h
|
561 |
+
ymax = np.max(yy) - pad_h
|
562 |
+
ymin = np.clip(ymin, 0, h - 1)
|
563 |
+
ymax = np.clip(ymax, 0, h - 1)
|
564 |
+
if (xmax - xmin)*(ymax - ymin) <area*self.min_crop_side_ratio:
|
565 |
+
# area too small
|
566 |
+
continue
|
567 |
+
if pp_polys.shape[0] != 0:
|
568 |
+
poly_axis_in_area = (pp_polys[:, :, 0] >= xmin) & (pp_polys[:, :, 0] <= xmax) \
|
569 |
+
& (pp_polys[:, :, 1] >= ymin) & (pp_polys[:, :, 1] <= ymax)
|
570 |
+
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
571 |
+
else:
|
572 |
+
selected_polys = []
|
573 |
+
|
574 |
+
if len(selected_polys) == 0:
|
575 |
+
continue
|
576 |
+
else:
|
577 |
+
pts = np.stack([[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
|
578 |
+
pp = Polygon(pts).buffer(0)
|
579 |
+
polygons_new = []
|
580 |
+
Fail_flag = False
|
581 |
+
for polygon in copy.deepcopy(polygons):
|
582 |
+
ppi = Polygon(polygon.points).buffer(0)
|
583 |
+
ppiou = float(ppi.intersection(pp).area)
|
584 |
+
if np.abs(ppiou - float(ppi.area)) > self.epsilon and np.abs(ppiou) > self.epsilon:
|
585 |
+
Fail_flag = True
|
586 |
+
break
|
587 |
+
elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
|
588 |
+
# polygon.points -= np.array([xmin, ymin])
|
589 |
+
polygons_new.append(polygon)
|
590 |
+
|
591 |
+
if Fail_flag:
|
592 |
+
continue
|
593 |
+
else:
|
594 |
+
cropped = image[ymin:ymax + 1, xmin:xmax + 1, :]
|
595 |
+
for polygon in polygons_new:
|
596 |
+
polygon.points -= np.array([xmin, ymin])
|
597 |
+
|
598 |
+
return cropped, polygons_new
|
599 |
+
|
600 |
+
return image, polygons
|
601 |
+
|
602 |
+
|
603 |
+
class RandomResizeScale(object):
|
604 |
+
def __init__(self, size=512, ratio=(3./4, 5./2)):
|
605 |
+
self.size = size
|
606 |
+
self.ratio = ratio
|
607 |
+
|
608 |
+
def __call__(self, image, polygons=None):
|
609 |
+
|
610 |
+
aspect_ratio = np.random.uniform(self.ratio[0], self.ratio[1])
|
611 |
+
h, w, _ = image.shape
|
612 |
+
scales = self.size*1.0/max(h, w)
|
613 |
+
aspect_ratio = scales * aspect_ratio
|
614 |
+
aspect_ratio = int(w * aspect_ratio)*1.0/w
|
615 |
+
image = cv2.resize(image, (int(w * aspect_ratio), int(h*aspect_ratio)))
|
616 |
+
scales = np.array([aspect_ratio, aspect_ratio])
|
617 |
+
if polygons is not None:
|
618 |
+
for polygon in polygons:
|
619 |
+
polygon.points = polygon.points * scales
|
620 |
+
|
621 |
+
return image, polygons
|
622 |
+
|
623 |
+
|
624 |
+
class Resize(object):
|
625 |
+
def __init__(self, size=1024):
|
626 |
+
self.size = size
|
627 |
+
self.SP = SquarePadding()
|
628 |
+
|
629 |
+
def __call__(self, image, polygons=None):
|
630 |
+
h, w, _ = image.shape
|
631 |
+
image = cv2.resize(image, (self.size,
|
632 |
+
self.size))
|
633 |
+
scales = np.array([self.size / w, self.size / h])
|
634 |
+
|
635 |
+
if polygons is not None:
|
636 |
+
for polygon in polygons:
|
637 |
+
polygon.points = polygon.points * scales
|
638 |
+
|
639 |
+
return image, polygons
|
640 |
+
|
641 |
+
|
642 |
+
class ResizeSquare(object):
|
643 |
+
def __init__(self, size=(480, 1280)):
|
644 |
+
self.size = size
|
645 |
+
|
646 |
+
def __call__(self, image, polygons=None):
|
647 |
+
h, w, _ = image.shape
|
648 |
+
img_size_min = min(h, w)
|
649 |
+
img_size_max = max(h, w)
|
650 |
+
|
651 |
+
if img_size_min < self.size[0]:
|
652 |
+
im_scale = float(self.size[0]) / float(img_size_min) # expand min to size[0]
|
653 |
+
if np.ceil(im_scale * img_size_max) > self.size[1]: # expand max can't > size[1]
|
654 |
+
im_scale = float(self.size[1]) / float(img_size_max)
|
655 |
+
elif img_size_max > self.size[1]:
|
656 |
+
im_scale = float(self.size[1]) / float(img_size_max)
|
657 |
+
else:
|
658 |
+
im_scale = 1.0
|
659 |
+
|
660 |
+
new_h = int(int(h * im_scale/32)*32)
|
661 |
+
new_w = int(int(w * im_scale/32)*32)
|
662 |
+
# if new_h*new_w > 1600*1920:
|
663 |
+
# im_scale = 1600 / float(img_size_max)
|
664 |
+
# new_h = int(int(h * im_scale/32)*32)
|
665 |
+
# new_w = int(int(w * im_scale/32)*32)
|
666 |
+
image = cv2.resize(image, (new_w, new_h))
|
667 |
+
scales = np.array([new_w / w, new_h / h])
|
668 |
+
if polygons is not None:
|
669 |
+
for polygon in polygons:
|
670 |
+
polygon.points = polygon.points * scales
|
671 |
+
|
672 |
+
return image, polygons
|
673 |
+
|
674 |
+
|
675 |
+
class ResizeLimitSquare(object):
|
676 |
+
def __init__(self, size=512, ratio=0.6):
|
677 |
+
self.size = size
|
678 |
+
self.ratio = ratio
|
679 |
+
self.SP = SquarePadding()
|
680 |
+
|
681 |
+
def __call__(self, image, polygons=None):
|
682 |
+
if np.random.random() <= self.ratio:
|
683 |
+
image, polygons = self.SP(image, polygons)
|
684 |
+
h, w, _ = image.shape
|
685 |
+
image = cv2.resize(image, (self.size,self.size))
|
686 |
+
scales = np.array([self.size*1.0/ w, self.size*1.0 / h])
|
687 |
+
|
688 |
+
if polygons is not None:
|
689 |
+
for polygon in polygons:
|
690 |
+
polygon.points = polygon.points * scales
|
691 |
+
|
692 |
+
return image, polygons
|
693 |
+
|
694 |
+
|
695 |
+
class RandomResizePadding(object):
|
696 |
+
def __init__(self, size=512, random_scale=np.array([0.75, 1.0, 1.25,1.5,2.0]),stride=32, ratio=0.6667):
|
697 |
+
self.random_scale = random_scale
|
698 |
+
self.size = size
|
699 |
+
self.ratio=ratio
|
700 |
+
self.stride=stride
|
701 |
+
self.SP=SquarePadding()
|
702 |
+
|
703 |
+
###########Random size for different eproches ########################
|
704 |
+
rd_scale = np.random.choice(self.random_scale)
|
705 |
+
step_num = round(np.random.normal(loc=0.0, scale=0.35) * 8) # step 按照高斯分布
|
706 |
+
self.input_size = np.clip(int(self.size * rd_scale + step_num * self.stride),
|
707 |
+
(int(self.size * self.random_scale[0] - self.stride)),
|
708 |
+
int(self.size * self.random_scale[-1] + self.stride))
|
709 |
+
############################ end ########################
|
710 |
+
|
711 |
+
def __call__(self, image, polygons=None):
|
712 |
+
|
713 |
+
if np.random.random() <= self.ratio:
|
714 |
+
image, polygons = self.SP(image, polygons)
|
715 |
+
h, w, _ = image.shape
|
716 |
+
image = cv2.resize(image, (self.input_size,self.input_size))
|
717 |
+
scales = np.array([self.input_size*1.0/ w, self.input_size*1.0 / h])
|
718 |
+
|
719 |
+
if polygons is not None:
|
720 |
+
for polygon in polygons:
|
721 |
+
polygon.points = polygon.points * scales
|
722 |
+
|
723 |
+
return image, polygons
|
724 |
+
|
725 |
+
transform_type_dict = dict(
|
726 |
+
brightness=ImageEnhance.Brightness, contrast=ImageEnhance.Contrast,
|
727 |
+
sharpness=ImageEnhance.Sharpness, color=ImageEnhance.Color
|
728 |
+
)
|
729 |
+
|
730 |
+
|
731 |
+
class RandomDistortion(object):
|
732 |
+
def __init__(self, transform_dict, prob=0.5):
|
733 |
+
self.transforms = [(transform_type_dict[k], transform_dict[k]) for k in transform_dict]
|
734 |
+
self.prob = prob
|
735 |
+
|
736 |
+
def __call__(self, img, target):
|
737 |
+
if random.random() > self.prob:
|
738 |
+
return img, target
|
739 |
+
out = Image.fromarray(img)
|
740 |
+
rand_num = np.random.uniform(0, 1, len(self.transforms))
|
741 |
+
|
742 |
+
for i, (transformer, alpha) in enumerate(self.transforms):
|
743 |
+
r = alpha * (rand_num[i] * 2.0 - 1.0) + 1 # r in [1-alpha, 1+alpha)
|
744 |
+
out = transformer(out).enhance(r)
|
745 |
+
|
746 |
+
return np.array(out), target
|
747 |
+
|
748 |
+
|
749 |
+
class Augmentation(object):
|
750 |
+
def __init__(self, size, mean, std):
|
751 |
+
self.size = size
|
752 |
+
self.mean = mean
|
753 |
+
self.std = std
|
754 |
+
self._transform_dict = {'brightness': 0.5, 'contrast': 0.5, 'sharpness': 0.8386, 'color': 0.5}
|
755 |
+
self.augmentation = Compose([
|
756 |
+
RandomCropFlip(),
|
757 |
+
RandomResizeScale(size=self.size, ratio=(3. / 8, 5. / 2)),
|
758 |
+
RandomResizedCrop(),
|
759 |
+
RotatePadding(up=60, colors=True), # pretrain on Syn is "up=30", else is "up=60"
|
760 |
+
ResizeLimitSquare(size=self.size),
|
761 |
+
RandomMirror(),
|
762 |
+
RandomDistortion(self._transform_dict),
|
763 |
+
Normalize(mean=self.mean, std=self.std),
|
764 |
+
])
|
765 |
+
|
766 |
+
def __call__(self, image, polygons=None):
|
767 |
+
return self.augmentation(image, polygons)
|
768 |
+
|
769 |
+
|
770 |
+
class BaseTransform(object):
|
771 |
+
def __init__(self, size, mean, std):
|
772 |
+
self.size = size
|
773 |
+
self.mean = mean
|
774 |
+
self.std = std
|
775 |
+
self.augmentation = Compose([
|
776 |
+
# Resize(size=640),
|
777 |
+
ResizeSquare(size=self.size),
|
778 |
+
Normalize(mean, std)
|
779 |
+
])
|
780 |
+
|
781 |
+
def __call__(self, image, polygons=None):
|
782 |
+
return self.augmentation(image, polygons)
|
783 |
+
|
784 |
+
|
785 |
+
class BaseTransformNresize(object):
|
786 |
+
def __init__(self, mean, std):
|
787 |
+
self.mean = mean
|
788 |
+
self.std = std
|
789 |
+
self.augmentation = Compose([
|
790 |
+
Normalize(mean, std)
|
791 |
+
])
|
792 |
+
|
793 |
+
def __call__(self, image, polygons=None):
|
794 |
+
return self.augmentation(image, polygons)
|
IndicPhotoOCR/detection/textbpn/util/canvas.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
__author__ = '古溪'
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
|
10 |
+
def heatmap(im_gray):
|
11 |
+
cmap = plt.get_cmap('jet')
|
12 |
+
rgba_img = cmap(255 - im_gray)
|
13 |
+
Hmap = np.delete(rgba_img, 3, 2)
|
14 |
+
# print(Hmap.shape, Hmap.max(), Hmap.min())
|
15 |
+
# cv2.imshow("heat_img", Hmap)
|
16 |
+
# cv2.waitKey(0)
|
17 |
+
return Hmap
|
18 |
+
|
19 |
+
|
20 |
+
def loss_ploy(loss_list, steps, period, name=""):
|
21 |
+
fig1, ax1 = plt.subplots(figsize=(16, 9))
|
22 |
+
ax1.plot(range(steps // period), loss_list)
|
23 |
+
ax1.set_title("Average loss vs step*{}".format(period))
|
24 |
+
ax1.set_xlabel("step*{}".format(period))
|
25 |
+
ax1.set_ylabel("Current loss")
|
26 |
+
plt.savefig('{}@loss_vs_step*{}.png'.format(name,period))
|
27 |
+
plt.clf()
|
28 |
+
|
29 |
+
|
30 |
+
def plt_ploys(ploys, period, name=""):
|
31 |
+
fig1, ax1 = plt.subplots(figsize=(16, 9))
|
32 |
+
cnames = ['aliceblue','antiquewhite','aqua','aquamarine','azure',
|
33 |
+
'blanchedalmond','blue','blueviolet','brown','burlywood',
|
34 |
+
'coral','cornflowerblue','cornsilk','crimson','cyan',
|
35 |
+
'darkblue','deeppink','deepskyblue','dodgerblue','forestgreen',
|
36 |
+
'gold','goldenrod','green','greenyellow','honeydew','hotpink',
|
37 |
+
'lawngreen','lightblue','lightgreen','lightpink','lightsalmon',
|
38 |
+
'lightseagreen','lightsteelblue','lightyellow','lime','limegreen',
|
39 |
+
'mediumseagreen','mediumspringgreen','midnightblue','orange','orangered',
|
40 |
+
'pink','red','royalblue','seagreen','skyblue','springgreen','steelblue',
|
41 |
+
'tan','teal','thistle','yellow','yellowgreen']
|
42 |
+
|
43 |
+
color = random.sample(cnames, len(ploys.keys()))
|
44 |
+
for ii, key in enumerate(ploys.keys()):
|
45 |
+
ax1.plot(range(1, len(ploys[key])+1), ploys[key],color=color[ii], label=key)
|
46 |
+
ax1.set_title("Loss Carve line")
|
47 |
+
ax1.set_xlabel("step*{}".format(period))
|
48 |
+
ax1.set_ylabel("Current loss")
|
49 |
+
plt.legend(ploys.keys())
|
50 |
+
plt.savefig('{}@loss_vs_step*{}.png'.format(name, period))
|
51 |
+
plt.clf()
|
52 |
+
|
53 |
+
if __name__ == '__main__':
|
54 |
+
# TODO ADD CODE
|
55 |
+
pass
|
IndicPhotoOCR/detection/textbpn/util/detection.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# c++ version pse based on opencv 3+
|
2 |
+
from pse import decode as pse_decode
|
3 |
+
from cfglib.config import config as cfg
|
4 |
+
|
5 |
+
|
6 |
+
class TextDetector(object):
|
7 |
+
|
8 |
+
def __init__(self, model):
|
9 |
+
# evaluation mode
|
10 |
+
self.model = model
|
11 |
+
model.eval()
|
12 |
+
# parameter
|
13 |
+
self.scale = cfg.scale
|
14 |
+
self.threshold = cfg.threshold
|
15 |
+
|
16 |
+
def detect(self, image, img_show):
|
17 |
+
# get model output
|
18 |
+
preds = self.model.forward(image)
|
19 |
+
preds, boxes, contours = pse_decode(preds[0], self.scale, self.threshold)
|
20 |
+
|
21 |
+
output = {
|
22 |
+
'image': image,
|
23 |
+
'tr': preds,
|
24 |
+
'bbox': boxes
|
25 |
+
}
|
26 |
+
return contours, output
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
IndicPhotoOCR/detection/textbpn/util/eval.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import subprocess
|
5 |
+
from cfglib.config import config as cfg
|
6 |
+
from util.misc import mkdirs
|
7 |
+
|
8 |
+
|
9 |
+
def osmkdir(out_dir):
|
10 |
+
import shutil
|
11 |
+
if os.path.exists(out_dir):
|
12 |
+
shutil.rmtree(out_dir)
|
13 |
+
os.makedirs(out_dir)
|
14 |
+
|
15 |
+
|
16 |
+
def analysize_result(source_dir, fid_path, outpt_dir, name):
|
17 |
+
|
18 |
+
bad_txt = open("{}/eval.txt".format(outpt_dir), 'w')
|
19 |
+
all_eval = open("{}/{}/{}_eval.txt".format(cfg.output_dir, "Analysis", name), 'a+')
|
20 |
+
sel_list = list()
|
21 |
+
with open(fid_path) as f:
|
22 |
+
lines = f.read().split("\n")
|
23 |
+
for line in lines:
|
24 |
+
line_items = line.split(" ")
|
25 |
+
id = line_items[0]
|
26 |
+
precision = float(line_items[2].split('=')[-1])
|
27 |
+
recall = float(line_items[4].split('=')[-1])
|
28 |
+
if id != "ALL" and (precision < 0.5 or recall < 0.5):
|
29 |
+
img_path = os.path.join(source_dir, line_items[0].replace(".txt", ".jpg"))
|
30 |
+
if os.path.exists(img_path):
|
31 |
+
os.system('cp {} {}'.format(img_path, outpt_dir))
|
32 |
+
sel_list.append((int(id.replace(".txt", "").replace("img", "").replace("_", "")), line))
|
33 |
+
if id == "ALL":
|
34 |
+
all_eval.write("{} {} {}\n".format(
|
35 |
+
outpt_dir.split('/')[-1],
|
36 |
+
"{}/{}".format(cfg.dis_threshold, cfg.cls_threshold),
|
37 |
+
line))
|
38 |
+
sel_list = sorted(sel_list, key=lambda its: its[0])
|
39 |
+
bad_txt.write('\n'.join([its[1] for its in sel_list]))
|
40 |
+
all_eval.close()
|
41 |
+
bad_txt.close()
|
42 |
+
|
43 |
+
|
44 |
+
def deal_eval_total_text(debug=False):
|
45 |
+
# compute DetEval
|
46 |
+
eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
|
47 |
+
if not os.path.exists(eval_dir):
|
48 |
+
os.makedirs(eval_dir)
|
49 |
+
|
50 |
+
print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
|
51 |
+
subprocess.call(
|
52 |
+
['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', cfg.exp_name, '--tr', '0.7',
|
53 |
+
'--tp', '0.6'])
|
54 |
+
subprocess.call(
|
55 |
+
['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', cfg.exp_name, '--tr', '0.8',
|
56 |
+
'--tp', '0.4'])
|
57 |
+
|
58 |
+
if debug:
|
59 |
+
source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
|
60 |
+
outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "total_text")
|
61 |
+
if not os.path.exists(outpt_dir_base):
|
62 |
+
mkdirs(outpt_dir_base)
|
63 |
+
|
64 |
+
outpt_dir1 = os.path.join(outpt_dir_base, "{}_{}_{}_{}_{}"
|
65 |
+
.format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch, 0.7, 0.6))
|
66 |
+
osmkdir(outpt_dir1)
|
67 |
+
fid_path1 = '{}/Eval_TotalText_{}_{}.txt'.format(eval_dir, 0.7, 0.6)
|
68 |
+
|
69 |
+
analysize_result(source_dir, fid_path1, outpt_dir1, "totalText")
|
70 |
+
|
71 |
+
outpt_dir2 = os.path.join(outpt_dir_base, "{}_{}_{}_{}_{}"
|
72 |
+
.format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch, 0.8, 0.4))
|
73 |
+
osmkdir(outpt_dir2)
|
74 |
+
fid_path2 = '{}/Eval_TotalText_{}_{}.txt'.format(eval_dir, 0.8, 0.4)
|
75 |
+
|
76 |
+
analysize_result(source_dir, fid_path2, outpt_dir2, "totalText")
|
77 |
+
|
78 |
+
print('End.')
|
79 |
+
|
80 |
+
|
81 |
+
def deal_eval_ctw1500(debug=False):
|
82 |
+
# compute DetEval
|
83 |
+
eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
|
84 |
+
if not os.path.exists(eval_dir):
|
85 |
+
os.makedirs(eval_dir)
|
86 |
+
|
87 |
+
print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
|
88 |
+
subprocess.call(['python', 'dataset/ctw1500/Evaluation_Protocol/ctw1500_eval.py', cfg.exp_name])
|
89 |
+
|
90 |
+
if debug:
|
91 |
+
source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
|
92 |
+
outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "ctw1500")
|
93 |
+
if not os.path.exists(outpt_dir_base):
|
94 |
+
mkdirs(outpt_dir_base)
|
95 |
+
|
96 |
+
outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
|
97 |
+
osmkdir(outpt_dir)
|
98 |
+
fid_path1 = '{}/Eval_ctw1500_{}.txt'.format(eval_dir, 0.5)
|
99 |
+
|
100 |
+
analysize_result(source_dir, fid_path1, outpt_dir, "ctw1500")
|
101 |
+
|
102 |
+
print('End.')
|
103 |
+
|
104 |
+
|
105 |
+
def deal_eval_icdar15(debug=False):
|
106 |
+
# compute DetEval
|
107 |
+
eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
|
108 |
+
if not os.path.exists(eval_dir):
|
109 |
+
os.makedirs(eval_dir)
|
110 |
+
|
111 |
+
input_dir = 'output/{}'.format(cfg.exp_name)
|
112 |
+
father_path = os.path.abspath(input_dir)
|
113 |
+
print(father_path)
|
114 |
+
print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
|
115 |
+
subprocess.call(['sh', 'dataset/icdar15/eval.sh', father_path])
|
116 |
+
|
117 |
+
if debug:
|
118 |
+
source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
|
119 |
+
outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "icdar15")
|
120 |
+
if not os.path.exists(outpt_dir_base):
|
121 |
+
mkdirs(outpt_dir_base)
|
122 |
+
|
123 |
+
outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
|
124 |
+
osmkdir(outpt_dir)
|
125 |
+
fid_path1 = '{}/Eval_icdar15.txt'.format(eval_dir)
|
126 |
+
|
127 |
+
analysize_result(source_dir, fid_path1, outpt_dir, "icdar15")
|
128 |
+
|
129 |
+
print('End.')
|
130 |
+
|
131 |
+
pass
|
132 |
+
|
133 |
+
|
134 |
+
def deal_eval_TD500(debug=False):
|
135 |
+
# compute DetEval
|
136 |
+
eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
|
137 |
+
if not os.path.exists(eval_dir):
|
138 |
+
os.makedirs(eval_dir)
|
139 |
+
|
140 |
+
input_dir = 'output/{}'.format(cfg.exp_name)
|
141 |
+
father_path = os.path.abspath(input_dir)
|
142 |
+
print(father_path)
|
143 |
+
print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
|
144 |
+
subprocess.call(['sh', 'dataset/TD500/eval.sh', father_path])
|
145 |
+
|
146 |
+
if debug:
|
147 |
+
source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
|
148 |
+
outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "TD500")
|
149 |
+
if not os.path.exists(outpt_dir_base):
|
150 |
+
mkdirs(outpt_dir_base)
|
151 |
+
|
152 |
+
outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
|
153 |
+
osmkdir(outpt_dir)
|
154 |
+
fid_path1 = '{}/Eval_TD500.txt'.format(eval_dir)
|
155 |
+
|
156 |
+
analysize_result(source_dir, fid_path1, outpt_dir, "TD500")
|
157 |
+
|
158 |
+
print('End.')
|
159 |
+
|
160 |
+
|
161 |
+
def data_transfer_ICDAR(contours):
|
162 |
+
cnts = list()
|
163 |
+
for cont in contours:
|
164 |
+
rect = cv2.minAreaRect(cont)
|
165 |
+
if min(rect[1][0], rect[1][1]) <= 5:
|
166 |
+
continue
|
167 |
+
points = cv2.boxPoints(rect)
|
168 |
+
points = np.int0(points)
|
169 |
+
# print(points.shape)
|
170 |
+
# points = np.reshape(points, (4, 2))
|
171 |
+
cnts.append(points)
|
172 |
+
return cnts
|
173 |
+
|
174 |
+
|
175 |
+
def data_transfer_TD500(contours, res_file, img=None):
|
176 |
+
with open(res_file, 'w') as f:
|
177 |
+
for cont in contours:
|
178 |
+
rect = cv2.minAreaRect(cont)
|
179 |
+
if min(rect[1][0], rect[1][1]) <= 5:
|
180 |
+
continue
|
181 |
+
points = cv2.boxPoints(rect)
|
182 |
+
box = np.int0(points)
|
183 |
+
cv2.drawContours(img, [box], 0, (0, 255, 0), 3)
|
184 |
+
|
185 |
+
cx, cy = rect[0]
|
186 |
+
w_, h_ = rect[1]
|
187 |
+
angle = rect[2]
|
188 |
+
mid_ = 0
|
189 |
+
if angle > 45:
|
190 |
+
angle = 90 - angle
|
191 |
+
mid_ = w_;
|
192 |
+
w_ = h_;
|
193 |
+
h_ = mid_
|
194 |
+
elif angle < -45:
|
195 |
+
angle = 90 + angle
|
196 |
+
mid_ = w_;
|
197 |
+
w_ = h_;
|
198 |
+
h_ = mid_
|
199 |
+
angle = angle / 180 * 3.141592653589
|
200 |
+
|
201 |
+
x_min = int(cx - w_ / 2)
|
202 |
+
x_max = int(cx + w_ / 2)
|
203 |
+
y_min = int(cy - h_ / 2)
|
204 |
+
y_max = int(cy + h_ / 2)
|
205 |
+
f.write('{},{},{},{},{}\r\n'.format(x_min, y_min, x_max, y_max, angle))
|
206 |
+
|
207 |
+
return img
|
208 |
+
|
209 |
+
|
210 |
+
def data_transfer_MLT2017(contours, res_file):
|
211 |
+
with open(res_file, 'w') as f:
|
212 |
+
for cont in contours:
|
213 |
+
rect = cv2.minAreaRect(cont)
|
214 |
+
if min(rect[1][0], rect[1][1]) <= 5:
|
215 |
+
continue
|
216 |
+
ploy_area = cv2.contourArea(cont)
|
217 |
+
rect_area = rect[1][0]*rect[1][1]
|
218 |
+
solidity = ploy_area/rect_area
|
219 |
+
width = rect[1][0] - np.clip(rect[1][0] * (1-np.sqrt(solidity)), 0, 6)
|
220 |
+
height = rect[1][1] - np.clip(rect[1][1] * (1-np.sqrt(solidity)), 0, 4)
|
221 |
+
points = cv2.boxPoints((rect[0], (width, height), rect[2]))
|
222 |
+
points = np.int0(points)
|
223 |
+
p = np.reshape(points, -1)
|
224 |
+
f.write('{},{},{},{},{},{},{},{},{}\r\n'
|
225 |
+
.format(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], 1))
|
226 |
+
|
227 |
+
|
228 |
+
|
IndicPhotoOCR/detection/textbpn/util/graph.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import absolute_import
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
from util.misc import norm2
|
8 |
+
|
9 |
+
class Data(object):
|
10 |
+
def __init__(self, name):
|
11 |
+
self.__name = name
|
12 |
+
self.__links = set()
|
13 |
+
|
14 |
+
@property
|
15 |
+
def name(self):
|
16 |
+
return self.__name
|
17 |
+
|
18 |
+
@property
|
19 |
+
def links(self):
|
20 |
+
return set(self.__links)
|
21 |
+
|
22 |
+
def add_link(self, other, score):
|
23 |
+
self.__links.add(other)
|
24 |
+
other.__links.add(self)
|
25 |
+
|
26 |
+
|
27 |
+
def connected_components(nodes, score_dict, th):
|
28 |
+
'''
|
29 |
+
conventional connected components searching
|
30 |
+
'''
|
31 |
+
result = []
|
32 |
+
nodes = set(nodes)
|
33 |
+
while nodes:
|
34 |
+
n = nodes.pop()
|
35 |
+
group = {n}
|
36 |
+
queue = [n]
|
37 |
+
while queue:
|
38 |
+
n = queue.pop(0)
|
39 |
+
if th is not None:
|
40 |
+
neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
|
41 |
+
else:
|
42 |
+
neighbors = n.links
|
43 |
+
neighbors.difference_update(group)
|
44 |
+
nodes.difference_update(neighbors)
|
45 |
+
group.update(neighbors)
|
46 |
+
queue.extend(neighbors)
|
47 |
+
result.append(group)
|
48 |
+
return result
|
49 |
+
|
50 |
+
|
51 |
+
def connected_components_constraint(nodes, max_sz, score_dict=None, th=None):
|
52 |
+
'''
|
53 |
+
only use edges whose scores are above `th`
|
54 |
+
if a component is larger than `max_sz`, all the nodes in this component are added into `remain` and returned for next iteration.
|
55 |
+
'''
|
56 |
+
result = []
|
57 |
+
remain = set()
|
58 |
+
nodes = set(nodes)
|
59 |
+
while nodes:
|
60 |
+
n = nodes.pop()
|
61 |
+
group = {n}
|
62 |
+
queue = [n]
|
63 |
+
valid = True
|
64 |
+
while queue:
|
65 |
+
n = queue.pop(0)
|
66 |
+
if th is not None:
|
67 |
+
neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
|
68 |
+
else:
|
69 |
+
neighbors = n.links
|
70 |
+
neighbors.difference_update(group)
|
71 |
+
nodes.difference_update(neighbors)
|
72 |
+
group.update(neighbors)
|
73 |
+
queue.extend(neighbors)
|
74 |
+
if len(group) > max_sz or len(remain.intersection(neighbors)) > 0:
|
75 |
+
# if this group is larger than `max_sz`, add the nodes into `remain`
|
76 |
+
valid = False
|
77 |
+
remain.update(group)
|
78 |
+
break
|
79 |
+
if valid: # if this group is smaller than or equal to `max_sz`, finalize it.
|
80 |
+
result.append(group)
|
81 |
+
return result, remain
|
82 |
+
|
83 |
+
|
84 |
+
def graph_propagation_naive(edges, score, th, bboxs=None, dis_thresh=50, pool='avg'):
|
85 |
+
|
86 |
+
edges = np.sort(edges, axis=1)
|
87 |
+
|
88 |
+
score_dict = {} # score lookup table
|
89 |
+
if pool is None:
|
90 |
+
for i, e in enumerate(edges):
|
91 |
+
score_dict[e[0], e[1]] = score[i]
|
92 |
+
elif pool == 'avg':
|
93 |
+
for i, e in enumerate(edges):
|
94 |
+
if bboxs is not None:
|
95 |
+
box1 = bboxs[e[0]][:8].reshape(4, 2)
|
96 |
+
box2 = bboxs[e[1]][:8].reshape(4, 2)
|
97 |
+
c1 = np.mean(box1, 0); c2 = np.mean(box2, 0)
|
98 |
+
dst = norm2(c1 - c2)
|
99 |
+
if dst > dis_thresh:
|
100 |
+
score[i] = 0
|
101 |
+
if (e[0], e[1]) in score_dict:
|
102 |
+
score_dict[e[0], e[1]] = 0.5 * (score_dict[e[0], e[1]] + score[i])
|
103 |
+
else:
|
104 |
+
score_dict[e[0], e[1]] = score[i]
|
105 |
+
|
106 |
+
elif pool == 'max':
|
107 |
+
for i, e in enumerate(edges):
|
108 |
+
if (e[0], e[1]) in score_dict:
|
109 |
+
score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]], score[i])
|
110 |
+
else:
|
111 |
+
score_dict[e[0], e[1]] = score[i]
|
112 |
+
else:
|
113 |
+
raise ValueError('Pooling operation not supported')
|
114 |
+
|
115 |
+
nodes = np.sort(np.unique(edges.flatten()))
|
116 |
+
mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
|
117 |
+
mapping[nodes] = np.arange(nodes.shape[0])
|
118 |
+
link_idx = mapping[edges]
|
119 |
+
vertex = [Data(n) for n in nodes]
|
120 |
+
for l, s in zip(link_idx, score):
|
121 |
+
vertex[l[0]].add_link(vertex[l[1]], s)
|
122 |
+
|
123 |
+
# first iteration
|
124 |
+
comps = connected_components(vertex, score_dict,th)
|
125 |
+
|
126 |
+
return comps
|
127 |
+
|
128 |
+
|
129 |
+
def graph_search(edges, scores, edges_num, th=None):
|
130 |
+
# graph search
|
131 |
+
scores = scores.reshape((-1, edges_num))
|
132 |
+
select_index = np.argsort(scores, axis=1)[:, -2:]
|
133 |
+
edges = np.sort(edges, axis=1).reshape((-1, edges_num, 2))
|
134 |
+
|
135 |
+
score_dict = {}
|
136 |
+
for i, ips in enumerate(select_index):
|
137 |
+
edg = edges[i]
|
138 |
+
si = scores[i]
|
139 |
+
for j, idx in enumerate(ips):
|
140 |
+
e = edg[idx, :]
|
141 |
+
if (e[0], e[1]) in score_dict:
|
142 |
+
score_dict[e[0], e[1]] = 0.5 * (score_dict[e[0], e[1]] + si[j])
|
143 |
+
else:
|
144 |
+
score_dict[e[0], e[1]] = si[j]
|
145 |
+
|
146 |
+
nodes = np.sort(np.unique(edges.flatten()))
|
147 |
+
vertex = [Data(n) for n in nodes]
|
148 |
+
for (key, value) in score_dict.items():
|
149 |
+
vertex[key[0]].add_link(vertex[key[1]], value)
|
150 |
+
|
151 |
+
comps = connected_components(vertex, score_dict, th)
|
152 |
+
|
153 |
+
return comps
|
154 |
+
|
155 |
+
|
156 |
+
def graph_propagation(edges, score, max_sz, step=0.1, beg_th=0.5, pool=None):
|
157 |
+
|
158 |
+
edges = np.sort(edges, axis=1)
|
159 |
+
th = score.min()
|
160 |
+
# th = beg_th
|
161 |
+
# construct graph
|
162 |
+
score_dict = {} # score lookup table
|
163 |
+
if pool is None:
|
164 |
+
for i,e in enumerate(edges):
|
165 |
+
score_dict[e[0], e[1]] = score[i]
|
166 |
+
elif pool == 'avg':
|
167 |
+
for i,e in enumerate(edges):
|
168 |
+
if (e[0], e[1]) in score_dict:
|
169 |
+
score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i])
|
170 |
+
else:
|
171 |
+
score_dict[e[0], e[1]] = score[i]
|
172 |
+
|
173 |
+
elif pool == 'max':
|
174 |
+
for i,e in enumerate(edges):
|
175 |
+
if (e[0],e[1]) in score_dict:
|
176 |
+
score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i])
|
177 |
+
else:
|
178 |
+
score_dict[e[0], e[1]] = score[i]
|
179 |
+
else:
|
180 |
+
raise ValueError('Pooling operation not supported')
|
181 |
+
|
182 |
+
nodes = np.sort(np.unique(edges.flatten()))
|
183 |
+
mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
|
184 |
+
mapping[nodes] = np.arange(nodes.shape[0])
|
185 |
+
link_idx = mapping[edges]
|
186 |
+
vertex = [Data(n) for n in nodes]
|
187 |
+
for l, s in zip(link_idx, score):
|
188 |
+
vertex[l[0]].add_link(vertex[l[1]], s)
|
189 |
+
|
190 |
+
# first iteration
|
191 |
+
comps, remain = connected_components_constraint(vertex, max_sz)
|
192 |
+
|
193 |
+
# iteration
|
194 |
+
components = comps[:]
|
195 |
+
while remain:
|
196 |
+
th = th + (1 - th) * step
|
197 |
+
comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
|
198 |
+
components.extend(comps)
|
199 |
+
return components
|
200 |
+
|
201 |
+
|
202 |
+
def graph_propagation_soft(edges, score, max_sz, step=0.1, **kwargs):
|
203 |
+
|
204 |
+
edges = np.sort(edges, axis=1)
|
205 |
+
th = score.min()
|
206 |
+
|
207 |
+
# construct graph
|
208 |
+
score_dict = {} # score lookup table
|
209 |
+
for i,e in enumerate(edges):
|
210 |
+
score_dict[e[0], e[1]] = score[i]
|
211 |
+
|
212 |
+
nodes = np.sort(np.unique(edges.flatten()))
|
213 |
+
mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
|
214 |
+
mapping[nodes] = np.arange(nodes.shape[0])
|
215 |
+
link_idx = mapping[edges]
|
216 |
+
vertex = [Data(n) for n in nodes]
|
217 |
+
for l, s in zip(link_idx, score):
|
218 |
+
vertex[l[0]].add_link(vertex[l[1]], s)
|
219 |
+
|
220 |
+
# first iteration
|
221 |
+
comps, remain = connected_components_constraint(vertex, max_sz)
|
222 |
+
first_vertex_idx = np.array([mapping[n.name] for c in comps for n in c])
|
223 |
+
fusion_vertex_idx = np.setdiff1d(np.arange(nodes.shape[0]), first_vertex_idx, assume_unique=True)
|
224 |
+
# iteration
|
225 |
+
components = comps[:]
|
226 |
+
while remain:
|
227 |
+
th = th + (1 - th) * step
|
228 |
+
comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
|
229 |
+
components.extend(comps)
|
230 |
+
label_dict = {}
|
231 |
+
for i,c in enumerate(components):
|
232 |
+
for n in c:
|
233 |
+
label_dict[n.name] = i
|
234 |
+
print('Propagation ...')
|
235 |
+
prop_vertex = [vertex[idx] for idx in fusion_vertex_idx]
|
236 |
+
label, label_fusion = diffusion(prop_vertex, label_dict, score_dict, **kwargs)
|
237 |
+
return label, label_fusion
|
238 |
+
|
239 |
+
|
240 |
+
def diffusion(vertex, label, score_dict, max_depth=5, weight_decay=0.6, normalize=True):
|
241 |
+
class BFSNode():
|
242 |
+
def __init__(self, node, depth, value):
|
243 |
+
self.node = node
|
244 |
+
self.depth = depth
|
245 |
+
self.value = value
|
246 |
+
|
247 |
+
label_fusion = {}
|
248 |
+
for name in label.keys():
|
249 |
+
label_fusion[name] = {label[name]: 1.0}
|
250 |
+
prog = 0
|
251 |
+
prog_step = len(vertex) // 20
|
252 |
+
start = time.time()
|
253 |
+
for root in vertex:
|
254 |
+
if prog % prog_step == 0:
|
255 |
+
print("progress: {} / {}, elapsed time: {}".format(prog, len(vertex), time.time() - start))
|
256 |
+
prog += 1
|
257 |
+
#queue = {[root, 0, 1.0]}
|
258 |
+
queue = {BFSNode(root, 0, 1.0)}
|
259 |
+
visited = [root.name]
|
260 |
+
root_label = label[root.name]
|
261 |
+
while queue:
|
262 |
+
curr = queue.pop()
|
263 |
+
if curr.depth >= max_depth: # pruning
|
264 |
+
continue
|
265 |
+
neighbors = curr.node.links
|
266 |
+
tmp_value = []
|
267 |
+
tmp_neighbor = []
|
268 |
+
for n in neighbors:
|
269 |
+
if n.name not in visited:
|
270 |
+
sub_value = score_dict[tuple(sorted([curr.node.name, n.name]))] * weight_decay * curr.value
|
271 |
+
tmp_value.append(sub_value)
|
272 |
+
tmp_neighbor.append(n)
|
273 |
+
if root_label not in label_fusion[n.name].keys():
|
274 |
+
label_fusion[n.name][root_label] = sub_value
|
275 |
+
else:
|
276 |
+
label_fusion[n.name][root_label] += sub_value
|
277 |
+
visited.append(n.name)
|
278 |
+
#queue.add([n, curr.depth+1, sub_value])
|
279 |
+
sortidx = np.argsort(tmp_value)[::-1]
|
280 |
+
for si in sortidx:
|
281 |
+
queue.add(BFSNode(tmp_neighbor[si], curr.depth+1, tmp_value[si]))
|
282 |
+
if normalize:
|
283 |
+
for name in label_fusion.keys():
|
284 |
+
summ = sum(label_fusion[name].values())
|
285 |
+
for k in label_fusion[name].keys():
|
286 |
+
label_fusion[name][k] /= summ
|
287 |
+
return label, label_fusion
|
288 |
+
|
289 |
+
|
290 |
+
def clusters2labels(clusters, n_nodes):
|
291 |
+
labels = (-1)* np.ones((n_nodes,))
|
292 |
+
for ci, c in enumerate(clusters):
|
293 |
+
for xid in c:
|
294 |
+
labels[xid.name] = ci
|
295 |
+
assert np.sum(labels < 0) < 1
|
296 |
+
return labels
|
297 |
+
|
298 |
+
|
299 |
+
def single_remove(bbox, pred):
|
300 |
+
single_idcs = np.zeros_like(pred)
|
301 |
+
pred_unique = np.unique(pred)
|
302 |
+
for u in pred_unique:
|
303 |
+
idcs = pred == u
|
304 |
+
if np.sum(idcs) == 1:
|
305 |
+
single_idcs[np.where(idcs)[0][0]] = 1
|
306 |
+
remain_idcs = [i for i in range(len(pred)) if not single_idcs[i]]
|
307 |
+
remain_idcs = np.asarray(remain_idcs)
|
308 |
+
return bbox[remain_idcs, :], pred[remain_idcs]
|
309 |
+
|
IndicPhotoOCR/detection/textbpn/util/io.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#coding=utf-8
|
2 |
+
'''
|
3 |
+
Created on 2016年9月27日
|
4 |
+
|
5 |
+
@author: dengdan
|
6 |
+
|
7 |
+
Tool functions for file system operation and I/O.
|
8 |
+
In the style of linux shell commands
|
9 |
+
'''
|
10 |
+
import os
|
11 |
+
import pickle as pkl
|
12 |
+
import subprocess
|
13 |
+
import logging
|
14 |
+
from . import strs, io
|
15 |
+
|
16 |
+
|
17 |
+
def mkdir(path):
|
18 |
+
"""
|
19 |
+
If the target directory does not exists, it and its parent directories will created.
|
20 |
+
"""
|
21 |
+
path = get_absolute_path(path)
|
22 |
+
if not exists(path):
|
23 |
+
os.makedirs(path)
|
24 |
+
return path
|
25 |
+
|
26 |
+
def make_parent_dir(path):
|
27 |
+
"""make the parent directories for a file."""
|
28 |
+
parent_dir = get_dir(path)
|
29 |
+
mkdir(parent_dir)
|
30 |
+
|
31 |
+
|
32 |
+
def pwd():
|
33 |
+
return os.getcwd()
|
34 |
+
|
35 |
+
def dump(path, obj):
|
36 |
+
path = get_absolute_path(path)
|
37 |
+
parent_path = get_dir(path)
|
38 |
+
mkdir(parent_path)
|
39 |
+
with open(path, 'w') as f:
|
40 |
+
logging.info('dumping file:' + path);
|
41 |
+
pkl.dump(obj, f)
|
42 |
+
|
43 |
+
def load(path):
|
44 |
+
path = get_absolute_path(path)
|
45 |
+
with open(path, 'r') as f:
|
46 |
+
data = pkl.load(f)
|
47 |
+
return data
|
48 |
+
|
49 |
+
def join_path(a, *p):
|
50 |
+
return os.path.join(a, *p)
|
51 |
+
|
52 |
+
def is_dir(path):
|
53 |
+
path = get_absolute_path(path)
|
54 |
+
return os.path.isdir(path)
|
55 |
+
|
56 |
+
is_directory = is_dir
|
57 |
+
|
58 |
+
def is_path(path):
|
59 |
+
path = get_absolute_path(path)
|
60 |
+
return os.path.ispath(path)
|
61 |
+
|
62 |
+
def get_dir(path):
|
63 |
+
'''
|
64 |
+
return the directory it belongs to.
|
65 |
+
if path is a directory itself, itself will be return
|
66 |
+
'''
|
67 |
+
path = get_absolute_path(path)
|
68 |
+
if is_dir(path):
|
69 |
+
return path;
|
70 |
+
return os.path.split(path)[0]
|
71 |
+
|
72 |
+
def get_parent_dir(path):
|
73 |
+
current_dir = get_dir(path)
|
74 |
+
return get_absolute_path(join_path(current_dir, '..'))
|
75 |
+
|
76 |
+
def get_filename(path):
|
77 |
+
return os.path.split(path)[1]
|
78 |
+
|
79 |
+
def get_absolute_path(p):
|
80 |
+
if p.startswith('~'):
|
81 |
+
p = os.path.expanduser(p)
|
82 |
+
return os.path.abspath(p)
|
83 |
+
|
84 |
+
def cd(p):
|
85 |
+
p = get_absolute_path(p)
|
86 |
+
os.chdir(p)
|
87 |
+
|
88 |
+
def ls(path = '.', suffix = None):
|
89 |
+
"""
|
90 |
+
list files in a directory.
|
91 |
+
return file names in a list
|
92 |
+
"""
|
93 |
+
path = get_absolute_path(path)
|
94 |
+
files = os.listdir(path)
|
95 |
+
|
96 |
+
if suffix is None:
|
97 |
+
return files
|
98 |
+
|
99 |
+
filtered = []
|
100 |
+
for f in files:
|
101 |
+
if string.ends_with(f, suffix, ignore_case = True):
|
102 |
+
filtered.append(f)
|
103 |
+
|
104 |
+
return filtered
|
105 |
+
|
106 |
+
def find_files(pattern):
|
107 |
+
import glob
|
108 |
+
return glob.glob(pattern)
|
109 |
+
|
110 |
+
def read_lines(p):
|
111 |
+
"""return the text in a file in lines as a list """
|
112 |
+
p = get_absolute_path(p)
|
113 |
+
f = open(p,'r')
|
114 |
+
return f.readlines()
|
115 |
+
|
116 |
+
def write_lines(p, lines, append_break = False):
|
117 |
+
p = get_absolute_path(p)
|
118 |
+
make_parent_dir(p)
|
119 |
+
with open(p, 'w') as f:
|
120 |
+
for line in lines:
|
121 |
+
if append_break:
|
122 |
+
f.write(line + '\n')
|
123 |
+
else:
|
124 |
+
f.write(line)
|
125 |
+
|
126 |
+
def cat(p):
|
127 |
+
"""return the text in a file as a whole"""
|
128 |
+
cmd = 'cat ' + p
|
129 |
+
return subprocess.getoutput(cmd)
|
130 |
+
|
131 |
+
def exists(path):
|
132 |
+
path = get_absolute_path(path)
|
133 |
+
return os.path.exists(path)
|
134 |
+
|
135 |
+
def not_exists(path):
|
136 |
+
return not exists(path)
|
137 |
+
|
138 |
+
def load_mat(path):
|
139 |
+
import scipy.io as sio # type: ignore
|
140 |
+
path = get_absolute_path(path)
|
141 |
+
return sio.loadmat(path)
|
142 |
+
|
143 |
+
def dump_mat(path, dict_obj, append = True):
|
144 |
+
import scipy.io as sio # type: ignore
|
145 |
+
path = get_absolute_path(path)
|
146 |
+
make_parent_dir(path)
|
147 |
+
sio.savemat(file_name = path, mdict = dict_obj, appendmat = append)
|
148 |
+
|
149 |
+
def dir_mat(path):
|
150 |
+
'''
|
151 |
+
list the variables in mat file.
|
152 |
+
return a list: [(name, shape, dtype), ...]
|
153 |
+
'''
|
154 |
+
import scipy.io as sio # type: ignore
|
155 |
+
path = get_absolute_path(path)
|
156 |
+
return sio.whosmat(path)
|
157 |
+
|
158 |
+
SIZE_UNIT_K = 1024
|
159 |
+
SIZE_UNIT_M = SIZE_UNIT_K ** 2
|
160 |
+
SIZE_UNIT_G = SIZE_UNIT_K ** 3
|
161 |
+
def get_file_size(path, unit = SIZE_UNIT_K):
|
162 |
+
size = os.path.getsize(get_absolute_path(path))
|
163 |
+
return size * 1.0 / unit
|
164 |
+
|
165 |
+
|
166 |
+
def create_h5(path):
|
167 |
+
import h5py # type: ignore
|
168 |
+
path = get_absolute_path(path)
|
169 |
+
make_parent_dir(path)
|
170 |
+
return h5py.File(path, 'w');
|
171 |
+
|
172 |
+
def open_h5(path, mode = 'r'):
|
173 |
+
import h5py
|
174 |
+
path = get_absolute_path(path)
|
175 |
+
return h5py.File(path, mode);
|
176 |
+
|
177 |
+
def read_h5(h5, key):
|
178 |
+
return h5[key][:]
|
179 |
+
def read_h5_attrs(h5, key, attrs):
|
180 |
+
return h5[key].attrs[attrs]
|
181 |
+
|
182 |
+
def copy(src, dest):
|
183 |
+
io.make_parent_dir(dest)
|
184 |
+
import shutil
|
185 |
+
shutil.copy(get_absolute_path(src), get_absolute_path(dest))
|
186 |
+
|
187 |
+
cp = copy
|
188 |
+
|
189 |
+
def remove(p):
|
190 |
+
import os
|
191 |
+
os.remove(get_absolute_path(p))
|
192 |
+
rm = remove
|
193 |
+
|
194 |
+
def search(pattern, path, file_only = True):
|
195 |
+
"""
|
196 |
+
Search files whose name matches the give pattern. The search scope
|
197 |
+
is the directory and sub-directories of 'path'.
|
198 |
+
"""
|
199 |
+
path = get_absolute_path(path)
|
200 |
+
pattern_here = io.join_path(path, pattern)
|
201 |
+
targets = []
|
202 |
+
|
203 |
+
# find matchings in current directory
|
204 |
+
candidates = find_files(pattern_here)
|
205 |
+
for can in candidates:
|
206 |
+
if io.is_dir(can) and file_only:
|
207 |
+
continue
|
208 |
+
else:
|
209 |
+
targets.append(can)
|
210 |
+
|
211 |
+
# find matching in sub-dirs
|
212 |
+
files = ls(path)
|
213 |
+
for f in files:
|
214 |
+
fpath = io.join_path(path, f)
|
215 |
+
if is_dir(fpath):
|
216 |
+
targets_in_sub_dir = search(pattern, fpath, file_only)
|
217 |
+
targets.extend(targets_in_sub_dir)
|
218 |
+
return targets
|
219 |
+
|
220 |
+
def dump_json(path, data):
|
221 |
+
import ujson as json
|
222 |
+
path = get_absolute_path(path)
|
223 |
+
make_parent_dir(path)
|
224 |
+
|
225 |
+
with open(path, 'w') as f:
|
226 |
+
json.dump(data, f)
|
227 |
+
return path
|
228 |
+
|
229 |
+
def load_json(path):
|
230 |
+
import ujson as json
|
231 |
+
path = get_absolute_path(path)
|
232 |
+
with open(path, 'r') as f:
|
233 |
+
return json.load(f)
|