anikde commited on
Commit
85f0ffb
·
1 Parent(s): 82accf8

added textbpn++ detection module

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. IndicPhotoOCR/detection/textbpn/__init__.py +0 -0
  2. IndicPhotoOCR/detection/textbpn/cfglib/config.py +90 -0
  3. IndicPhotoOCR/detection/textbpn/cfglib/option.py +123 -0
  4. IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth +3 -0
  5. IndicPhotoOCR/detection/textbpn/network/Reg_loss.py +196 -0
  6. IndicPhotoOCR/detection/textbpn/network/Seg_loss.py +107 -0
  7. IndicPhotoOCR/detection/textbpn/network/__init__.py +1 -0
  8. IndicPhotoOCR/detection/textbpn/network/backbone/__init__.py +1 -0
  9. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile +6 -0
  10. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile.sh +6 -0
  11. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/__init__.py +13 -0
  12. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/__init__.py +0 -0
  13. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_conv.py +181 -0
  14. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_pool.py +69 -0
  15. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/__init__.py +0 -0
  16. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_conv.py +157 -0
  17. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_pool.py +172 -0
  18. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/setup.py +19 -0
  19. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda.cpp +695 -0
  20. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda_kernel.cu +866 -0
  21. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda.cpp +87 -0
  22. IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda_kernel.cu +364 -0
  23. IndicPhotoOCR/detection/textbpn/network/backbone/resnet.py +336 -0
  24. IndicPhotoOCR/detection/textbpn/network/backbone/vgg.py +60 -0
  25. IndicPhotoOCR/detection/textbpn/network/layers/Adaptive_Deformation.py +88 -0
  26. IndicPhotoOCR/detection/textbpn/network/layers/CircConv.py +91 -0
  27. IndicPhotoOCR/detection/textbpn/network/layers/GCN.py +77 -0
  28. IndicPhotoOCR/detection/textbpn/network/layers/GraphConv.py +45 -0
  29. IndicPhotoOCR/detection/textbpn/network/layers/RNN.py +35 -0
  30. IndicPhotoOCR/detection/textbpn/network/layers/Transformer.py +140 -0
  31. IndicPhotoOCR/detection/textbpn/network/layers/Transformer_old.py +171 -0
  32. IndicPhotoOCR/detection/textbpn/network/layers/__init__.py +0 -0
  33. IndicPhotoOCR/detection/textbpn/network/layers/gcn_utils.py +150 -0
  34. IndicPhotoOCR/detection/textbpn/network/layers/model_block.py +149 -0
  35. IndicPhotoOCR/detection/textbpn/network/layers/position_encoding.py +89 -0
  36. IndicPhotoOCR/detection/textbpn/network/layers/resnet.py +73 -0
  37. IndicPhotoOCR/detection/textbpn/network/layers/resnet_dcn.py +59 -0
  38. IndicPhotoOCR/detection/textbpn/network/layers/vgg.py +62 -0
  39. IndicPhotoOCR/detection/textbpn/network/loss.py +187 -0
  40. IndicPhotoOCR/detection/textbpn/network/loss_org.py +136 -0
  41. IndicPhotoOCR/detection/textbpn/network/textnet.py +216 -0
  42. IndicPhotoOCR/detection/textbpn/output.png +3 -0
  43. IndicPhotoOCR/detection/textbpn/textbpnpp_detector.py +197 -0
  44. IndicPhotoOCR/detection/textbpn/util/__init__.py +2 -0
  45. IndicPhotoOCR/detection/textbpn/util/augmentation.py +794 -0
  46. IndicPhotoOCR/detection/textbpn/util/canvas.py +55 -0
  47. IndicPhotoOCR/detection/textbpn/util/detection.py +48 -0
  48. IndicPhotoOCR/detection/textbpn/util/eval.py +228 -0
  49. IndicPhotoOCR/detection/textbpn/util/graph.py +309 -0
  50. IndicPhotoOCR/detection/textbpn/util/io.py +233 -0
IndicPhotoOCR/detection/textbpn/__init__.py ADDED
File without changes
IndicPhotoOCR/detection/textbpn/cfglib/config.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict
2
+ import torch
3
+ import os
4
+
5
+ config = EasyDict()
6
+
7
+
8
+ # Normalize image
9
+ config.means = (0.485, 0.456, 0.406)
10
+ config.stds = (0.229, 0.224, 0.225)
11
+
12
+ config.gpu = "1"
13
+
14
+ # Experiment name #
15
+ config.exp_name = "Synthtext"
16
+
17
+ # dataloader jobs number
18
+ config.num_workers = 24
19
+
20
+ # batch_size
21
+ config.batch_size = 12
22
+
23
+ # training epoch number
24
+ config.max_epoch = 200
25
+
26
+ config.start_epoch = 0
27
+
28
+ # learning rate
29
+ config.lr = 1e-4
30
+
31
+ # using GPU
32
+ config.cuda = False
33
+
34
+ config.output_dir = 'output'
35
+
36
+ config.input_size = 640
37
+
38
+ # max polygon per image
39
+ # synText, total-text:64; CTW1500: 64; icdar: 64; MLT: 32; TD500: 64.
40
+ config.max_annotation = 64
41
+
42
+ # adj num for graph
43
+ config.adj_num = 4
44
+
45
+ # control points number
46
+ config.num_points = 20
47
+
48
+ # use hard examples (annotated as '#')
49
+ config.use_hard = True
50
+
51
+ # Load data into memory at one time
52
+ config.load_memory = False
53
+
54
+ # prediction on 1/scale feature map
55
+ config.scale = 1
56
+
57
+ # # clip gradient of loss
58
+ config.grad_clip = 25
59
+
60
+ # demo tcl threshold
61
+ config.dis_threshold = 0.4
62
+
63
+ config.cls_threshold = 0.8
64
+
65
+ # Contour approximation factor
66
+ config.approx_factor = 0.004
67
+
68
+
69
+ def update_config(config, extra_config):
70
+ for k, v in vars(extra_config).items():
71
+ config[k] = v
72
+ # print(config.gpu)
73
+ # config.device = torch.device('cuda') if config.cuda else torch.device('cpu')
74
+ config.device = torch.device('cpu')
75
+
76
+
77
+ def print_config(config):
78
+ print('==========Options============')
79
+ for k, v in config.items():
80
+ print('{}: {}'.format(k, v))
81
+ print('=============End=============')
82
+
83
+
84
+
85
+ ################### MY Settings ##################
86
+ config.resume=True
87
+
88
+ config.device="cpu"
89
+
90
+ # config.test_size = [224, 224]
IndicPhotoOCR/detection/textbpn/cfglib/option.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ import torch.backends.cudnn as cudnn
5
+
6
+ from datetime import datetime
7
+
8
+
9
+ def str2bool(v):
10
+ return v.lower() in ("yes", "true", "t", "1")
11
+
12
+
13
+ def arg2str(args):
14
+ args_dict = vars(args)
15
+ option_str = datetime.now().strftime('%b%d_%H-%M-%S') + '\n'
16
+
17
+ for k, v in sorted(args_dict.items()):
18
+ option_str += ('{}: {}\n'.format(str(k), str(v)))
19
+
20
+ return option_str
21
+
22
+
23
+ class BaseOptions(object):
24
+
25
+ def __init__(self):
26
+
27
+ self.parser = argparse.ArgumentParser()
28
+
29
+ # basic opts
30
+ self.parser.add_argument('--exp_name', default="TD500", type=str,
31
+ choices=['Synthtext', 'Totaltext', 'Ctw1500','Icdar2015',
32
+ "MLT2017", 'TD500', "MLT2019", "ArT", "ALL"], help='Experiment name')
33
+ self.parser.add_argument("--gpu", default="1", help="set gpu id", type=str)
34
+ self.parser.add_argument('--resume', default=None, type=str, help='Path to target resume checkpoint')
35
+ self.parser.add_argument('--num_workers', default=24, type=int, help='Number of workers used in dataloading')
36
+ self.parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model')
37
+ self.parser.add_argument('--mgpu', action='store_true', help='Use multi-gpu to train model')
38
+ self.parser.add_argument('--save_dir', default='./model/', help='Path to save checkpoint models')
39
+ self.parser.add_argument('--vis_dir', default='./vis/', help='Path to save visualization images')
40
+ self.parser.add_argument('--log_dir', default='./logs/', help='Path to tensorboard log')
41
+ self.parser.add_argument('--loss', default='CrossEntropyLoss', type=str, help='Training Loss')
42
+ # self.parser.add_argument('--input_channel', default=1, type=int, help='number of input channels' )
43
+ self.parser.add_argument('--pretrain', default=False, type=str2bool, help='Pretrained AutoEncoder model')
44
+ self.parser.add_argument('--verbose', '-v', default=True, type=str2bool, help='Whether to output debug info')
45
+ self.parser.add_argument('--viz', action='store_true', help='Whether to output debug info')
46
+ # self.parser.add_argument('--viz', default=True, type=str2bool, help='Whether to output debug info')
47
+
48
+ # train opts
49
+ self.parser.add_argument('--max_epoch', default=250, type=int, help='Max epochs')
50
+ self.parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
51
+ self.parser.add_argument('--lr_adjust', default='fix',
52
+ choices=['fix', 'poly'], type=str, help='Learning Rate Adjust Strategy')
53
+ self.parser.add_argument('--stepvalues', default=[], nargs='+', type=int, help='# of iter to change lr')
54
+ self.parser.add_argument('--weight_decay', '--wd', default=0., type=float, help='Weight decay for SGD')
55
+ self.parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD lr')
56
+ self.parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
57
+ self.parser.add_argument('--batch_size', default=6, type=int, help='Batch size for training')
58
+ self.parser.add_argument('--optim', default='Adam', type=str, choices=['SGD', 'Adam'], help='Optimizer')
59
+ self.parser.add_argument('--save_freq', default=5, type=int, help='save weights every # epoch')
60
+ self.parser.add_argument('--display_freq', default=10, type=int, help='display training metrics every # iter')
61
+ self.parser.add_argument('--viz_freq', default=50, type=int, help='visualize training process every # iter')
62
+ self.parser.add_argument('--log_freq', default=10000, type=int, help='log to tensorboard every # iterations')
63
+ self.parser.add_argument('--val_freq', default=1000, type=int, help='do validation every # iterations')
64
+
65
+ # backbone
66
+ self.parser.add_argument('--scale', default=1, type=int, help='prediction on 1/scale feature map')
67
+ self.parser.add_argument('--net', default='resnet50', type=str,
68
+ choices=['vgg', 'resnet50', 'resnet18',
69
+ "deformable_resnet18", "deformable_resnet50"],
70
+ help='Network architecture')
71
+ # data args
72
+ self.parser.add_argument('--load_memory', default=False, type=str2bool, help='Load data into memory')
73
+ self.parser.add_argument('--rescale', type=float, default=255.0, help='rescale factor')
74
+ self.parser.add_argument('--input_size', default=640, type=int, help='model input size')
75
+ self.parser.add_argument('--test_size', default=[640, 960], type=int, nargs='+', help='test size')
76
+
77
+ # eval args00
78
+ self.parser.add_argument('--checkepoch', default=1070, type=int, help='Load checkpoint number')
79
+ self.parser.add_argument('--start_epoch', default=0, type=int, help='start epoch number')
80
+ self.parser.add_argument('--cls_threshold', default=0.875, type=float, help='threshold of pse')
81
+ self.parser.add_argument('--dis_threshold', default=0.35, type=float, help='filter the socre < score_i')
82
+
83
+ # demo args
84
+ self.parser.add_argument('--img_root', default=None, type=str, help='Path to deploy images')
85
+
86
+ def parse(self, fixed=None):
87
+
88
+ if fixed is not None:
89
+ args = self.parser.parse_args(fixed)
90
+ else:
91
+ args = self.parser.parse_args()
92
+
93
+ return args
94
+
95
+ def initialize(self, fixed=None):
96
+
97
+ # Parse options
98
+ self.args = self.parse(fixed)
99
+ os.environ['CUDA_VISIBLE_DEVICES'] = self.args.gpu
100
+
101
+ # Setting default torch Tensor type
102
+ if self.args.cuda and torch.cuda.is_available():
103
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
104
+ cudnn.benchmark = True
105
+ else:
106
+ torch.set_default_tensor_type('torch.FloatTensor')
107
+
108
+ # Create weights saving directory
109
+ if not os.path.exists(self.args.save_dir):
110
+ os.mkdir(self.args.save_dir)
111
+
112
+ # Create weights saving directory of target model
113
+ model_save_path = os.path.join(self.args.save_dir, self.args.exp_name)
114
+
115
+ if not os.path.exists(model_save_path):
116
+ os.mkdir(model_save_path)
117
+
118
+ return self.args
119
+
120
+ def update(self, args, extra_options):
121
+
122
+ for k, v in extra_options.items():
123
+ setattr(args, k, v)
IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b735b9c93c8758972d3b8cfd3ef8e1c09afa8cd9106f4cb11406300b141b1d78
3
+ size 145703602
IndicPhotoOCR/detection/textbpn/network/Reg_loss.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 10/1/21
3
+ # @Author : GXYM
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class PolyMatchingLoss(nn.Module):
11
+ def __init__(self, pnum, device, loss_type="L1"):
12
+ super(PolyMatchingLoss, self).__init__()
13
+
14
+ self.pnum = pnum
15
+ self.device = device
16
+ self.loss_type = loss_type
17
+ self.smooth_L1 = F.smooth_l1_loss
18
+ self.L2_loss = torch.nn.MSELoss(reduce=False, size_average=False)
19
+
20
+ batch_size = 1
21
+ pidxall = np.zeros(shape=(batch_size, pnum, pnum), dtype=np.int32)
22
+ for b in range(batch_size):
23
+ for i in range(pnum):
24
+ pidx = (np.arange(pnum) + i) % pnum
25
+ pidxall[b, i] = pidx
26
+
27
+ pidxall = torch.from_numpy(np.reshape(pidxall, newshape=(batch_size, -1))).to(device)
28
+ self.feature_id = pidxall.unsqueeze_(2).long().expand(pidxall.size(0), pidxall.size(1), 2).detach()
29
+ print(self.feature_id.shape)
30
+
31
+ def match_loss(self, pred, gt):
32
+ batch_size = pred.shape[0]
33
+ feature_id = self.feature_id.expand(batch_size, self.feature_id.size(1), 2)
34
+
35
+ gt_expand = torch.gather(gt, 1, feature_id).view(batch_size, self.pnum, self.pnum, 2)
36
+ pred_expand = pred.unsqueeze(1)
37
+
38
+ if self.loss_type == "L2":
39
+ dis = self.L2_loss(pred_expand, gt_expand)
40
+ dis = dis.sum(3).sqrt().mean(2)
41
+ elif self.loss_type == "L1":
42
+ dis = self.smooth_L1(pred_expand, gt_expand, reduction='none')
43
+ dis = dis.sum(3).mean(2)
44
+
45
+ min_dis, min_id = torch.min(dis, dim=1, keepdim=True)
46
+
47
+ return min_dis
48
+
49
+ def forward(self, pred_list, gt):
50
+ loss = torch.tensor(0.)
51
+ for pred in pred_list:
52
+ loss += torch.mean(self.match_loss(pred, gt))
53
+
54
+ return loss / torch.tensor(len(pred_list))
55
+
56
+ # los = []
57
+ # for pred in pred_list:
58
+ # los.append(self.match_loss(pred, gt))
59
+ #
60
+ # los_b = torch.tensor(0.)
61
+ # loss_c = torch.tensor(0.)
62
+ # for i, _ in enumerate(los):
63
+ # los_b += torch.mean(los[i])
64
+ # loss_c += (torch.mean(torch.clamp(los[i] - los[i - 1], min=0.0)) if i > 0 else torch.tensor(0.))
65
+ # loss = los_b / torch.tensor(len(los)) + 0.5*loss_c / torch.tensor(len(los)-1)
66
+ #
67
+ # return loss
68
+
69
+
70
+ class AttentionLoss(nn.Module):
71
+ def __init__(self, beta=4, gamma=0.5):
72
+ super(AttentionLoss, self).__init__()
73
+
74
+ self.beta = beta
75
+ self.gamma = gamma
76
+
77
+ def forward(self, pred, gt):
78
+ num_pos = torch.sum(gt)
79
+ num_neg = torch.sum(1 - gt)
80
+ alpha = num_neg / (num_pos + num_neg)
81
+ edge_beta = torch.pow(self.beta, torch.pow(1 - pred, self.gamma))
82
+ bg_beta = torch.pow(self.beta, torch.pow(pred, self.gamma))
83
+
84
+ loss = 0
85
+ loss = loss - alpha * edge_beta * torch.log(pred) * gt
86
+ loss = loss - (1 - alpha) * bg_beta * torch.log(1 - pred) * (1 - gt)
87
+ return torch.mean(loss)
88
+
89
+
90
+ class GeoCrossEntropyLoss(nn.Module):
91
+ def __init__(self):
92
+ super(GeoCrossEntropyLoss, self).__init__()
93
+
94
+ def forward(self, output, target, poly):
95
+ output = torch.nn.functional.softmax(output, dim=1)
96
+ output = torch.log(torch.clamp(output, min=1e-4))
97
+ poly = poly.view(poly.size(0), 4, poly.size(1) // 4, 2)
98
+ target = target[..., None, None].expand(poly.size(0), poly.size(1), 1, poly.size(3))
99
+ target_poly = torch.gather(poly, 2, target)
100
+ sigma = (poly[:, :, 0] - poly[:, :, 1]).pow(2).sum(2, keepdim=True)
101
+ kernel = torch.exp(-(poly - target_poly).pow(2).sum(3) / (sigma / 3))
102
+ loss = -(output * kernel.transpose(2, 1)).sum(1).mean()
103
+ return loss
104
+
105
+
106
+ class AELoss(nn.Module):
107
+ def __init__(self):
108
+ super(AELoss, self).__init__()
109
+
110
+ def forward(self, ae, ind, ind_mask):
111
+ """
112
+ ae: [b, 1, h, w]
113
+ ind: [b, max_objs, max_parts]
114
+ ind_mask: [b, max_objs, max_parts]
115
+ obj_mask: [b, max_objs]
116
+ """
117
+ # first index
118
+ b, _, h, w = ae.shape
119
+ b, max_objs, max_parts = ind.shape
120
+ obj_mask = torch.sum(ind_mask, dim=2) != 0
121
+
122
+ ae = ae.view(b, h * w, 1)
123
+ seed_ind = ind.view(b, max_objs * max_parts, 1)
124
+ tag = ae.gather(1, seed_ind).view(b, max_objs, max_parts)
125
+
126
+ # compute the mean
127
+ tag_mean = tag * ind_mask
128
+ tag_mean = tag_mean.sum(2) / (ind_mask.sum(2) + 1e-4)
129
+
130
+ # pull ae of the same object to their mean
131
+ pull_dist = (tag - tag_mean.unsqueeze(2)).pow(2) * ind_mask
132
+ obj_num = obj_mask.sum(dim=1).float()
133
+ pull = (pull_dist.sum(dim=(1, 2)) / (obj_num + 1e-4)).sum()
134
+ pull /= b
135
+
136
+ # push away the mean of different objects
137
+ push_dist = torch.abs(tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2))
138
+ push_dist = 1 - push_dist
139
+ push_dist = nn.functional.relu(push_dist, inplace=True)
140
+ obj_mask = (obj_mask.unsqueeze(1) + obj_mask.unsqueeze(2)) == 2
141
+ push_dist = push_dist * obj_mask.float()
142
+ push = ((push_dist.sum(dim=(1, 2)) - obj_num) / (obj_num * (obj_num - 1) + 1e-4)).sum()
143
+ push /= b
144
+ return pull, push
145
+
146
+
147
+ def smooth_l1_loss(inputs, target, sigma=9.0):
148
+ try:
149
+ diff = torch.abs(inputs - target)
150
+ less_one = (diff < 1.0 / sigma).float()
151
+ loss = less_one * 0.5 * diff ** 2 * sigma \
152
+ + torch.abs(torch.tensor(1.0) - less_one) * (diff - 0.5 / sigma)
153
+ loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0)
154
+ except Exception as e:
155
+ print('RPN_REGR_Loss Exception:', e)
156
+ loss = torch.tensor(0.0)
157
+
158
+ return loss
159
+
160
+
161
+ def _neg_loss(pred, gt):
162
+ ''' Modified focal loss. Exactly the same as CornerNet.
163
+ Runs faster and costs a little bit more memory
164
+ Arguments:
165
+ pred (batch x c x h x w)
166
+ gt_regr (batch x c x h x w)
167
+ '''
168
+ pos_inds = gt.eq(1).float()
169
+ neg_inds = gt.lt(1).float()
170
+
171
+ neg_weights = torch.pow(1 - gt, 4)
172
+
173
+ loss = 0
174
+
175
+ pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
176
+ neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
177
+
178
+ num_pos = pos_inds.float().sum()
179
+ pos_loss = pos_loss.sum()
180
+ neg_loss = neg_loss.sum()
181
+
182
+ if num_pos == 0:
183
+ loss = loss - neg_loss
184
+ else:
185
+ loss = loss - (pos_loss + neg_loss) / num_pos
186
+ return loss
187
+
188
+
189
+ class FocalLoss(nn.Module):
190
+ '''nn.Module warpper for focal loss'''
191
+ def __init__(self):
192
+ super(FocalLoss, self).__init__()
193
+ self.neg_loss = _neg_loss
194
+
195
+ def forward(self, out, target):
196
+ return self.neg_loss(out, target)
IndicPhotoOCR/detection/textbpn/network/Seg_loss.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 10/1/21
3
+ # @Author : GXYM
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+
9
+ class SegmentLoss(nn.Module):
10
+ def __init__(self, Lambda, ratio=3, reduction='mean'):
11
+ """Implement PSE Loss.
12
+ """
13
+ super(SegmentLoss, self).__init__()
14
+ assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
15
+ self.Lambda = Lambda
16
+ self.ratio = ratio
17
+ self.reduction = reduction
18
+
19
+ def forward(self, outputs, labels, training_masks, th=0.5):
20
+ texts = outputs[:, -1, :, :]
21
+ kernels = outputs[:, :-1, :, :]
22
+ gt_texts = labels[:, -1, :, :]
23
+ gt_kernels = labels[:, :-1, :, :]
24
+
25
+ selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
26
+ selected_masks = selected_masks.to(outputs.device)
27
+
28
+ loss_text = self.dice_loss(texts, gt_texts, selected_masks)
29
+
30
+ loss_kernels = []
31
+ # mask0 = torch.sigmoid(texts).data.cpu().numpy()
32
+ mask0 = texts.data.cpu().numpy()
33
+ mask1 = training_masks.data.cpu().numpy()
34
+ selected_masks = ((mask0 > th) & (mask1 > th)).astype('float32')
35
+ selected_masks = torch.from_numpy(selected_masks).float()
36
+ selected_masks = selected_masks.to(outputs.device)
37
+ kernels_num = gt_kernels.size()[1]
38
+ for i in range(kernels_num):
39
+ kernel_i = kernels[:, i, :, :]
40
+ gt_kernel_i = gt_kernels[:, i, :, :]
41
+ loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
42
+ loss_kernels.append(loss_kernel_i)
43
+ loss_kernels = torch.stack(loss_kernels).mean(0)
44
+ if self.reduction == 'mean':
45
+ loss_text = loss_text.mean()
46
+ loss_kernels = loss_kernels.mean()
47
+ elif self.reduction == 'sum':
48
+ loss_text = loss_text.sum()
49
+ loss_kernels = loss_kernels.sum()
50
+
51
+ loss = self.Lambda *loss_text + (1-self.Lambda)*loss_kernels
52
+ return loss_text, loss_kernels, loss
53
+
54
+ def dice_loss(self, input, target, mask):
55
+ # input = torch.sigmoid(input)
56
+
57
+ input = input.contiguous().view(input.size()[0], -1)
58
+ target = target.contiguous().view(target.size()[0], -1)
59
+ mask = mask.contiguous().view(mask.size()[0], -1)
60
+
61
+ input = input * mask
62
+ target = (target.float()) * mask
63
+
64
+ a = torch.sum(input * target, 1)
65
+ b = torch.sum(input * input, 1) + 0.001
66
+ c = torch.sum(target * target, 1) + 0.001
67
+ d = (2 * a) / (b + c)
68
+ return 1 - d
69
+
70
+ def ohem_single(self, score, gt_text, training_mask, th=0.5):
71
+ pos_num = (int)(np.sum(gt_text > th)) - (int)(np.sum((gt_text > th) & (training_mask <= th)))
72
+
73
+ if pos_num == 0:
74
+ # selected_mask = gt_text.copy() * 0 # may be not good
75
+ selected_mask = training_mask
76
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
77
+ return selected_mask
78
+
79
+ neg_num = (int)(np.sum(gt_text <= th))
80
+ neg_num = (int)(min(pos_num * 3, neg_num))
81
+
82
+ if neg_num == 0:
83
+ selected_mask = training_mask
84
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
85
+ return selected_mask
86
+
87
+ neg_score = score[gt_text <= th]
88
+ # 将负样本得分从高到低排序
89
+ neg_score_sorted = np.sort(-neg_score)
90
+ threshold = -neg_score_sorted[neg_num - 1]
91
+ # 选出 得分高的 负样本 和正样本 的 mask
92
+ selected_mask = ((score >= threshold) | (gt_text > th)) & (training_mask > th)
93
+ selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
94
+ return selected_mask
95
+
96
+ def ohem_batch(self, scores, gt_texts, training_masks):
97
+ scores = scores.data.cpu().numpy()
98
+ gt_texts = gt_texts.data.cpu().numpy()
99
+ training_masks = training_masks.data.cpu().numpy()
100
+ selected_masks = []
101
+ for i in range(scores.shape[0]):
102
+ selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
103
+
104
+ selected_masks = np.concatenate(selected_masks, 0)
105
+ selected_masks = torch.from_numpy(selected_masks).float()
106
+
107
+ return selected_masks
IndicPhotoOCR/detection/textbpn/network/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import *
IndicPhotoOCR/detection/textbpn/network/backbone/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .resnet import resnet18, resnet34, resnet50, resnet101, deformable_resnet50, deformable_resnet18
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ rm *.so
3
+ python setup.py build_ext --inplace
4
+ rm -rf ./build
5
+
6
+
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/Makefile.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ rm *.so
3
+ python setup.py build_ext --inplace
4
+ rm -rf ./build
5
+
6
+
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .functions.deform_conv import deform_conv, modulated_deform_conv
2
+ from .functions.deform_pool import deform_roi_pooling
3
+ from .modules.deform_conv import (DeformConv, ModulatedDeformConv,
4
+ DeformConvPack, ModulatedDeformConvPack)
5
+ from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack,
6
+ ModulatedDeformRoIPoolingPack)
7
+
8
+ __all__ = [
9
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
10
+ 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack',
11
+ 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv',
12
+ 'deform_roi_pooling'
13
+ ]
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/__init__.py ADDED
File without changes
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_conv.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+ from torch.nn.modules.utils import _pair
4
+
5
+ from .. import deform_conv_cuda
6
+
7
+
8
+ class DeformConvFunction(Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx,
12
+ input,
13
+ offset,
14
+ weight,
15
+ stride=1,
16
+ padding=0,
17
+ dilation=1,
18
+ groups=1,
19
+ deformable_groups=1,
20
+ im2col_step=64):
21
+ if input is not None and input.dim() != 4:
22
+ raise ValueError(
23
+ "Expected 4D tensor as input, got {}D tensor instead.".format(
24
+ input.dim()))
25
+ ctx.stride = _pair(stride)
26
+ ctx.padding = _pair(padding)
27
+ ctx.dilation = _pair(dilation)
28
+ ctx.groups = groups
29
+ ctx.deformable_groups = deformable_groups
30
+ ctx.im2col_step = im2col_step
31
+
32
+ ctx.save_for_backward(input, offset, weight)
33
+
34
+ output = input.new_empty(
35
+ DeformConvFunction._output_size(input, weight, ctx.padding,
36
+ ctx.dilation, ctx.stride))
37
+
38
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
39
+
40
+ if not input.is_cuda:
41
+ raise NotImplementedError
42
+ else:
43
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
44
+ assert (input.shape[0] %
45
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
46
+ deform_conv_cuda.deform_conv_forward_cuda(
47
+ input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
48
+ weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
49
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
50
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups,
51
+ cur_im2col_step)
52
+ return output
53
+
54
+ @staticmethod
55
+ def backward(ctx, grad_output):
56
+ input, offset, weight = ctx.saved_tensors
57
+
58
+ grad_input = grad_offset = grad_weight = None
59
+
60
+ if not grad_output.is_cuda:
61
+ raise NotImplementedError
62
+ else:
63
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
64
+ assert (input.shape[0] %
65
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
66
+
67
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
68
+ grad_input = torch.zeros_like(input)
69
+ grad_offset = torch.zeros_like(offset)
70
+ deform_conv_cuda.deform_conv_backward_input_cuda(
71
+ input, offset, grad_output, grad_input,
72
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
73
+ weight.size(2), ctx.stride[1], ctx.stride[0],
74
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
75
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups,
76
+ cur_im2col_step)
77
+
78
+ if ctx.needs_input_grad[2]:
79
+ grad_weight = torch.zeros_like(weight)
80
+ deform_conv_cuda.deform_conv_backward_parameters_cuda(
81
+ input, offset, grad_output,
82
+ grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
83
+ weight.size(2), ctx.stride[1], ctx.stride[0],
84
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
85
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
86
+ cur_im2col_step)
87
+
88
+ return (grad_input, grad_offset, grad_weight, None, None, None, None,
89
+ None)
90
+
91
+ @staticmethod
92
+ def _output_size(input, weight, padding, dilation, stride):
93
+ channels = weight.size(0)
94
+ output_size = (input.size(0), channels)
95
+ for d in range(input.dim() - 2):
96
+ in_size = input.size(d + 2)
97
+ pad = padding[d]
98
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
99
+ stride_ = stride[d]
100
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
101
+ if not all(map(lambda s: s > 0, output_size)):
102
+ raise ValueError(
103
+ "convolution input is too small (output would be {})".format(
104
+ 'x'.join(map(str, output_size))))
105
+ return output_size
106
+
107
+
108
+ class ModulatedDeformConvFunction(Function):
109
+
110
+ @staticmethod
111
+ def forward(ctx,
112
+ input,
113
+ offset,
114
+ mask,
115
+ weight,
116
+ bias=None,
117
+ stride=1,
118
+ padding=0,
119
+ dilation=1,
120
+ groups=1,
121
+ deformable_groups=1):
122
+ ctx.stride = stride
123
+ ctx.padding = padding
124
+ ctx.dilation = dilation
125
+ ctx.groups = groups
126
+ ctx.deformable_groups = deformable_groups
127
+ ctx.with_bias = bias is not None
128
+ if not ctx.with_bias:
129
+ bias = input.new_empty(1) # fake tensor
130
+ if not input.is_cuda:
131
+ raise NotImplementedError
132
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
133
+ or input.requires_grad:
134
+ ctx.save_for_backward(input, offset, mask, weight, bias)
135
+ output = input.new_empty(
136
+ ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
137
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
138
+ deform_conv_cuda.modulated_deform_conv_cuda_forward(
139
+ input, weight, bias, ctx._bufs[0], offset, mask, output,
140
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
141
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
142
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
143
+ return output
144
+
145
+ @staticmethod
146
+ def backward(ctx, grad_output):
147
+ if not grad_output.is_cuda:
148
+ raise NotImplementedError
149
+ input, offset, mask, weight, bias = ctx.saved_tensors
150
+ grad_input = torch.zeros_like(input)
151
+ grad_offset = torch.zeros_like(offset)
152
+ grad_mask = torch.zeros_like(mask)
153
+ grad_weight = torch.zeros_like(weight)
154
+ grad_bias = torch.zeros_like(bias)
155
+ deform_conv_cuda.modulated_deform_conv_cuda_backward(
156
+ input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
157
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
158
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
159
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
160
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
161
+ if not ctx.with_bias:
162
+ grad_bias = None
163
+
164
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
165
+ None, None, None, None, None)
166
+
167
+ @staticmethod
168
+ def _infer_shape(ctx, input, weight):
169
+ n = input.size(0)
170
+ channels_out = weight.size(0)
171
+ height, width = input.shape[2:4]
172
+ kernel_h, kernel_w = weight.shape[2:4]
173
+ height_out = (height + 2 * ctx.padding -
174
+ (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
175
+ width_out = (width + 2 * ctx.padding -
176
+ (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
177
+ return n, channels_out, height_out, width_out
178
+
179
+
180
+ deform_conv = DeformConvFunction.apply
181
+ modulated_deform_conv = ModulatedDeformConvFunction.apply
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/functions/deform_pool.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+
4
+ from .. import deform_pool_cuda
5
+
6
+
7
+ class DeformRoIPoolingFunction(Function):
8
+
9
+ @staticmethod
10
+ def forward(ctx,
11
+ data,
12
+ rois,
13
+ offset,
14
+ spatial_scale,
15
+ out_size,
16
+ out_channels,
17
+ no_trans,
18
+ group_size=1,
19
+ part_size=None,
20
+ sample_per_part=4,
21
+ trans_std=.0):
22
+ ctx.spatial_scale = spatial_scale
23
+ ctx.out_size = out_size
24
+ ctx.out_channels = out_channels
25
+ ctx.no_trans = no_trans
26
+ ctx.group_size = group_size
27
+ ctx.part_size = out_size if part_size is None else part_size
28
+ ctx.sample_per_part = sample_per_part
29
+ ctx.trans_std = trans_std
30
+
31
+ assert 0.0 <= ctx.trans_std <= 1.0
32
+ if not data.is_cuda:
33
+ raise NotImplementedError
34
+
35
+ n = rois.shape[0]
36
+ output = data.new_empty(n, out_channels, out_size, out_size)
37
+ output_count = data.new_empty(n, out_channels, out_size, out_size)
38
+ deform_pool_cuda.deform_psroi_pooling_cuda_forward(
39
+ data, rois, offset, output, output_count, ctx.no_trans,
40
+ ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size,
41
+ ctx.part_size, ctx.sample_per_part, ctx.trans_std)
42
+
43
+ if data.requires_grad or rois.requires_grad or offset.requires_grad:
44
+ ctx.save_for_backward(data, rois, offset)
45
+ ctx.output_count = output_count
46
+
47
+ return output
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ if not grad_output.is_cuda:
52
+ raise NotImplementedError
53
+
54
+ data, rois, offset = ctx.saved_tensors
55
+ output_count = ctx.output_count
56
+ grad_input = torch.zeros_like(data)
57
+ grad_rois = None
58
+ grad_offset = torch.zeros_like(offset)
59
+
60
+ deform_pool_cuda.deform_psroi_pooling_cuda_backward(
61
+ grad_output, data, rois, offset, output_count, grad_input,
62
+ grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels,
63
+ ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part,
64
+ ctx.trans_std)
65
+ return (grad_input, grad_rois, grad_offset, None, None, None, None,
66
+ None, None, None, None)
67
+
68
+
69
+ deform_roi_pooling = DeformRoIPoolingFunction.apply
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/__init__.py ADDED
File without changes
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_conv.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.modules.utils import _pair
6
+
7
+ from ..functions.deform_conv import deform_conv, modulated_deform_conv
8
+
9
+
10
+ class DeformConv(nn.Module):
11
+
12
+ def __init__(self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride=1,
17
+ padding=0,
18
+ dilation=1,
19
+ groups=1,
20
+ deformable_groups=1,
21
+ bias=False):
22
+ super(DeformConv, self).__init__()
23
+
24
+ assert not bias
25
+ assert in_channels % groups == 0, \
26
+ 'in_channels {} cannot be divisible by groups {}'.format(
27
+ in_channels, groups)
28
+ assert out_channels % groups == 0, \
29
+ 'out_channels {} cannot be divisible by groups {}'.format(
30
+ out_channels, groups)
31
+
32
+ self.in_channels = in_channels
33
+ self.out_channels = out_channels
34
+ self.kernel_size = _pair(kernel_size)
35
+ self.stride = _pair(stride)
36
+ self.padding = _pair(padding)
37
+ self.dilation = _pair(dilation)
38
+ self.groups = groups
39
+ self.deformable_groups = deformable_groups
40
+
41
+ self.weight = nn.Parameter(
42
+ torch.Tensor(out_channels, in_channels // self.groups,
43
+ *self.kernel_size))
44
+
45
+ self.reset_parameters()
46
+
47
+ def reset_parameters(self):
48
+ n = self.in_channels
49
+ for k in self.kernel_size:
50
+ n *= k
51
+ stdv = 1. / math.sqrt(n)
52
+ self.weight.data.uniform_(-stdv, stdv)
53
+
54
+ def forward(self, x, offset):
55
+ return deform_conv(x, offset, self.weight, self.stride, self.padding,
56
+ self.dilation, self.groups, self.deformable_groups)
57
+
58
+
59
+ class DeformConvPack(DeformConv):
60
+
61
+ def __init__(self, *args, **kwargs):
62
+ super(DeformConvPack, self).__init__(*args, **kwargs)
63
+
64
+ self.conv_offset = nn.Conv2d(
65
+ self.in_channels,
66
+ self.deformable_groups * 2 * self.kernel_size[0] *
67
+ self.kernel_size[1],
68
+ kernel_size=self.kernel_size,
69
+ stride=_pair(self.stride),
70
+ padding=_pair(self.padding),
71
+ bias=True)
72
+ self.init_offset()
73
+
74
+ def init_offset(self):
75
+ self.conv_offset.weight.data.zero_()
76
+ self.conv_offset.bias.data.zero_()
77
+
78
+ def forward(self, x):
79
+ offset = self.conv_offset(x)
80
+ return deform_conv(x, offset, self.weight, self.stride, self.padding,
81
+ self.dilation, self.groups, self.deformable_groups)
82
+
83
+
84
+ class ModulatedDeformConv(nn.Module):
85
+
86
+ def __init__(self,
87
+ in_channels,
88
+ out_channels,
89
+ kernel_size,
90
+ stride=1,
91
+ padding=0,
92
+ dilation=1,
93
+ groups=1,
94
+ deformable_groups=1,
95
+ bias=True):
96
+ super(ModulatedDeformConv, self).__init__()
97
+ self.in_channels = in_channels
98
+ self.out_channels = out_channels
99
+ self.kernel_size = _pair(kernel_size)
100
+ self.stride = stride
101
+ self.padding = padding
102
+ self.dilation = dilation
103
+ self.groups = groups
104
+ self.deformable_groups = deformable_groups
105
+ self.with_bias = bias
106
+
107
+ self.weight = nn.Parameter(
108
+ torch.Tensor(out_channels, in_channels // groups,
109
+ *self.kernel_size))
110
+ if bias:
111
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
112
+ else:
113
+ self.register_parameter('bias', None)
114
+ self.reset_parameters()
115
+
116
+ def reset_parameters(self):
117
+ n = self.in_channels
118
+ for k in self.kernel_size:
119
+ n *= k
120
+ stdv = 1. / math.sqrt(n)
121
+ self.weight.data.uniform_(-stdv, stdv)
122
+ if self.bias is not None:
123
+ self.bias.data.zero_()
124
+
125
+ def forward(self, x, offset, mask):
126
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
127
+ self.stride, self.padding, self.dilation,
128
+ self.groups, self.deformable_groups)
129
+
130
+
131
+ class ModulatedDeformConvPack(ModulatedDeformConv):
132
+
133
+ def __init__(self, *args, **kwargs):
134
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
135
+
136
+ self.conv_offset_mask = nn.Conv2d(
137
+ self.in_channels,
138
+ self.deformable_groups * 3 * self.kernel_size[0] *
139
+ self.kernel_size[1],
140
+ kernel_size=self.kernel_size,
141
+ stride=_pair(self.stride),
142
+ padding=_pair(self.padding),
143
+ bias=True)
144
+ self.init_offset()
145
+
146
+ def init_offset(self):
147
+ self.conv_offset_mask.weight.data.zero_()
148
+ self.conv_offset_mask.bias.data.zero_()
149
+
150
+ def forward(self, x):
151
+ out = self.conv_offset_mask(x)
152
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
153
+ offset = torch.cat((o1, o2), dim=1)
154
+ mask = torch.sigmoid(mask)
155
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
156
+ self.stride, self.padding, self.dilation,
157
+ self.groups, self.deformable_groups)
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/modules/deform_pool.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from ..functions.deform_pool import deform_roi_pooling
4
+
5
+
6
+ class DeformRoIPooling(nn.Module):
7
+
8
+ def __init__(self,
9
+ spatial_scale,
10
+ out_size,
11
+ out_channels,
12
+ no_trans,
13
+ group_size=1,
14
+ part_size=None,
15
+ sample_per_part=4,
16
+ trans_std=.0):
17
+ super(DeformRoIPooling, self).__init__()
18
+ self.spatial_scale = spatial_scale
19
+ self.out_size = out_size
20
+ self.out_channels = out_channels
21
+ self.no_trans = no_trans
22
+ self.group_size = group_size
23
+ self.part_size = out_size if part_size is None else part_size
24
+ self.sample_per_part = sample_per_part
25
+ self.trans_std = trans_std
26
+
27
+ def forward(self, data, rois, offset):
28
+ if self.no_trans:
29
+ offset = data.new_empty(0)
30
+ return deform_roi_pooling(
31
+ data, rois, offset, self.spatial_scale, self.out_size,
32
+ self.out_channels, self.no_trans, self.group_size, self.part_size,
33
+ self.sample_per_part, self.trans_std)
34
+
35
+
36
+ class DeformRoIPoolingPack(DeformRoIPooling):
37
+
38
+ def __init__(self,
39
+ spatial_scale,
40
+ out_size,
41
+ out_channels,
42
+ no_trans,
43
+ group_size=1,
44
+ part_size=None,
45
+ sample_per_part=4,
46
+ trans_std=.0,
47
+ num_offset_fcs=3,
48
+ deform_fc_channels=1024):
49
+ super(DeformRoIPoolingPack,
50
+ self).__init__(spatial_scale, out_size, out_channels, no_trans,
51
+ group_size, part_size, sample_per_part, trans_std)
52
+
53
+ self.num_offset_fcs = num_offset_fcs
54
+ self.deform_fc_channels = deform_fc_channels
55
+
56
+ if not no_trans:
57
+ seq = []
58
+ ic = self.out_size * self.out_size * self.out_channels
59
+ for i in range(self.num_offset_fcs):
60
+ if i < self.num_offset_fcs - 1:
61
+ oc = self.deform_fc_channels
62
+ else:
63
+ oc = self.out_size * self.out_size * 2
64
+ seq.append(nn.Linear(ic, oc))
65
+ ic = oc
66
+ if i < self.num_offset_fcs - 1:
67
+ seq.append(nn.ReLU(inplace=True))
68
+ self.offset_fc = nn.Sequential(*seq)
69
+ self.offset_fc[-1].weight.data.zero_()
70
+ self.offset_fc[-1].bias.data.zero_()
71
+
72
+ def forward(self, data, rois):
73
+ assert data.size(1) == self.out_channels
74
+ if self.no_trans:
75
+ offset = data.new_empty(0)
76
+ return deform_roi_pooling(
77
+ data, rois, offset, self.spatial_scale, self.out_size,
78
+ self.out_channels, self.no_trans, self.group_size,
79
+ self.part_size, self.sample_per_part, self.trans_std)
80
+ else:
81
+ n = rois.shape[0]
82
+ offset = data.new_empty(0)
83
+ x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
84
+ self.out_size, self.out_channels, True,
85
+ self.group_size, self.part_size,
86
+ self.sample_per_part, self.trans_std)
87
+ offset = self.offset_fc(x.view(n, -1))
88
+ offset = offset.view(n, 2, self.out_size, self.out_size)
89
+ return deform_roi_pooling(
90
+ data, rois, offset, self.spatial_scale, self.out_size,
91
+ self.out_channels, self.no_trans, self.group_size,
92
+ self.part_size, self.sample_per_part, self.trans_std)
93
+
94
+
95
+ class ModulatedDeformRoIPoolingPack(DeformRoIPooling):
96
+
97
+ def __init__(self,
98
+ spatial_scale,
99
+ out_size,
100
+ out_channels,
101
+ no_trans,
102
+ group_size=1,
103
+ part_size=None,
104
+ sample_per_part=4,
105
+ trans_std=.0,
106
+ num_offset_fcs=3,
107
+ num_mask_fcs=2,
108
+ deform_fc_channels=1024):
109
+ super(ModulatedDeformRoIPoolingPack, self).__init__(
110
+ spatial_scale, out_size, out_channels, no_trans, group_size,
111
+ part_size, sample_per_part, trans_std)
112
+
113
+ self.num_offset_fcs = num_offset_fcs
114
+ self.num_mask_fcs = num_mask_fcs
115
+ self.deform_fc_channels = deform_fc_channels
116
+
117
+ if not no_trans:
118
+ offset_fc_seq = []
119
+ ic = self.out_size * self.out_size * self.out_channels
120
+ for i in range(self.num_offset_fcs):
121
+ if i < self.num_offset_fcs - 1:
122
+ oc = self.deform_fc_channels
123
+ else:
124
+ oc = self.out_size * self.out_size * 2
125
+ offset_fc_seq.append(nn.Linear(ic, oc))
126
+ ic = oc
127
+ if i < self.num_offset_fcs - 1:
128
+ offset_fc_seq.append(nn.ReLU(inplace=True))
129
+ self.offset_fc = nn.Sequential(*offset_fc_seq)
130
+ self.offset_fc[-1].weight.data.zero_()
131
+ self.offset_fc[-1].bias.data.zero_()
132
+
133
+ mask_fc_seq = []
134
+ ic = self.out_size * self.out_size * self.out_channels
135
+ for i in range(self.num_mask_fcs):
136
+ if i < self.num_mask_fcs - 1:
137
+ oc = self.deform_fc_channels
138
+ else:
139
+ oc = self.out_size * self.out_size
140
+ mask_fc_seq.append(nn.Linear(ic, oc))
141
+ ic = oc
142
+ if i < self.num_mask_fcs - 1:
143
+ mask_fc_seq.append(nn.ReLU(inplace=True))
144
+ else:
145
+ mask_fc_seq.append(nn.Sigmoid())
146
+ self.mask_fc = nn.Sequential(*mask_fc_seq)
147
+ self.mask_fc[-2].weight.data.zero_()
148
+ self.mask_fc[-2].bias.data.zero_()
149
+
150
+ def forward(self, data, rois):
151
+ assert data.size(1) == self.out_channels
152
+ if self.no_trans:
153
+ offset = data.new_empty(0)
154
+ return deform_roi_pooling(
155
+ data, rois, offset, self.spatial_scale, self.out_size,
156
+ self.out_channels, self.no_trans, self.group_size,
157
+ self.part_size, self.sample_per_part, self.trans_std)
158
+ else:
159
+ n = rois.shape[0]
160
+ offset = data.new_empty(0)
161
+ x = deform_roi_pooling(data, rois, offset, self.spatial_scale,
162
+ self.out_size, self.out_channels, True,
163
+ self.group_size, self.part_size,
164
+ self.sample_per_part, self.trans_std)
165
+ offset = self.offset_fc(x.view(n, -1))
166
+ offset = offset.view(n, 2, self.out_size, self.out_size)
167
+ mask = self.mask_fc(x.view(n, -1))
168
+ mask = mask.view(n, 1, self.out_size, self.out_size)
169
+ return deform_roi_pooling(
170
+ data, rois, offset, self.spatial_scale, self.out_size,
171
+ self.out_channels, self.no_trans, self.group_size,
172
+ self.part_size, self.sample_per_part, self.trans_std) * mask
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ PATH ="{}:{}".format(os.environ['PATH'], "/opt/cuda/bin")
3
+ # os.environ['CUDA_VISIBLE_DEVICES'] = "1"
4
+ os.environ['PATH'] = PATH
5
+ from setuptools import setup
6
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7
+
8
+ setup(
9
+ name='deform_conv',
10
+ ext_modules=[
11
+ CUDAExtension('deform_conv_cuda', [
12
+ 'src/deform_conv_cuda.cpp',
13
+ 'src/deform_conv_cuda_kernel.cu',
14
+ ]),
15
+ CUDAExtension('deform_pool_cuda', [
16
+ 'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu'
17
+ ]),
18
+ ],
19
+ cmdclass={'build_ext': BuildExtension})
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda.cpp ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modify from
2
+ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
3
+
4
+ #include <torch/extension.h>
5
+
6
+ #include <cmath>
7
+ #include <vector>
8
+
9
+ void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
10
+ const int channels, const int height, const int width,
11
+ const int ksize_h, const int ksize_w, const int pad_h,
12
+ const int pad_w, const int stride_h, const int stride_w,
13
+ const int dilation_h, const int dilation_w,
14
+ const int parallel_imgs, const int deformable_group,
15
+ at::Tensor data_col);
16
+
17
+ void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
18
+ const int channels, const int height, const int width,
19
+ const int ksize_h, const int ksize_w, const int pad_h,
20
+ const int pad_w, const int stride_h, const int stride_w,
21
+ const int dilation_h, const int dilation_w,
22
+ const int parallel_imgs, const int deformable_group,
23
+ at::Tensor grad_im);
24
+
25
+ void deformable_col2im_coord(
26
+ const at::Tensor data_col, const at::Tensor data_im,
27
+ const at::Tensor data_offset, const int channels, const int height,
28
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
29
+ const int pad_w, const int stride_h, const int stride_w,
30
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
31
+ const int deformable_group, at::Tensor grad_offset);
32
+
33
+ void modulated_deformable_im2col_cuda(
34
+ const at::Tensor data_im, const at::Tensor data_offset,
35
+ const at::Tensor data_mask, const int batch_size, const int channels,
36
+ const int height_im, const int width_im, const int height_col,
37
+ const int width_col, const int kernel_h, const int kenerl_w,
38
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
39
+ const int dilation_h, const int dilation_w, const int deformable_group,
40
+ at::Tensor data_col);
41
+
42
+ void modulated_deformable_col2im_cuda(
43
+ const at::Tensor data_col, const at::Tensor data_offset,
44
+ const at::Tensor data_mask, const int batch_size, const int channels,
45
+ const int height_im, const int width_im, const int height_col,
46
+ const int width_col, const int kernel_h, const int kenerl_w,
47
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
48
+ const int dilation_h, const int dilation_w, const int deformable_group,
49
+ at::Tensor grad_im);
50
+
51
+ void modulated_deformable_col2im_coord_cuda(
52
+ const at::Tensor data_col, const at::Tensor data_im,
53
+ const at::Tensor data_offset, const at::Tensor data_mask,
54
+ const int batch_size, const int channels, const int height_im,
55
+ const int width_im, const int height_col, const int width_col,
56
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
57
+ const int stride_h, const int stride_w, const int dilation_h,
58
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
59
+ at::Tensor grad_mask);
60
+
61
+ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
62
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
63
+ int padW, int dilationH, int dilationW, int group,
64
+ int deformable_group) {
65
+ TORCH_CHECK(weight.ndimension() == 4,
66
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
67
+ "but got: %s",
68
+ weight.ndimension());
69
+
70
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
71
+
72
+ TORCH_CHECK(kW > 0 && kH > 0,
73
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
74
+ kW);
75
+
76
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
77
+ "kernel size should be consistent with weight, ",
78
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
79
+ kW, weight.size(2), weight.size(3));
80
+
81
+ TORCH_CHECK(dW > 0 && dH > 0,
82
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
83
+
84
+ TORCH_CHECK(
85
+ dilationW > 0 && dilationH > 0,
86
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
87
+ dilationH, dilationW);
88
+
89
+ int ndim = input.ndimension();
90
+ int dimf = 0;
91
+ int dimh = 1;
92
+ int dimw = 2;
93
+
94
+ if (ndim == 4) {
95
+ dimf++;
96
+ dimh++;
97
+ dimw++;
98
+ }
99
+
100
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
101
+ ndim);
102
+
103
+ long nInputPlane = weight.size(1) * group;
104
+ long inputHeight = input.size(dimh);
105
+ long inputWidth = input.size(dimw);
106
+ long nOutputPlane = weight.size(0);
107
+ long outputHeight =
108
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
109
+ long outputWidth =
110
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
111
+
112
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
113
+ "input channels must divide deformable group size");
114
+
115
+ if (outputWidth < 1 || outputHeight < 1)
116
+ AT_ERROR(
117
+ "Given input size: (%ld x %ld x %ld). "
118
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
119
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
120
+ outputWidth);
121
+
122
+ TORCH_CHECK(input.size(1) == nInputPlane,
123
+ "invalid number of input planes, expected: %d, but got: %d",
124
+ nInputPlane, input.size(1));
125
+
126
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
127
+ "input image is smaller than kernel");
128
+
129
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
130
+ "invalid spatial size of offset, expected height: %d width: %d, but "
131
+ "got height: %d width: %d",
132
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
133
+
134
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
135
+ "invalid number of channels of offset");
136
+
137
+ if (gradOutput != NULL) {
138
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
139
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
140
+ nOutputPlane, gradOutput->size(dimf));
141
+
142
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
143
+ gradOutput->size(dimw) == outputWidth),
144
+ "invalid size of gradOutput, expected height: %d width: %d , but "
145
+ "got height: %d width: %d",
146
+ outputHeight, outputWidth, gradOutput->size(dimh),
147
+ gradOutput->size(dimw));
148
+ }
149
+ }
150
+
151
+ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
152
+ at::Tensor offset, at::Tensor output,
153
+ at::Tensor columns, at::Tensor ones, int kW,
154
+ int kH, int dW, int dH, int padW, int padH,
155
+ int dilationW, int dilationH, int group,
156
+ int deformable_group, int im2col_step) {
157
+ // todo: resize columns to include im2col: done
158
+ // todo: add im2col_step as input
159
+ // todo: add new output buffer and transpose it to output (or directly
160
+ // transpose output) todo: possibly change data indexing because of
161
+ // parallel_imgs
162
+
163
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
164
+ dilationH, dilationW, group, deformable_group);
165
+
166
+ input = input.contiguous();
167
+ offset = offset.contiguous();
168
+ weight = weight.contiguous();
169
+
170
+ int batch = 1;
171
+ if (input.ndimension() == 3) {
172
+ // Force batch
173
+ batch = 0;
174
+ input.unsqueeze_(0);
175
+ offset.unsqueeze_(0);
176
+ }
177
+
178
+ // todo: assert batchsize dividable by im2col_step
179
+
180
+ long batchSize = input.size(0);
181
+ long nInputPlane = input.size(1);
182
+ long inputHeight = input.size(2);
183
+ long inputWidth = input.size(3);
184
+
185
+ long nOutputPlane = weight.size(0);
186
+
187
+ long outputWidth =
188
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
189
+ long outputHeight =
190
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
191
+
192
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
193
+
194
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
195
+ outputHeight, outputWidth});
196
+ columns = at::zeros(
197
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
198
+ input.options());
199
+
200
+ if (ones.ndimension() != 2 ||
201
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
202
+ ones = at::ones({outputHeight, outputWidth}, input.options());
203
+ }
204
+
205
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
206
+ inputHeight, inputWidth});
207
+ offset =
208
+ offset.view({batchSize / im2col_step, im2col_step,
209
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
210
+
211
+ at::Tensor output_buffer =
212
+ at::zeros({batchSize / im2col_step, nOutputPlane,
213
+ im2col_step * outputHeight, outputWidth},
214
+ output.options());
215
+
216
+ output_buffer = output_buffer.view(
217
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
218
+ output_buffer.size(2), output_buffer.size(3)});
219
+
220
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
221
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
222
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
223
+ dilationW, im2col_step, deformable_group, columns);
224
+
225
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
226
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
227
+ weight.size(2), weight.size(3)});
228
+
229
+ for (int g = 0; g < group; g++) {
230
+ output_buffer[elt][g] = output_buffer[elt][g]
231
+ .flatten(1)
232
+ .addmm_(weight[g].flatten(1), columns[g])
233
+ .view_as(output_buffer[elt][g]);
234
+ }
235
+ }
236
+
237
+ output_buffer = output_buffer.view(
238
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
239
+ output_buffer.size(3), output_buffer.size(4)});
240
+
241
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
242
+ im2col_step, outputHeight, outputWidth});
243
+ output_buffer.transpose_(1, 2);
244
+ output.copy_(output_buffer);
245
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
246
+
247
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
248
+ offset = offset.view(
249
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
250
+
251
+ if (batch == 0) {
252
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
253
+ input = input.view({nInputPlane, inputHeight, inputWidth});
254
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
255
+ }
256
+
257
+ return 1;
258
+ }
259
+
260
+ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
261
+ at::Tensor gradOutput, at::Tensor gradInput,
262
+ at::Tensor gradOffset, at::Tensor weight,
263
+ at::Tensor columns, int kW, int kH, int dW,
264
+ int dH, int padW, int padH, int dilationW,
265
+ int dilationH, int group,
266
+ int deformable_group, int im2col_step) {
267
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
268
+ dilationH, dilationW, group, deformable_group);
269
+
270
+ input = input.contiguous();
271
+ offset = offset.contiguous();
272
+ gradOutput = gradOutput.contiguous();
273
+ weight = weight.contiguous();
274
+
275
+ int batch = 1;
276
+
277
+ if (input.ndimension() == 3) {
278
+ // Force batch
279
+ batch = 0;
280
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
281
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
282
+ gradOutput = gradOutput.view(
283
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
284
+ }
285
+
286
+ long batchSize = input.size(0);
287
+ long nInputPlane = input.size(1);
288
+ long inputHeight = input.size(2);
289
+ long inputWidth = input.size(3);
290
+
291
+ long nOutputPlane = weight.size(0);
292
+
293
+ long outputWidth =
294
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
295
+ long outputHeight =
296
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
297
+
298
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
299
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
300
+ columns = at::zeros(
301
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
302
+ input.options());
303
+
304
+ // change order of grad output
305
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
306
+ nOutputPlane, outputHeight, outputWidth});
307
+ gradOutput.transpose_(1, 2);
308
+
309
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
310
+ inputHeight, inputWidth});
311
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
312
+ inputHeight, inputWidth});
313
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
314
+ deformable_group * 2 * kH * kW, outputHeight,
315
+ outputWidth});
316
+ offset =
317
+ offset.view({batchSize / im2col_step, im2col_step,
318
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
319
+
320
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
321
+ // divide into groups
322
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
323
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
324
+ weight.size(2), weight.size(3)});
325
+ gradOutput = gradOutput.view(
326
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
327
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
328
+
329
+ for (int g = 0; g < group; g++) {
330
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
331
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
332
+ }
333
+
334
+ columns =
335
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
336
+ gradOutput = gradOutput.view(
337
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
338
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
339
+
340
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
341
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
342
+ dilationH, dilationW, im2col_step, deformable_group,
343
+ gradOffset[elt]);
344
+
345
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
346
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
347
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
348
+ }
349
+
350
+ gradOutput.transpose_(1, 2);
351
+ gradOutput =
352
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
353
+
354
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
355
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
356
+ gradOffset = gradOffset.view(
357
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
358
+ offset = offset.view(
359
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
360
+
361
+ if (batch == 0) {
362
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
363
+ input = input.view({nInputPlane, inputHeight, inputWidth});
364
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
365
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
366
+ gradOffset =
367
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
368
+ }
369
+
370
+ return 1;
371
+ }
372
+
373
+ int deform_conv_backward_parameters_cuda(
374
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
375
+ at::Tensor gradWeight, // at::Tensor gradBias,
376
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
377
+ int padW, int padH, int dilationW, int dilationH, int group,
378
+ int deformable_group, float scale, int im2col_step) {
379
+ // todo: transpose and reshape outGrad
380
+ // todo: reshape columns
381
+ // todo: add im2col_step as input
382
+
383
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
384
+ padW, dilationH, dilationW, group, deformable_group);
385
+
386
+ input = input.contiguous();
387
+ offset = offset.contiguous();
388
+ gradOutput = gradOutput.contiguous();
389
+
390
+ int batch = 1;
391
+
392
+ if (input.ndimension() == 3) {
393
+ // Force batch
394
+ batch = 0;
395
+ input = input.view(
396
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
397
+ gradOutput = gradOutput.view(
398
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
399
+ }
400
+
401
+ long batchSize = input.size(0);
402
+ long nInputPlane = input.size(1);
403
+ long inputHeight = input.size(2);
404
+ long inputWidth = input.size(3);
405
+
406
+ long nOutputPlane = gradWeight.size(0);
407
+
408
+ long outputWidth =
409
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
410
+ long outputHeight =
411
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
412
+
413
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
414
+
415
+ columns = at::zeros(
416
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
417
+ input.options());
418
+
419
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
420
+ nOutputPlane, outputHeight, outputWidth});
421
+ gradOutput.transpose_(1, 2);
422
+
423
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
424
+ gradOutputBuffer =
425
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
426
+ outputHeight, outputWidth});
427
+ gradOutputBuffer.copy_(gradOutput);
428
+ gradOutputBuffer =
429
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
430
+ im2col_step * outputHeight, outputWidth});
431
+
432
+ gradOutput.transpose_(1, 2);
433
+ gradOutput =
434
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
435
+
436
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
437
+ inputHeight, inputWidth});
438
+ offset =
439
+ offset.view({batchSize / im2col_step, im2col_step,
440
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
441
+
442
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
443
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
444
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
445
+ dilationW, im2col_step, deformable_group, columns);
446
+
447
+ // divide into group
448
+ gradOutputBuffer = gradOutputBuffer.view(
449
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
450
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
451
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
452
+ gradWeight =
453
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
454
+ gradWeight.size(2), gradWeight.size(3)});
455
+
456
+ for (int g = 0; g < group; g++) {
457
+ gradWeight[g] = gradWeight[g]
458
+ .flatten(1)
459
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
460
+ columns[g].transpose(1, 0), 1.0, scale)
461
+ .view_as(gradWeight[g]);
462
+ }
463
+ gradOutputBuffer = gradOutputBuffer.view(
464
+ {gradOutputBuffer.size(0),
465
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
466
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
467
+ columns =
468
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
469
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
470
+ gradWeight.size(2), gradWeight.size(3),
471
+ gradWeight.size(4)});
472
+ }
473
+
474
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
475
+ offset = offset.view(
476
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
477
+
478
+ if (batch == 0) {
479
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
480
+ input = input.view({nInputPlane, inputHeight, inputWidth});
481
+ }
482
+
483
+ return 1;
484
+ }
485
+
486
+ void modulated_deform_conv_cuda_forward(
487
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
488
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
489
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
490
+ const int pad_h, const int pad_w, const int dilation_h,
491
+ const int dilation_w, const int group, const int deformable_group,
492
+ const bool with_bias) {
493
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
494
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
495
+
496
+ const int batch = input.size(0);
497
+ const int channels = input.size(1);
498
+ const int height = input.size(2);
499
+ const int width = input.size(3);
500
+
501
+ const int channels_out = weight.size(0);
502
+ const int channels_kernel = weight.size(1);
503
+ const int kernel_h_ = weight.size(2);
504
+ const int kernel_w_ = weight.size(3);
505
+
506
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
507
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
508
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
509
+ if (channels != channels_kernel * group)
510
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
511
+ channels, channels_kernel * group);
512
+
513
+ const int height_out =
514
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
515
+ const int width_out =
516
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
517
+
518
+ if (ones.ndimension() != 2 ||
519
+ ones.size(0) * ones.size(1) < height_out * width_out) {
520
+ // Resize plane and fill with ones...
521
+ ones = at::ones({height_out, width_out}, input.options());
522
+ }
523
+
524
+ // resize output
525
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
526
+ // resize temporary columns
527
+ columns =
528
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
529
+ input.options());
530
+
531
+ output = output.view({output.size(0), group, output.size(1) / group,
532
+ output.size(2), output.size(3)});
533
+
534
+ for (int b = 0; b < batch; b++) {
535
+ modulated_deformable_im2col_cuda(
536
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
537
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
538
+ dilation_h, dilation_w, deformable_group, columns);
539
+
540
+ // divide into group
541
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
542
+ weight.size(2), weight.size(3)});
543
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
544
+
545
+ for (int g = 0; g < group; g++) {
546
+ output[b][g] = output[b][g]
547
+ .flatten(1)
548
+ .addmm_(weight[g].flatten(1), columns[g])
549
+ .view_as(output[b][g]);
550
+ }
551
+
552
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
553
+ weight.size(3), weight.size(4)});
554
+ columns =
555
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
556
+ }
557
+
558
+ output = output.view({output.size(0), output.size(1) * output.size(2),
559
+ output.size(3), output.size(4)});
560
+
561
+ if (with_bias) {
562
+ output += bias.view({1, bias.size(0), 1, 1});
563
+ }
564
+ }
565
+
566
+ void modulated_deform_conv_cuda_backward(
567
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
568
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
569
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
570
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
571
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
572
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
573
+ const bool with_bias) {
574
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
575
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
576
+
577
+ const int batch = input.size(0);
578
+ const int channels = input.size(1);
579
+ const int height = input.size(2);
580
+ const int width = input.size(3);
581
+
582
+ const int channels_kernel = weight.size(1);
583
+ const int kernel_h_ = weight.size(2);
584
+ const int kernel_w_ = weight.size(3);
585
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
586
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
587
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
588
+ if (channels != channels_kernel * group)
589
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
590
+ channels, channels_kernel * group);
591
+
592
+ const int height_out =
593
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
594
+ const int width_out =
595
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
596
+
597
+ if (ones.ndimension() != 2 ||
598
+ ones.size(0) * ones.size(1) < height_out * width_out) {
599
+ // Resize plane and fill with ones...
600
+ ones = at::ones({height_out, width_out}, input.options());
601
+ }
602
+
603
+ grad_input = grad_input.view({batch, channels, height, width});
604
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
605
+ input.options());
606
+
607
+ grad_output =
608
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
609
+ grad_output.size(2), grad_output.size(3)});
610
+
611
+ for (int b = 0; b < batch; b++) {
612
+ // divide int group
613
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
614
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
615
+ weight.size(2), weight.size(3)});
616
+
617
+ for (int g = 0; g < group; g++) {
618
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
619
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
620
+ }
621
+
622
+ columns =
623
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
624
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
625
+ weight.size(3), weight.size(4)});
626
+
627
+ // gradient w.r.t. input coordinate data
628
+ modulated_deformable_col2im_coord_cuda(
629
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
630
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
631
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
632
+ grad_mask[b]);
633
+ // gradient w.r.t. input data
634
+ modulated_deformable_col2im_cuda(
635
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
636
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
637
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
638
+
639
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
640
+ // group
641
+ modulated_deformable_im2col_cuda(
642
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
643
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
644
+ dilation_h, dilation_w, deformable_group, columns);
645
+
646
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
647
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
648
+ grad_weight.size(1), grad_weight.size(2),
649
+ grad_weight.size(3)});
650
+ if (with_bias)
651
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
652
+
653
+ for (int g = 0; g < group; g++) {
654
+ grad_weight[g] =
655
+ grad_weight[g]
656
+ .flatten(1)
657
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
658
+ .view_as(grad_weight[g]);
659
+ if (with_bias) {
660
+ grad_bias[g] =
661
+ grad_bias[g]
662
+ .view({-1, 1})
663
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
664
+ .view(-1);
665
+ }
666
+ }
667
+
668
+ columns =
669
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
670
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
671
+ grad_weight.size(2), grad_weight.size(3),
672
+ grad_weight.size(4)});
673
+ if (with_bias)
674
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
675
+ }
676
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
677
+ grad_output.size(2), grad_output.size(3),
678
+ grad_output.size(4)});
679
+ }
680
+
681
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
682
+ m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
683
+ "deform forward (CUDA)");
684
+ m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
685
+ "deform_conv_backward_input (CUDA)");
686
+ m.def("deform_conv_backward_parameters_cuda",
687
+ &deform_conv_backward_parameters_cuda,
688
+ "deform_conv_backward_parameters (CUDA)");
689
+ m.def("modulated_deform_conv_cuda_forward",
690
+ &modulated_deform_conv_cuda_forward,
691
+ "modulated deform conv forward (CUDA)");
692
+ m.def("modulated_deform_conv_cuda_backward",
693
+ &modulated_deform_conv_cuda_backward,
694
+ "modulated deform conv backward (CUDA)");
695
+ }
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_conv_cuda_kernel.cu ADDED
@@ -0,0 +1,866 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
3
+ *
4
+ * COPYRIGHT
5
+ *
6
+ * All contributions by the University of California:
7
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
8
+ * All rights reserved.
9
+ *
10
+ * All other contributions:
11
+ * Copyright (c) 2014-2017, the respective contributors
12
+ * All rights reserved.
13
+ *
14
+ * Caffe uses a shared copyright model: each contributor holds copyright over
15
+ * their contributions to Caffe. The project versioning records all such
16
+ * contribution and copyright details. If a contributor wants to further mark
17
+ * their specific copyright on a particular contribution, they should indicate
18
+ * their copyright solely in the commit message of the change when it is
19
+ * committed.
20
+ *
21
+ * LICENSE
22
+ *
23
+ * Redistribution and use in source and binary forms, with or without
24
+ * modification, are permitted provided that the following conditions are met:
25
+ *
26
+ * 1. Redistributions of source code must retain the above copyright notice, this
27
+ * list of conditions and the following disclaimer.
28
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
29
+ * this list of conditions and the following disclaimer in the documentation
30
+ * and/or other materials provided with the distribution.
31
+ *
32
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
33
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
34
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
35
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
36
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
38
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
39
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
40
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
41
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42
+ *
43
+ * CONTRIBUTION AGREEMENT
44
+ *
45
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
46
+ * or otherwise, the contributor releases their content to the
47
+ * license and copyright terms herein.
48
+ *
49
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
50
+ *
51
+ * Copyright (c) 2018 Microsoft
52
+ * Licensed under The MIT License [see LICENSE for details]
53
+ * \file modulated_deformable_im2col.cuh
54
+ * \brief Function definitions of converting an image to
55
+ * column matrix based on kernel, padding, dilation, and offset.
56
+ * These functions are mainly used in deformable convolution operators.
57
+ * \ref: https://arxiv.org/abs/1703.06211
58
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
59
+ */
60
+
61
+ // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
62
+
63
+ #include <ATen/ATen.h>
64
+ #include <THC/THCAtomics.cuh>
65
+ #include <stdio.h>
66
+ #include <math.h>
67
+ #include <float.h>
68
+
69
+ using namespace at;
70
+
71
+ #define CUDA_KERNEL_LOOP(i, n) \
72
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
73
+ i += blockDim.x * gridDim.x)
74
+
75
+ const int CUDA_NUM_THREADS = 1024;
76
+ const int kMaxGridNum = 65535;
77
+
78
+ inline int GET_BLOCKS(const int N)
79
+ {
80
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
81
+ }
82
+
83
+ template <typename scalar_t>
84
+ __device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
85
+ const int height, const int width, scalar_t h, scalar_t w)
86
+ {
87
+
88
+ int h_low = floor(h);
89
+ int w_low = floor(w);
90
+ int h_high = h_low + 1;
91
+ int w_high = w_low + 1;
92
+
93
+ scalar_t lh = h - h_low;
94
+ scalar_t lw = w - w_low;
95
+ scalar_t hh = 1 - lh, hw = 1 - lw;
96
+
97
+ scalar_t v1 = 0;
98
+ if (h_low >= 0 && w_low >= 0)
99
+ v1 = bottom_data[h_low * data_width + w_low];
100
+ scalar_t v2 = 0;
101
+ if (h_low >= 0 && w_high <= width - 1)
102
+ v2 = bottom_data[h_low * data_width + w_high];
103
+ scalar_t v3 = 0;
104
+ if (h_high <= height - 1 && w_low >= 0)
105
+ v3 = bottom_data[h_high * data_width + w_low];
106
+ scalar_t v4 = 0;
107
+ if (h_high <= height - 1 && w_high <= width - 1)
108
+ v4 = bottom_data[h_high * data_width + w_high];
109
+
110
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
111
+
112
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
113
+ return val;
114
+ }
115
+
116
+ template <typename scalar_t>
117
+ __device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
118
+ const int h, const int w, const int height, const int width)
119
+ {
120
+
121
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
122
+ {
123
+ //empty
124
+ return 0;
125
+ }
126
+
127
+ int argmax_h_low = floor(argmax_h);
128
+ int argmax_w_low = floor(argmax_w);
129
+ int argmax_h_high = argmax_h_low + 1;
130
+ int argmax_w_high = argmax_w_low + 1;
131
+
132
+ scalar_t weight = 0;
133
+ if (h == argmax_h_low && w == argmax_w_low)
134
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
135
+ if (h == argmax_h_low && w == argmax_w_high)
136
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
137
+ if (h == argmax_h_high && w == argmax_w_low)
138
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
139
+ if (h == argmax_h_high && w == argmax_w_high)
140
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
141
+ return weight;
142
+ }
143
+
144
+ template <typename scalar_t>
145
+ __device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
146
+ const int height, const int width, const scalar_t *im_data,
147
+ const int data_width, const int bp_dir)
148
+ {
149
+
150
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
151
+ {
152
+ //empty
153
+ return 0;
154
+ }
155
+
156
+ int argmax_h_low = floor(argmax_h);
157
+ int argmax_w_low = floor(argmax_w);
158
+ int argmax_h_high = argmax_h_low + 1;
159
+ int argmax_w_high = argmax_w_low + 1;
160
+
161
+ scalar_t weight = 0;
162
+
163
+ if (bp_dir == 0)
164
+ {
165
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
166
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
167
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
168
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
169
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
170
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
171
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
172
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
173
+ }
174
+ else if (bp_dir == 1)
175
+ {
176
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
177
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
178
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
179
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
180
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
181
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
182
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
183
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
184
+ }
185
+
186
+ return weight;
187
+ }
188
+
189
+ template <typename scalar_t>
190
+ __global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
191
+ const int height, const int width, const int kernel_h, const int kernel_w,
192
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
193
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
194
+ const int batch_size, const int num_channels, const int deformable_group,
195
+ const int height_col, const int width_col,
196
+ scalar_t *data_col)
197
+ {
198
+ CUDA_KERNEL_LOOP(index, n)
199
+ {
200
+ // index index of output matrix
201
+ const int w_col = index % width_col;
202
+ const int h_col = (index / width_col) % height_col;
203
+ const int b_col = (index / width_col / height_col) % batch_size;
204
+ const int c_im = (index / width_col / height_col) / batch_size;
205
+ const int c_col = c_im * kernel_h * kernel_w;
206
+
207
+ // compute deformable group index
208
+ const int deformable_group_index = c_im / channel_per_deformable_group;
209
+
210
+ const int h_in = h_col * stride_h - pad_h;
211
+ const int w_in = w_col * stride_w - pad_w;
212
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
213
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
214
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
215
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
216
+
217
+ for (int i = 0; i < kernel_h; ++i)
218
+ {
219
+ for (int j = 0; j < kernel_w; ++j)
220
+ {
221
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
222
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
223
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
224
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
225
+ scalar_t val = static_cast<scalar_t>(0);
226
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
227
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
228
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
229
+ {
230
+ //const scalar_t map_h = i * dilation_h + offset_h;
231
+ //const scalar_t map_w = j * dilation_w + offset_w;
232
+ //const int cur_height = height - h_in;
233
+ //const int cur_width = width - w_in;
234
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
235
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
236
+ }
237
+ *data_col_ptr = val;
238
+ data_col_ptr += batch_size * height_col * width_col;
239
+ }
240
+ }
241
+ }
242
+ }
243
+
244
+ void deformable_im2col(
245
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
246
+ const int height, const int width, const int ksize_h, const int ksize_w,
247
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
248
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
249
+ const int deformable_group, at::Tensor data_col)
250
+ {
251
+ // num_axes should be smaller than block size
252
+ // todo: check parallel_imgs is correctly passed in
253
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
254
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
255
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
256
+ int channel_per_deformable_group = channels / deformable_group;
257
+
258
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
259
+ data_im.type(), "deformable_im2col_gpu", ([&] {
260
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
261
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
262
+ scalar_t *data_col_ = data_col.data<scalar_t>();
263
+
264
+ deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
265
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
266
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
267
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
268
+ height_col, width_col, data_col_);
269
+ }));
270
+
271
+ cudaError_t err = cudaGetLastError();
272
+ if (err != cudaSuccess)
273
+ {
274
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
275
+ }
276
+ }
277
+
278
+ template <typename scalar_t>
279
+ __global__ void deformable_col2im_gpu_kernel(
280
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
281
+ const int channels, const int height, const int width,
282
+ const int kernel_h, const int kernel_w,
283
+ const int pad_h, const int pad_w,
284
+ const int stride_h, const int stride_w,
285
+ const int dilation_h, const int dilation_w,
286
+ const int channel_per_deformable_group,
287
+ const int batch_size, const int deformable_group,
288
+ const int height_col, const int width_col,
289
+ scalar_t *grad_im)
290
+ {
291
+ CUDA_KERNEL_LOOP(index, n)
292
+ {
293
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
294
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
295
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
296
+ // compute the start and end of the output
297
+
298
+ const int deformable_group_index = c / channel_per_deformable_group;
299
+
300
+ int w_out = index % width_col;
301
+ int h_out = (index / width_col) % height_col;
302
+ int b = (index / width_col / height_col) % batch_size;
303
+ int w_in = w_out * stride_w - pad_w;
304
+ int h_in = h_out * stride_h - pad_h;
305
+
306
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
307
+ 2 * kernel_h * kernel_w * height_col * width_col;
308
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
309
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
310
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
311
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
312
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
313
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
314
+
315
+ const scalar_t cur_top_grad = data_col[index];
316
+ const int cur_h = (int)cur_inv_h_data;
317
+ const int cur_w = (int)cur_inv_w_data;
318
+ for (int dy = -2; dy <= 2; dy++)
319
+ {
320
+ for (int dx = -2; dx <= 2; dx++)
321
+ {
322
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
323
+ cur_w + dx >= 0 && cur_w + dx < width &&
324
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
325
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
326
+ {
327
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
328
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
329
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
330
+ }
331
+ }
332
+ }
333
+ }
334
+ }
335
+
336
+ void deformable_col2im(
337
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
338
+ const int height, const int width, const int ksize_h,
339
+ const int ksize_w, const int pad_h, const int pad_w,
340
+ const int stride_h, const int stride_w,
341
+ const int dilation_h, const int dilation_w,
342
+ const int parallel_imgs, const int deformable_group,
343
+ at::Tensor grad_im)
344
+ {
345
+
346
+ // todo: make sure parallel_imgs is passed in correctly
347
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
348
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
349
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
350
+ int channel_per_deformable_group = channels / deformable_group;
351
+
352
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
353
+ data_col.type(), "deformable_col2im_gpu", ([&] {
354
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
355
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
356
+ scalar_t *grad_im_ = grad_im.data<scalar_t>();
357
+
358
+ deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
359
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
360
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
361
+ dilation_h, dilation_w, channel_per_deformable_group,
362
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
363
+ }));
364
+
365
+ cudaError_t err = cudaGetLastError();
366
+ if (err != cudaSuccess)
367
+ {
368
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
369
+ }
370
+ }
371
+
372
+ template <typename scalar_t>
373
+ __global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
374
+ const scalar_t *data_im, const scalar_t *data_offset,
375
+ const int channels, const int height, const int width,
376
+ const int kernel_h, const int kernel_w,
377
+ const int pad_h, const int pad_w,
378
+ const int stride_h, const int stride_w,
379
+ const int dilation_h, const int dilation_w,
380
+ const int channel_per_deformable_group,
381
+ const int batch_size, const int offset_channels, const int deformable_group,
382
+ const int height_col, const int width_col, scalar_t *grad_offset)
383
+ {
384
+ CUDA_KERNEL_LOOP(index, n)
385
+ {
386
+ scalar_t val = 0;
387
+ int w = index % width_col;
388
+ int h = (index / width_col) % height_col;
389
+ int c = (index / width_col / height_col) % offset_channels;
390
+ int b = (index / width_col / height_col) / offset_channels;
391
+ // compute the start and end of the output
392
+
393
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
394
+ const int col_step = kernel_h * kernel_w;
395
+ int cnt = 0;
396
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
397
+ batch_size * width_col * height_col;
398
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
399
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
400
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
401
+ kernel_h * kernel_w * height_col * width_col;
402
+
403
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
404
+
405
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
406
+ {
407
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
408
+ const int bp_dir = offset_c % 2;
409
+
410
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
411
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
412
+ int w_out = col_pos % width_col;
413
+ int h_out = (col_pos / width_col) % height_col;
414
+ int w_in = w_out * stride_w - pad_w;
415
+ int h_in = h_out * stride_h - pad_h;
416
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
417
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
418
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
419
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
420
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
421
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
422
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
423
+ {
424
+ inv_h = inv_w = -2;
425
+ }
426
+ const scalar_t weight = get_coordinate_weight(
427
+ inv_h, inv_w,
428
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
429
+ val += weight * data_col_ptr[col_pos];
430
+ cnt += 1;
431
+ }
432
+
433
+ grad_offset[index] = val;
434
+ }
435
+ }
436
+
437
+ void deformable_col2im_coord(
438
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
439
+ const int channels, const int height, const int width, const int ksize_h,
440
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
441
+ const int stride_w, const int dilation_h, const int dilation_w,
442
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
443
+ {
444
+
445
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
446
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
447
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
448
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
449
+
450
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
451
+ data_col.type(), "deformable_col2im_coord_gpu", ([&] {
452
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
453
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
454
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
455
+ scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
456
+
457
+ deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
458
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
459
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
460
+ dilation_h, dilation_w, channel_per_deformable_group,
461
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
462
+ height_col, width_col, grad_offset_);
463
+ }));
464
+ }
465
+
466
+ template <typename scalar_t>
467
+ __device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
468
+ const int height, const int width, scalar_t h, scalar_t w)
469
+ {
470
+ int h_low = floor(h);
471
+ int w_low = floor(w);
472
+ int h_high = h_low + 1;
473
+ int w_high = w_low + 1;
474
+
475
+ scalar_t lh = h - h_low;
476
+ scalar_t lw = w - w_low;
477
+ scalar_t hh = 1 - lh, hw = 1 - lw;
478
+
479
+ scalar_t v1 = 0;
480
+ if (h_low >= 0 && w_low >= 0)
481
+ v1 = bottom_data[h_low * data_width + w_low];
482
+ scalar_t v2 = 0;
483
+ if (h_low >= 0 && w_high <= width - 1)
484
+ v2 = bottom_data[h_low * data_width + w_high];
485
+ scalar_t v3 = 0;
486
+ if (h_high <= height - 1 && w_low >= 0)
487
+ v3 = bottom_data[h_high * data_width + w_low];
488
+ scalar_t v4 = 0;
489
+ if (h_high <= height - 1 && w_high <= width - 1)
490
+ v4 = bottom_data[h_high * data_width + w_high];
491
+
492
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
493
+
494
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
495
+ return val;
496
+ }
497
+
498
+ template <typename scalar_t>
499
+ __device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
500
+ const int h, const int w, const int height, const int width)
501
+ {
502
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
503
+ {
504
+ //empty
505
+ return 0;
506
+ }
507
+
508
+ int argmax_h_low = floor(argmax_h);
509
+ int argmax_w_low = floor(argmax_w);
510
+ int argmax_h_high = argmax_h_low + 1;
511
+ int argmax_w_high = argmax_w_low + 1;
512
+
513
+ scalar_t weight = 0;
514
+ if (h == argmax_h_low && w == argmax_w_low)
515
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
516
+ if (h == argmax_h_low && w == argmax_w_high)
517
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
518
+ if (h == argmax_h_high && w == argmax_w_low)
519
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
520
+ if (h == argmax_h_high && w == argmax_w_high)
521
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
522
+ return weight;
523
+ }
524
+
525
+ template <typename scalar_t>
526
+ __device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
527
+ const int height, const int width, const scalar_t *im_data,
528
+ const int data_width, const int bp_dir)
529
+ {
530
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
531
+ {
532
+ //empty
533
+ return 0;
534
+ }
535
+
536
+ int argmax_h_low = floor(argmax_h);
537
+ int argmax_w_low = floor(argmax_w);
538
+ int argmax_h_high = argmax_h_low + 1;
539
+ int argmax_w_high = argmax_w_low + 1;
540
+
541
+ scalar_t weight = 0;
542
+
543
+ if (bp_dir == 0)
544
+ {
545
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
546
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
547
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
548
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
549
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
550
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
551
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
552
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
553
+ }
554
+ else if (bp_dir == 1)
555
+ {
556
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
557
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
558
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
559
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
560
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
561
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
562
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
563
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
564
+ }
565
+
566
+ return weight;
567
+ }
568
+
569
+ template <typename scalar_t>
570
+ __global__ void modulated_deformable_im2col_gpu_kernel(const int n,
571
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
572
+ const int height, const int width, const int kernel_h, const int kernel_w,
573
+ const int pad_h, const int pad_w,
574
+ const int stride_h, const int stride_w,
575
+ const int dilation_h, const int dilation_w,
576
+ const int channel_per_deformable_group,
577
+ const int batch_size, const int num_channels, const int deformable_group,
578
+ const int height_col, const int width_col,
579
+ scalar_t *data_col)
580
+ {
581
+ CUDA_KERNEL_LOOP(index, n)
582
+ {
583
+ // index index of output matrix
584
+ const int w_col = index % width_col;
585
+ const int h_col = (index / width_col) % height_col;
586
+ const int b_col = (index / width_col / height_col) % batch_size;
587
+ const int c_im = (index / width_col / height_col) / batch_size;
588
+ const int c_col = c_im * kernel_h * kernel_w;
589
+
590
+ // compute deformable group index
591
+ const int deformable_group_index = c_im / channel_per_deformable_group;
592
+
593
+ const int h_in = h_col * stride_h - pad_h;
594
+ const int w_in = w_col * stride_w - pad_w;
595
+
596
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
597
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
598
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
599
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
600
+
601
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
602
+
603
+ for (int i = 0; i < kernel_h; ++i)
604
+ {
605
+ for (int j = 0; j < kernel_w; ++j)
606
+ {
607
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
608
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
609
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
610
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
611
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
612
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
613
+ scalar_t val = static_cast<scalar_t>(0);
614
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
615
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
616
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
617
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
618
+ {
619
+ //const float map_h = i * dilation_h + offset_h;
620
+ //const float map_w = j * dilation_w + offset_w;
621
+ //const int cur_height = height - h_in;
622
+ //const int cur_width = width - w_in;
623
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
624
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
625
+ }
626
+ *data_col_ptr = val * mask;
627
+ data_col_ptr += batch_size * height_col * width_col;
628
+ //data_col_ptr += height_col * width_col;
629
+ }
630
+ }
631
+ }
632
+ }
633
+
634
+ template <typename scalar_t>
635
+ __global__ void modulated_deformable_col2im_gpu_kernel(const int n,
636
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
637
+ const int channels, const int height, const int width,
638
+ const int kernel_h, const int kernel_w,
639
+ const int pad_h, const int pad_w,
640
+ const int stride_h, const int stride_w,
641
+ const int dilation_h, const int dilation_w,
642
+ const int channel_per_deformable_group,
643
+ const int batch_size, const int deformable_group,
644
+ const int height_col, const int width_col,
645
+ scalar_t *grad_im)
646
+ {
647
+ CUDA_KERNEL_LOOP(index, n)
648
+ {
649
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
650
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
651
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
652
+ // compute the start and end of the output
653
+
654
+ const int deformable_group_index = c / channel_per_deformable_group;
655
+
656
+ int w_out = index % width_col;
657
+ int h_out = (index / width_col) % height_col;
658
+ int b = (index / width_col / height_col) % batch_size;
659
+ int w_in = w_out * stride_w - pad_w;
660
+ int h_in = h_out * stride_h - pad_h;
661
+
662
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
663
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
664
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
665
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
666
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
667
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
668
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
669
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
670
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
671
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
672
+
673
+ const scalar_t cur_top_grad = data_col[index] * mask;
674
+ const int cur_h = (int)cur_inv_h_data;
675
+ const int cur_w = (int)cur_inv_w_data;
676
+ for (int dy = -2; dy <= 2; dy++)
677
+ {
678
+ for (int dx = -2; dx <= 2; dx++)
679
+ {
680
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
681
+ cur_w + dx >= 0 && cur_w + dx < width &&
682
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
683
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
684
+ {
685
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
686
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
687
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
688
+ }
689
+ }
690
+ }
691
+ }
692
+ }
693
+
694
+ template <typename scalar_t>
695
+ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
696
+ const scalar_t *data_col, const scalar_t *data_im,
697
+ const scalar_t *data_offset, const scalar_t *data_mask,
698
+ const int channels, const int height, const int width,
699
+ const int kernel_h, const int kernel_w,
700
+ const int pad_h, const int pad_w,
701
+ const int stride_h, const int stride_w,
702
+ const int dilation_h, const int dilation_w,
703
+ const int channel_per_deformable_group,
704
+ const int batch_size, const int offset_channels, const int deformable_group,
705
+ const int height_col, const int width_col,
706
+ scalar_t *grad_offset, scalar_t *grad_mask)
707
+ {
708
+ CUDA_KERNEL_LOOP(index, n)
709
+ {
710
+ scalar_t val = 0, mval = 0;
711
+ int w = index % width_col;
712
+ int h = (index / width_col) % height_col;
713
+ int c = (index / width_col / height_col) % offset_channels;
714
+ int b = (index / width_col / height_col) / offset_channels;
715
+ // compute the start and end of the output
716
+
717
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
718
+ const int col_step = kernel_h * kernel_w;
719
+ int cnt = 0;
720
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
721
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
722
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
723
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
724
+
725
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
726
+
727
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
728
+ {
729
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
730
+ const int bp_dir = offset_c % 2;
731
+
732
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
733
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
734
+ int w_out = col_pos % width_col;
735
+ int h_out = (col_pos / width_col) % height_col;
736
+ int w_in = w_out * stride_w - pad_w;
737
+ int h_in = h_out * stride_h - pad_h;
738
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
739
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
740
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
741
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
742
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
743
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
744
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
745
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
746
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
747
+ {
748
+ inv_h = inv_w = -2;
749
+ }
750
+ else
751
+ {
752
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
753
+ }
754
+ const scalar_t weight = dmcn_get_coordinate_weight(
755
+ inv_h, inv_w,
756
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
757
+ val += weight * data_col_ptr[col_pos] * mask;
758
+ cnt += 1;
759
+ }
760
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
761
+ grad_offset[index] = val;
762
+ if (offset_c % 2 == 0)
763
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
764
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
765
+ }
766
+ }
767
+
768
+ void modulated_deformable_im2col_cuda(
769
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
770
+ const int batch_size, const int channels, const int height_im, const int width_im,
771
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
772
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
773
+ const int dilation_h, const int dilation_w,
774
+ const int deformable_group, at::Tensor data_col)
775
+ {
776
+ // num_axes should be smaller than block size
777
+ const int channel_per_deformable_group = channels / deformable_group;
778
+ const int num_kernels = channels * batch_size * height_col * width_col;
779
+
780
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
781
+ data_im.type(), "modulated_deformable_im2col_gpu", ([&] {
782
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
783
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
784
+ const scalar_t *data_mask_ = data_mask.data<scalar_t>();
785
+ scalar_t *data_col_ = data_col.data<scalar_t>();
786
+
787
+ modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
788
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
789
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
790
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
791
+ }));
792
+
793
+ cudaError_t err = cudaGetLastError();
794
+ if (err != cudaSuccess)
795
+ {
796
+ // printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
797
+ }
798
+ }
799
+
800
+ void modulated_deformable_col2im_cuda(
801
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
802
+ const int batch_size, const int channels, const int height_im, const int width_im,
803
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
804
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
805
+ const int dilation_h, const int dilation_w,
806
+ const int deformable_group, at::Tensor grad_im)
807
+ {
808
+
809
+ const int channel_per_deformable_group = channels / deformable_group;
810
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
811
+
812
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
813
+ data_col.type(), "modulated_deformable_col2im_gpu", ([&] {
814
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
815
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
816
+ const scalar_t *data_mask_ = data_mask.data<scalar_t>();
817
+ scalar_t *grad_im_ = grad_im.data<scalar_t>();
818
+
819
+ modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
820
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
821
+ kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
822
+ dilation_h, dilation_w, channel_per_deformable_group,
823
+ batch_size, deformable_group, height_col, width_col, grad_im_);
824
+ }));
825
+
826
+ cudaError_t err = cudaGetLastError();
827
+ if (err != cudaSuccess)
828
+ {
829
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
830
+ }
831
+ }
832
+
833
+ void modulated_deformable_col2im_coord_cuda(
834
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
835
+ const int batch_size, const int channels, const int height_im, const int width_im,
836
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
837
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
838
+ const int dilation_h, const int dilation_w,
839
+ const int deformable_group,
840
+ at::Tensor grad_offset, at::Tensor grad_mask)
841
+ {
842
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
843
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
844
+
845
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
846
+ data_col.type(), "modulated_deformable_col2im_coord_gpu", ([&] {
847
+ const scalar_t *data_col_ = data_col.data<scalar_t>();
848
+ const scalar_t *data_im_ = data_im.data<scalar_t>();
849
+ const scalar_t *data_offset_ = data_offset.data<scalar_t>();
850
+ const scalar_t *data_mask_ = data_mask.data<scalar_t>();
851
+ scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
852
+ scalar_t *grad_mask_ = grad_mask.data<scalar_t>();
853
+
854
+ modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
855
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
856
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
857
+ dilation_h, dilation_w, channel_per_deformable_group,
858
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
859
+ grad_offset_, grad_mask_);
860
+ }));
861
+ cudaError_t err = cudaGetLastError();
862
+ if (err != cudaSuccess)
863
+ {
864
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
865
+ }
866
+ }
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda.cpp ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // modify from
2
+ // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
3
+
4
+ // based on
5
+ // author: Charles Shang
6
+ // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
7
+
8
+ #include <torch/extension.h>
9
+
10
+ #include <cmath>
11
+ #include <vector>
12
+
13
+ void DeformablePSROIPoolForward(
14
+ const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
15
+ at::Tensor out, at::Tensor top_count, const int batch, const int channels,
16
+ const int height, const int width, const int num_bbox,
17
+ const int channels_trans, const int no_trans, const float spatial_scale,
18
+ const int output_dim, const int group_size, const int pooled_size,
19
+ const int part_size, const int sample_per_part, const float trans_std);
20
+
21
+ void DeformablePSROIPoolBackwardAcc(
22
+ const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
23
+ const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
24
+ at::Tensor trans_grad, const int batch, const int channels,
25
+ const int height, const int width, const int num_bbox,
26
+ const int channels_trans, const int no_trans, const float spatial_scale,
27
+ const int output_dim, const int group_size, const int pooled_size,
28
+ const int part_size, const int sample_per_part, const float trans_std);
29
+
30
+ void deform_psroi_pooling_cuda_forward(
31
+ at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
32
+ at::Tensor top_count, const int no_trans, const float spatial_scale,
33
+ const int output_dim, const int group_size, const int pooled_size,
34
+ const int part_size, const int sample_per_part, const float trans_std) {
35
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
36
+
37
+ const int batch = input.size(0);
38
+ const int channels = input.size(1);
39
+ const int height = input.size(2);
40
+ const int width = input.size(3);
41
+ const int channels_trans = no_trans ? 2 : trans.size(1);
42
+
43
+ const int num_bbox = bbox.size(0);
44
+ if (num_bbox != out.size(0))
45
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
46
+ out.size(0), num_bbox);
47
+
48
+ DeformablePSROIPoolForward(
49
+ input, bbox, trans, out, top_count, batch, channels, height, width,
50
+ num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
51
+ pooled_size, part_size, sample_per_part, trans_std);
52
+ }
53
+
54
+ void deform_psroi_pooling_cuda_backward(
55
+ at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
56
+ at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
57
+ const int no_trans, const float spatial_scale, const int output_dim,
58
+ const int group_size, const int pooled_size, const int part_size,
59
+ const int sample_per_part, const float trans_std) {
60
+ TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
61
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
62
+
63
+ const int batch = input.size(0);
64
+ const int channels = input.size(1);
65
+ const int height = input.size(2);
66
+ const int width = input.size(3);
67
+ const int channels_trans = no_trans ? 2 : trans.size(1);
68
+
69
+ const int num_bbox = bbox.size(0);
70
+ if (num_bbox != out_grad.size(0))
71
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
72
+ out_grad.size(0), num_bbox);
73
+
74
+ DeformablePSROIPoolBackwardAcc(
75
+ out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
76
+ channels, height, width, num_bbox, channels_trans, no_trans,
77
+ spatial_scale, output_dim, group_size, pooled_size, part_size,
78
+ sample_per_part, trans_std);
79
+ }
80
+
81
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
82
+ m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
83
+ "deform psroi pooling forward(CUDA)");
84
+ m.def("deform_psroi_pooling_cuda_backward",
85
+ &deform_psroi_pooling_cuda_backward,
86
+ "deform psroi pooling backward(CUDA)");
87
+ }
IndicPhotoOCR/detection/textbpn/network/backbone/assets/dcn/src/deform_pool_cuda_kernel.cu ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * Copyright (c) 2017 Microsoft
3
+ * Licensed under The MIT License [see LICENSE for details]
4
+ * \file deformable_psroi_pooling.cu
5
+ * \brief
6
+ * \author Yi Li, Guodong Zhang, Jifeng Dai
7
+ */
8
+ /***************** Adapted by Charles Shang *********************/
9
+ // modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
10
+
11
+ #include <ATen/ATen.h>
12
+ #include <THC/THCAtomics.cuh>
13
+ #include <stdio.h>
14
+ #include <math.h>
15
+ #include <algorithm>
16
+
17
+ using namespace at;
18
+
19
+ #define CUDA_KERNEL_LOOP(i, n) \
20
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
21
+ i < (n); \
22
+ i += blockDim.x * gridDim.x)
23
+
24
+ const int CUDA_NUM_THREADS = 1024;
25
+ inline int GET_BLOCKS(const int N)
26
+ {
27
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
28
+ }
29
+
30
+ template <typename scalar_t>
31
+ __device__ scalar_t bilinear_interp(
32
+ const scalar_t *data,
33
+ const scalar_t x,
34
+ const scalar_t y,
35
+ const int width,
36
+ const int height)
37
+ {
38
+ int x1 = floor(x);
39
+ int x2 = ceil(x);
40
+ int y1 = floor(y);
41
+ int y2 = ceil(y);
42
+ scalar_t dist_x = (scalar_t)(x - x1);
43
+ scalar_t dist_y = (scalar_t)(y - y1);
44
+ scalar_t value11 = data[y1 * width + x1];
45
+ scalar_t value12 = data[y2 * width + x1];
46
+ scalar_t value21 = data[y1 * width + x2];
47
+ scalar_t value22 = data[y2 * width + x2];
48
+ scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
49
+ return value;
50
+ }
51
+
52
+ template <typename scalar_t>
53
+ __global__ void DeformablePSROIPoolForwardKernel(
54
+ const int count,
55
+ const scalar_t *bottom_data,
56
+ const scalar_t spatial_scale,
57
+ const int channels,
58
+ const int height, const int width,
59
+ const int pooled_height, const int pooled_width,
60
+ const scalar_t *bottom_rois, const scalar_t *bottom_trans,
61
+ const int no_trans,
62
+ const scalar_t trans_std,
63
+ const int sample_per_part,
64
+ const int output_dim,
65
+ const int group_size,
66
+ const int part_size,
67
+ const int num_classes,
68
+ const int channels_each_class,
69
+ scalar_t *top_data,
70
+ scalar_t *top_count)
71
+ {
72
+ CUDA_KERNEL_LOOP(index, count)
73
+ {
74
+ // The output is in order (n, ctop, ph, pw)
75
+ int pw = index % pooled_width;
76
+ int ph = (index / pooled_width) % pooled_height;
77
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
78
+ int n = index / pooled_width / pooled_height / output_dim;
79
+
80
+ // [start, end) interval for spatial sampling
81
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
82
+ int roi_batch_ind = offset_bottom_rois[0];
83
+ scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
84
+ scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
85
+ scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
86
+ scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
87
+
88
+ // Force too small ROIs to be 1x1
89
+ scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
90
+ scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
91
+
92
+ // Compute w and h at bottom
93
+ scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
94
+ scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
95
+
96
+ scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
97
+ scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
98
+
99
+ int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
100
+ int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
101
+ int class_id = ctop / channels_each_class;
102
+ scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
103
+ scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
104
+
105
+ scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
106
+ wstart += trans_x * roi_width;
107
+ scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
108
+ hstart += trans_y * roi_height;
109
+
110
+ scalar_t sum = 0;
111
+ int count = 0;
112
+ int gw = floor((scalar_t)(pw)*group_size / pooled_width);
113
+ int gh = floor((scalar_t)(ph)*group_size / pooled_height);
114
+ gw = min(max(gw, 0), group_size - 1);
115
+ gh = min(max(gh, 0), group_size - 1);
116
+
117
+ const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
118
+ for (int ih = 0; ih < sample_per_part; ih++)
119
+ {
120
+ for (int iw = 0; iw < sample_per_part; iw++)
121
+ {
122
+ scalar_t w = wstart + iw * sub_bin_size_w;
123
+ scalar_t h = hstart + ih * sub_bin_size_h;
124
+ // bilinear interpolation
125
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
126
+ {
127
+ continue;
128
+ }
129
+ w = min(max(w, 0.), width - 1.);
130
+ h = min(max(h, 0.), height - 1.);
131
+ int c = (ctop * group_size + gh) * group_size + gw;
132
+ scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
133
+ sum += val;
134
+ count++;
135
+ }
136
+ }
137
+ top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
138
+ top_count[index] = count;
139
+ }
140
+ }
141
+
142
+ template <typename scalar_t>
143
+ __global__ void DeformablePSROIPoolBackwardAccKernel(
144
+ const int count,
145
+ const scalar_t *top_diff,
146
+ const scalar_t *top_count,
147
+ const int num_rois,
148
+ const scalar_t spatial_scale,
149
+ const int channels,
150
+ const int height, const int width,
151
+ const int pooled_height, const int pooled_width,
152
+ const int output_dim,
153
+ scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
154
+ const scalar_t *bottom_data,
155
+ const scalar_t *bottom_rois,
156
+ const scalar_t *bottom_trans,
157
+ const int no_trans,
158
+ const scalar_t trans_std,
159
+ const int sample_per_part,
160
+ const int group_size,
161
+ const int part_size,
162
+ const int num_classes,
163
+ const int channels_each_class)
164
+ {
165
+ CUDA_KERNEL_LOOP(index, count)
166
+ {
167
+ // The output is in order (n, ctop, ph, pw)
168
+ int pw = index % pooled_width;
169
+ int ph = (index / pooled_width) % pooled_height;
170
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
171
+ int n = index / pooled_width / pooled_height / output_dim;
172
+
173
+ // [start, end) interval for spatial sampling
174
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
175
+ int roi_batch_ind = offset_bottom_rois[0];
176
+ scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
177
+ scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
178
+ scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
179
+ scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
180
+
181
+ // Force too small ROIs to be 1x1
182
+ scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
183
+ scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
184
+
185
+ // Compute w and h at bottom
186
+ scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
187
+ scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
188
+
189
+ scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
190
+ scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
191
+
192
+ int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
193
+ int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
194
+ int class_id = ctop / channels_each_class;
195
+ scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
196
+ scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
197
+
198
+ scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
199
+ wstart += trans_x * roi_width;
200
+ scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
201
+ hstart += trans_y * roi_height;
202
+
203
+ if (top_count[index] <= 0)
204
+ {
205
+ continue;
206
+ }
207
+ scalar_t diff_val = top_diff[index] / top_count[index];
208
+ const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
209
+ scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
210
+ int gw = floor((scalar_t)(pw)*group_size / pooled_width);
211
+ int gh = floor((scalar_t)(ph)*group_size / pooled_height);
212
+ gw = min(max(gw, 0), group_size - 1);
213
+ gh = min(max(gh, 0), group_size - 1);
214
+
215
+ for (int ih = 0; ih < sample_per_part; ih++)
216
+ {
217
+ for (int iw = 0; iw < sample_per_part; iw++)
218
+ {
219
+ scalar_t w = wstart + iw * sub_bin_size_w;
220
+ scalar_t h = hstart + ih * sub_bin_size_h;
221
+ // bilinear interpolation
222
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
223
+ {
224
+ continue;
225
+ }
226
+ w = min(max(w, 0.), width - 1.);
227
+ h = min(max(h, 0.), height - 1.);
228
+ int c = (ctop * group_size + gh) * group_size + gw;
229
+ // backward on feature
230
+ int x0 = floor(w);
231
+ int x1 = ceil(w);
232
+ int y0 = floor(h);
233
+ int y1 = ceil(h);
234
+ scalar_t dist_x = w - x0, dist_y = h - y0;
235
+ scalar_t q00 = (1 - dist_x) * (1 - dist_y);
236
+ scalar_t q01 = (1 - dist_x) * dist_y;
237
+ scalar_t q10 = dist_x * (1 - dist_y);
238
+ scalar_t q11 = dist_x * dist_y;
239
+ int bottom_index_base = c * height * width;
240
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
241
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
242
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
243
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
244
+
245
+ if (no_trans)
246
+ {
247
+ continue;
248
+ }
249
+ scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
250
+ scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
251
+ scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
252
+ scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
253
+ scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
254
+ diff_x *= roi_width;
255
+ scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
256
+ diff_y *= roi_height;
257
+
258
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
259
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
260
+ }
261
+ }
262
+ }
263
+ }
264
+
265
+ void DeformablePSROIPoolForward(const at::Tensor data,
266
+ const at::Tensor bbox,
267
+ const at::Tensor trans,
268
+ at::Tensor out,
269
+ at::Tensor top_count,
270
+ const int batch,
271
+ const int channels,
272
+ const int height,
273
+ const int width,
274
+ const int num_bbox,
275
+ const int channels_trans,
276
+ const int no_trans,
277
+ const float spatial_scale,
278
+ const int output_dim,
279
+ const int group_size,
280
+ const int pooled_size,
281
+ const int part_size,
282
+ const int sample_per_part,
283
+ const float trans_std)
284
+ {
285
+ const int pooled_height = pooled_size;
286
+ const int pooled_width = pooled_size;
287
+ const int count = num_bbox * output_dim * pooled_height * pooled_width;
288
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
289
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
290
+
291
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
292
+ data.type(), "deformable_psroi_pool_forward", ([&] {
293
+ const scalar_t *bottom_data = data.data<scalar_t>();
294
+ const scalar_t *bottom_rois = bbox.data<scalar_t>();
295
+ const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
296
+ scalar_t *top_data = out.data<scalar_t>();
297
+ scalar_t *top_count_data = top_count.data<scalar_t>();
298
+
299
+ DeformablePSROIPoolForwardKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
300
+ count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
301
+ bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
302
+ group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
303
+ }));
304
+
305
+ cudaError_t err = cudaGetLastError();
306
+ if (err != cudaSuccess)
307
+ {
308
+ printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
309
+ }
310
+ }
311
+
312
+ void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
313
+ const at::Tensor data,
314
+ const at::Tensor bbox,
315
+ const at::Tensor trans,
316
+ const at::Tensor top_count,
317
+ at::Tensor in_grad,
318
+ at::Tensor trans_grad,
319
+ const int batch,
320
+ const int channels,
321
+ const int height,
322
+ const int width,
323
+ const int num_bbox,
324
+ const int channels_trans,
325
+ const int no_trans,
326
+ const float spatial_scale,
327
+ const int output_dim,
328
+ const int group_size,
329
+ const int pooled_size,
330
+ const int part_size,
331
+ const int sample_per_part,
332
+ const float trans_std)
333
+ {
334
+ // LOG(INFO) << "DeformablePSROIPoolBackward";
335
+ const int num_rois = num_bbox;
336
+ const int pooled_height = pooled_size;
337
+ const int pooled_width = pooled_size;
338
+ const int count = num_bbox * output_dim * pooled_height * pooled_width;
339
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
340
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
341
+
342
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
343
+ out_grad.type(), "deformable_psroi_pool_backward_acc", ([&] {
344
+ const scalar_t *top_diff = out_grad.data<scalar_t>();
345
+ const scalar_t *bottom_data = data.data<scalar_t>();
346
+ const scalar_t *bottom_rois = bbox.data<scalar_t>();
347
+ const scalar_t *bottom_trans = no_trans ? NULL : trans.data<scalar_t>();
348
+ scalar_t *bottom_data_diff = in_grad.data<scalar_t>();
349
+ scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data<scalar_t>();
350
+ const scalar_t *top_count_data = top_count.data<scalar_t>();
351
+
352
+ DeformablePSROIPoolBackwardAccKernel<<<GET_BLOCKS(count), CUDA_NUM_THREADS>>>(
353
+ count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
354
+ pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
355
+ bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
356
+ group_size, part_size, num_classes, channels_each_class);
357
+ }));
358
+
359
+ cudaError_t err = cudaGetLastError();
360
+ if (err != cudaSuccess)
361
+ {
362
+ printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
363
+ }
364
+ }
IndicPhotoOCR/detection/textbpn/network/backbone/resnet.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import math
3
+ import torch.utils.model_zoo as model_zoo
4
+ BatchNorm2d = nn.BatchNorm2d
5
+
6
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7
+ 'resnet152']
8
+
9
+
10
+ model_urls = {
11
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16
+ }
17
+
18
+
19
+ def constant_init(module, constant, bias=0):
20
+ nn.init.constant_(module.weight, constant)
21
+ if hasattr(module, 'bias'):
22
+ nn.init.constant_(module.bias, bias)
23
+
24
+
25
+ def conv3x3(in_planes, out_planes, stride=1):
26
+ """3x3 convolution with padding"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=1, bias=False)
29
+
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
35
+ super(BasicBlock, self).__init__()
36
+ self.with_dcn = dcn is not None
37
+ self.conv1 = conv3x3(inplanes, planes, stride)
38
+ self.bn1 = BatchNorm2d(planes)
39
+ self.relu = nn.ReLU(inplace=True)
40
+ self.with_modulated_dcn = False
41
+ if self.with_dcn:
42
+ fallback_on_stride = dcn.get('fallback_on_stride', False)
43
+ self.with_modulated_dcn = dcn.get('modulated', False)
44
+ # self.conv2 = conv3x3(planes, planes)
45
+ if not self.with_dcn or fallback_on_stride:
46
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
47
+ padding=1, bias=False)
48
+ else:
49
+ deformable_groups = dcn.get('deformable_groups', 1)
50
+ if not self.with_modulated_dcn:
51
+ from network.backbone.assets.dcn import DeformConv
52
+ conv_op = DeformConv
53
+ offset_channels = 18
54
+ else:
55
+ from network.backbone.assets.dcn import ModulatedDeformConv
56
+ conv_op = ModulatedDeformConv
57
+ offset_channels = 27
58
+ self.conv2_offset = nn.Conv2d(
59
+ planes,
60
+ deformable_groups * offset_channels,
61
+ kernel_size=3,
62
+ padding=1)
63
+ self.conv2 = conv_op(
64
+ planes,
65
+ planes,
66
+ kernel_size=3,
67
+ padding=1,
68
+ deformable_groups=deformable_groups,
69
+ bias=False)
70
+ self.bn2 = BatchNorm2d(planes)
71
+ self.downsample = downsample
72
+ self.stride = stride
73
+
74
+ def forward(self, x):
75
+ residual = x
76
+
77
+ out = self.conv1(x)
78
+ out = self.bn1(out)
79
+ out = self.relu(out)
80
+
81
+ # out = self.conv2(out)
82
+ if not self.with_dcn:
83
+ out = self.conv2(out)
84
+ elif self.with_modulated_dcn:
85
+ offset_mask = self.conv2_offset(out)
86
+ offset = offset_mask[:, :18, :, :]
87
+ mask = offset_mask[:, -9:, :, :].sigmoid()
88
+ out = self.conv2(out, offset, mask)
89
+ else:
90
+ offset = self.conv2_offset(out)
91
+ out = self.conv2(out, offset)
92
+ out = self.bn2(out)
93
+
94
+ if self.downsample is not None:
95
+ residual = self.downsample(x)
96
+
97
+ out += residual
98
+ out = self.relu(out)
99
+
100
+ return out
101
+
102
+
103
+ class Bottleneck(nn.Module):
104
+ expansion = 4
105
+
106
+ def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
107
+ super(Bottleneck, self).__init__()
108
+ self.with_dcn = dcn is not None
109
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
110
+ self.bn1 = BatchNorm2d(planes)
111
+ fallback_on_stride = False
112
+ self.with_modulated_dcn = False
113
+ if self.with_dcn:
114
+ fallback_on_stride = dcn.get('fallback_on_stride', False)
115
+ self.with_modulated_dcn = dcn.get('modulated', False)
116
+ if not self.with_dcn or fallback_on_stride:
117
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
118
+ stride=stride, padding=1, bias=False)
119
+ else:
120
+ deformable_groups = dcn.get('deformable_groups', 1)
121
+ if not self.with_modulated_dcn:
122
+ from network.backbone.assets.dcn import DeformConv
123
+ conv_op = DeformConv
124
+ offset_channels = 18
125
+ else:
126
+ from network.backbone.assets.dcn import ModulatedDeformConv
127
+ conv_op = ModulatedDeformConv
128
+ offset_channels = 27
129
+ self.conv2_offset = nn.Conv2d(
130
+ planes, deformable_groups * offset_channels,
131
+ kernel_size=3,
132
+ padding=1)
133
+ self.conv2 = conv_op(
134
+ planes, planes, kernel_size=3, padding=1, stride=stride,
135
+ deformable_groups=deformable_groups, bias=False)
136
+ self.bn2 = BatchNorm2d(planes)
137
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
138
+ self.bn3 = BatchNorm2d(planes * 4)
139
+ self.relu = nn.ReLU(inplace=True)
140
+ self.downsample = downsample
141
+ self.stride = stride
142
+ self.dcn = dcn
143
+ self.with_dcn = dcn is not None
144
+
145
+ def forward(self, x):
146
+ residual = x
147
+
148
+ out = self.conv1(x)
149
+ out = self.bn1(out)
150
+ out = self.relu(out)
151
+
152
+ # out = self.conv2(out)
153
+ if not self.with_dcn:
154
+ out = self.conv2(out)
155
+ elif self.with_modulated_dcn:
156
+ offset_mask = self.conv2_offset(out)
157
+ offset = offset_mask[:, :18, :, :]
158
+ mask = offset_mask[:, -9:, :, :].sigmoid()
159
+ out = self.conv2(out, offset, mask)
160
+ else:
161
+ offset = self.conv2_offset(out)
162
+ out = self.conv2(out, offset)
163
+ out = self.bn2(out)
164
+ out = self.relu(out)
165
+
166
+ out = self.conv3(out)
167
+ out = self.bn3(out)
168
+
169
+ if self.downsample is not None:
170
+ residual = self.downsample(x)
171
+
172
+ out += residual
173
+ out = self.relu(out)
174
+
175
+ return out
176
+
177
+
178
+ class ResNet(nn.Module):
179
+ def __init__(self, block, layers, num_classes=1000,
180
+ dcn=None, stage_with_dcn=(False, False, False, False)):
181
+ self.dcn = dcn
182
+ self.stage_with_dcn = stage_with_dcn
183
+ self.inplanes = 64
184
+ super(ResNet, self).__init__()
185
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
186
+ bias=False)
187
+ self.bn1 = BatchNorm2d(64)
188
+ self.relu = nn.ReLU(inplace=True)
189
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
190
+ self.layer1 = self._make_layer(block, 64, layers[0])
191
+ self.layer2 = self._make_layer(
192
+ block, 128, layers[1], stride=2, dcn=dcn)
193
+ self.layer3 = self._make_layer(
194
+ block, 256, layers[2], stride=2, dcn=dcn)
195
+ self.layer4 = self._make_layer(
196
+ block, 512, layers[3], stride=2, dcn=dcn)
197
+ self.avgpool = nn.AvgPool2d(7, stride=1)
198
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
199
+
200
+ self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
201
+
202
+ for m in self.modules():
203
+ if isinstance(m, nn.Conv2d):
204
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
205
+ m.weight.data.normal_(0, math.sqrt(2. / n))
206
+ elif isinstance(m, BatchNorm2d):
207
+ m.weight.data.fill_(1)
208
+ m.bias.data.zero_()
209
+ if self.dcn is not None:
210
+ for m in self.modules():
211
+ if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
212
+ if hasattr(m, 'conv2_offset'):
213
+ constant_init(m.conv2_offset, 0)
214
+
215
+ def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
216
+ downsample = None
217
+ if stride != 1 or self.inplanes != planes * block.expansion:
218
+ downsample = nn.Sequential(
219
+ nn.Conv2d(self.inplanes, planes * block.expansion,
220
+ kernel_size=1, stride=stride, bias=False),
221
+ BatchNorm2d(planes * block.expansion),
222
+ )
223
+
224
+ layers = []
225
+ layers.append(block(self.inplanes, planes,
226
+ stride, downsample, dcn=dcn))
227
+ self.inplanes = planes * block.expansion
228
+ for i in range(1, blocks):
229
+ layers.append(block(self.inplanes, planes, dcn=dcn))
230
+
231
+ return nn.Sequential(*layers)
232
+
233
+ def forward(self, x):
234
+ x = self.conv1(x)
235
+ x = self.bn1(x)
236
+ x = self.relu(x)
237
+ x1 = self.maxpool(x)
238
+
239
+ x2 = self.layer1(x1)
240
+ x3 = self.layer2(x2)
241
+ x4 = self.layer3(x3)
242
+ x5 = self.layer4(x4)
243
+
244
+ return x1, x2, x3, x4, x5
245
+
246
+
247
+ def resnet18(pretrained=True, **kwargs):
248
+ """Constructs a ResNet-18 model.
249
+ Args:
250
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
251
+ """
252
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
253
+ if pretrained:
254
+ model.load_state_dict(model_zoo.load_url(
255
+ model_urls['resnet18']), strict=False)
256
+ return model
257
+
258
+ def deformable_resnet18(pretrained=True, **kwargs):
259
+ """Constructs a ResNet-18 model.
260
+ Args:
261
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
262
+ """
263
+ model = ResNet(BasicBlock, [2, 2, 2, 2],
264
+ dcn=dict(modulated=True,
265
+ deformable_groups=1,
266
+ fallback_on_stride=False),
267
+ stage_with_dcn=[False, True, True, True], **kwargs)
268
+ if pretrained:
269
+ model.load_state_dict(model_zoo.load_url(
270
+ model_urls['resnet18']), strict=False)
271
+ return model
272
+
273
+
274
+ def resnet34(pretrained=True, **kwargs):
275
+ """Constructs a ResNet-34 model.
276
+ Args:
277
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
278
+ """
279
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
280
+ if pretrained:
281
+ model.load_state_dict(model_zoo.load_url(
282
+ model_urls['resnet34']), strict=False)
283
+ return model
284
+
285
+
286
+ def resnet50(pretrained=True, **kwargs):
287
+ """Constructs a ResNet-50 model.
288
+ Args:
289
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
290
+ """
291
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
292
+ if pretrained:
293
+ model.load_state_dict(model_zoo.load_url(
294
+ model_urls['resnet50']), strict=False)
295
+ return model
296
+
297
+
298
+ def deformable_resnet50(pretrained=True, **kwargs):
299
+ """Constructs a ResNet-50 model with deformable conv.
300
+ Args:
301
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
302
+ """
303
+ model = ResNet(Bottleneck, [3, 4, 6, 3],
304
+ dcn=dict(modulated=True,
305
+ deformable_groups=1,
306
+ fallback_on_stride=False),
307
+ stage_with_dcn=[False, True, True, True],
308
+ **kwargs)
309
+ if pretrained:
310
+ model.load_state_dict(model_zoo.load_url(
311
+ model_urls['resnet50']), strict=False)
312
+ return model
313
+
314
+
315
+ def resnet101(pretrained=True, **kwargs):
316
+ """Constructs a ResNet-101 model.
317
+ Args:
318
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
319
+ """
320
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
321
+ if pretrained:
322
+ model.load_state_dict(model_zoo.load_url(
323
+ model_urls['resnet101']), strict=False)
324
+ return model
325
+
326
+
327
+ def resnet152(pretrained=True, **kwargs):
328
+ """Constructs a ResNet-152 model.
329
+ Args:
330
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
331
+ """
332
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
333
+ if pretrained:
334
+ model.load_state_dict(model_zoo.load_url(
335
+ model_urls['resnet152']), strict=False)
336
+ return model
IndicPhotoOCR/detection/textbpn/network/backbone/vgg.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.utils.model_zoo as model_zoo
3
+ import torchvision.models as models
4
+
5
+ model_urls = {
6
+ 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
7
+ 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
8
+ 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
9
+ 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
10
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
11
+ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
12
+ }
13
+
14
+
15
+ class VggNet(nn.Module):
16
+ def __init__(self, name="vgg16", pretrain=True):
17
+ super().__init__()
18
+ if name == "vgg16":
19
+ base_net = models.vgg16(pretrained=False)
20
+ elif name == "vgg16_bn":
21
+ base_net = models.vgg16_bn(pretrained=False)
22
+ else:
23
+ print(" base model is not support !")
24
+ if pretrain:
25
+ print("load the {} weight from ./cache".format(name))
26
+ base_net.load_state_dict(model_zoo.load_url(model_urls[name], model_dir="./cache"))
27
+
28
+ if name == "vgg16":
29
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 5)])
30
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(5, 10)])
31
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(10, 17)])
32
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(17, 24)])
33
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 31)])
34
+ elif name == "vgg16_bn":
35
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 7)])
36
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(7, 14)])
37
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(14, 24)])
38
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 34)])
39
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(34, 44)])
40
+
41
+ def forward(self, x):
42
+ C1 = self.stage1(x)
43
+ C2 = self.stage2(C1)
44
+ C3 = self.stage3(C2)
45
+ C4 = self.stage4(C3)
46
+ C5 = self.stage5(C4)
47
+
48
+ return C1, C2, C3, C4, C5
49
+
50
+
51
+ if __name__ == '__main__':
52
+ import torch
53
+ input = torch.randn((4, 3, 512, 512))
54
+ net = VggNet()
55
+ C1, C2, C3, C4, C5 = net(input)
56
+ print(C1.size())
57
+ print(C2.size())
58
+ print(C3.size())
59
+ print(C4.size())
60
+ print(C5.size())
IndicPhotoOCR/detection/textbpn/network/layers/Adaptive_Deformation.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################################################################
2
+ # File Name: AdaptiveDeformation.py
3
+ # Author: S.X.Zhang
4
+ ###################################################################
5
+
6
+ from __future__ import print_function
7
+ from __future__ import division
8
+ from __future__ import absolute_import
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn import init
14
+
15
+
16
+ class MeanAggregator(nn.Module):
17
+ def __init__(self):
18
+ super(MeanAggregator, self).__init__()
19
+
20
+ def forward(self, features, A):
21
+ x = torch.bmm(A, features)
22
+ return x
23
+
24
+
25
+ class GraphConv(nn.Module):
26
+ def __init__(self, in_dim, out_dim, agg):
27
+ super(GraphConv, self).__init__()
28
+ self.in_dim = in_dim
29
+ self.out_dim = out_dim
30
+ self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
31
+ self.bias = nn.Parameter(torch.FloatTensor(out_dim))
32
+ init.xavier_uniform_(self.weight)
33
+ init.constant_(self.bias, 0)
34
+ self.agg = agg()
35
+
36
+ def forward(self, features, A):
37
+ b, n, d = features.shape
38
+ assert (d == self.in_dim)
39
+ agg_feats = self.agg(features, A)
40
+ cat_feats = torch.cat([features, agg_feats], dim=2)
41
+ out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
42
+ out = F.relu(out + self.bias)
43
+ return out
44
+
45
+
46
+ class AdaptiveDeformation(nn.Module):
47
+ def __init__(self, input, state_dim):
48
+ super(AdaptiveDeformation, self).__init__()
49
+ self.bn0 = nn.BatchNorm1d(input, affine=False)
50
+ self.conv1 = nn.Conv1d(input, state_dim, 1)
51
+ self.rnn = nn.LSTM(input, state_dim, 1, bidirectional=True)
52
+ self.gconv1 = GraphConv(input, 256, MeanAggregator)
53
+ self.gconv2 = GraphConv(256, 1024, MeanAggregator)
54
+ self.gconv3 = GraphConv(1024, 512, MeanAggregator)
55
+ self.gconv4 = GraphConv(512, state_dim, MeanAggregator)
56
+
57
+ self.prediction = nn.Sequential(
58
+ nn.Conv1d(4*state_dim, 128, 1),
59
+ nn.ReLU(inplace=True),
60
+ nn.Dropout(0.1),
61
+ nn.Conv1d(128, 64, 1),
62
+ nn.ReLU(inplace=True),
63
+ nn.Dropout(0.1),
64
+ nn.Conv1d(64, 2, 1))
65
+
66
+ def forward(self, x, A):
67
+ x = self.bn0(x)
68
+
69
+ # # rnn block
70
+ yl = x.permute(2, 0, 1)
71
+ yl, _ = self.rnn(yl)
72
+ yl = yl.permute(1, 2, 0)
73
+
74
+ # # gcn block
75
+ yg = x.permute(0, 2, 1)
76
+ b, n, c = yg.shape
77
+ A = A.expand(b, n, n)
78
+ yg = self.gconv1(yg, A)
79
+ yg = self.gconv2(yg, A)
80
+ yg = self.gconv3(yg, A)
81
+ yg = self.gconv4(yg, A)
82
+ yg = yg.permute(0, 2, 1)
83
+
84
+ # res block
85
+ x = torch.cat([yl, yg, self.conv1(x)], dim=1)
86
+ pred = self.prediction(x)
87
+
88
+ return pred
IndicPhotoOCR/detection/textbpn/network/layers/CircConv.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+
5
+ class CircConv(nn.Module):
6
+ def __init__(self, state_dim, out_state_dim=None, n_adj=4):
7
+ super(CircConv, self).__init__()
8
+
9
+ self.n_adj = n_adj
10
+ out_state_dim = state_dim if out_state_dim is None else out_state_dim
11
+ self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1)
12
+
13
+ def forward(self, input, adj):
14
+ input = torch.cat([input[..., -self.n_adj:], input, input[..., :self.n_adj]], dim=2)
15
+ return self.fc(input)
16
+
17
+
18
+ class DilatedCircConv(nn.Module):
19
+ def __init__(self, state_dim, out_state_dim=None, n_adj=4, dilation=1):
20
+ super(DilatedCircConv, self).__init__()
21
+
22
+ self.n_adj = n_adj
23
+ self.dilation = dilation
24
+ out_state_dim = state_dim if out_state_dim is None else out_state_dim
25
+ self.fc = nn.Conv1d(state_dim, out_state_dim, kernel_size=self.n_adj*2+1, dilation=self.dilation)
26
+
27
+ def forward(self, input, adj):
28
+ if self.n_adj != 0:
29
+ input = torch.cat([input[..., -self.n_adj*self.dilation:], input, input[..., :self.n_adj*self.dilation]], dim=2)
30
+ return self.fc(input)
31
+
32
+
33
+ _conv_factory = {
34
+ 'grid': CircConv,
35
+ 'dgrid': DilatedCircConv
36
+ }
37
+
38
+
39
+ class BasicBlock(nn.Module):
40
+ def __init__(self, state_dim, out_state_dim, conv_type, n_adj=4, dilation=1):
41
+ super(BasicBlock, self).__init__()
42
+
43
+ self.conv = _conv_factory[conv_type](state_dim, out_state_dim, n_adj, dilation)
44
+ self.relu = nn.ReLU(inplace=True)
45
+ self.norm = nn.BatchNorm1d(out_state_dim)
46
+
47
+ def forward(self, x, adj=None):
48
+ x = self.conv(x, adj)
49
+ x = self.relu(x)
50
+ x = self.norm(x)
51
+ return x
52
+
53
+
54
+ class DeepSnake(nn.Module):
55
+ def __init__(self, state_dim, feature_dim, conv_type='dgrid'):
56
+ super(DeepSnake, self).__init__()
57
+
58
+ self.head = BasicBlock(feature_dim, state_dim, conv_type)
59
+
60
+ self.res_layer_num = 7
61
+ dilation = [1, 1, 1, 2, 2, 4, 4]
62
+ for i in range(self.res_layer_num):
63
+ conv = BasicBlock(state_dim, state_dim, conv_type, n_adj=4, dilation=dilation[i])
64
+ self.__setattr__('res'+str(i), conv)
65
+
66
+ fusion_state_dim = 256
67
+ self.fusion = nn.Conv1d(state_dim * (self.res_layer_num + 1), fusion_state_dim, 1)
68
+ self.prediction = nn.Sequential(
69
+ nn.Conv1d(state_dim * (self.res_layer_num + 1) + fusion_state_dim, 256, 1),
70
+ nn.ReLU(inplace=True),
71
+ nn.Conv1d(256, 64, 1),
72
+ nn.ReLU(inplace=True),
73
+ nn.Conv1d(64, 2, 1)
74
+ )
75
+
76
+ def forward(self, x, adj):
77
+ states = []
78
+
79
+ x = self.head(x, adj)
80
+ states.append(x)
81
+ for i in range(self.res_layer_num):
82
+ x = self.__getattr__('res'+str(i))(x, adj) + x
83
+ states.append(x)
84
+
85
+ state = torch.cat(states, dim=1)
86
+ global_state = torch.max(self.fusion(state), dim=2, keepdim=True)[0]
87
+ global_state = global_state.expand(global_state.size(0), global_state.size(1), state.size(2))
88
+ state = torch.cat([global_state, state], dim=1)
89
+ x = self.prediction(state)
90
+
91
+ return x
IndicPhotoOCR/detection/textbpn/network/layers/GCN.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################################################################
2
+ # File Name: GCN.py
3
+ # Author: S.X.Zhang
4
+ ###################################################################
5
+
6
+ from __future__ import print_function
7
+ from __future__ import division
8
+ from __future__ import absolute_import
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn import init
14
+
15
+
16
+ class MeanAggregator(nn.Module):
17
+ def __init__(self):
18
+ super(MeanAggregator, self).__init__()
19
+
20
+ def forward(self, features, A):
21
+ x = torch.bmm(A, features)
22
+ return x
23
+
24
+
25
+ class GraphConv(nn.Module):
26
+ def __init__(self, in_dim, out_dim, agg):
27
+ super(GraphConv, self).__init__()
28
+ self.in_dim = in_dim
29
+ self.out_dim = out_dim
30
+ self.weight = nn.Parameter(torch.FloatTensor(in_dim * 2, out_dim))
31
+ self.bias = nn.Parameter(torch.FloatTensor(out_dim))
32
+ init.xavier_uniform_(self.weight)
33
+ init.constant_(self.bias, 0)
34
+ self.agg = agg()
35
+
36
+ def forward(self, features, A):
37
+ b, n, d = features.shape
38
+ assert (d == self.in_dim)
39
+ agg_feats = self.agg(features, A)
40
+ cat_feats = torch.cat([features, agg_feats], dim=2)
41
+ out = torch.einsum('bnd,df->bnf', (cat_feats, self.weight))
42
+ out = F.relu(out + self.bias)
43
+ return out
44
+
45
+
46
+ class GCN(nn.Module):
47
+ def __init__(self, in_dim, out_dim):
48
+ super(GCN, self).__init__()
49
+ self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
50
+
51
+ self.conv1 = GraphConv(in_dim, 256, MeanAggregator)
52
+ self.conv2 = GraphConv(256, 1024, MeanAggregator)
53
+ self.conv3 = GraphConv(1024, 512, MeanAggregator)
54
+ self.conv4 = GraphConv(512, out_dim, MeanAggregator)
55
+
56
+ self.prediction = nn.Sequential(
57
+ nn.Conv1d(out_dim, 128, 1),
58
+ nn.ReLU(inplace=True),
59
+ nn.Conv1d(128, 64, 1),
60
+ nn.ReLU(inplace=True),
61
+ nn.Conv1d(64, 2, 1))
62
+
63
+ def forward(self, x, A):
64
+ x = self.bn0(x)
65
+ x = x.permute(0, 2, 1)
66
+ b, n, c = x.shape
67
+ A = A.expand(b, n, n)
68
+
69
+ x = self.conv1(x, A)
70
+ x = self.conv2(x, A)
71
+ x = self.conv3(x, A)
72
+ x = self.conv4(x, A)
73
+
74
+ x = x.permute(0, 2, 1)
75
+ pred = self.prediction(x)
76
+
77
+ return pred
IndicPhotoOCR/detection/textbpn/network/layers/GraphConv.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn.parameter import Parameter
5
+ from torch.nn.modules.module import Module
6
+ from torch.nn import init
7
+
8
+
9
+ class GraphConvolution(Module):
10
+ """
11
+ Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
12
+ """
13
+
14
+ def __init__(self, in_features, out_features, bias=True):
15
+ super(GraphConvolution, self).__init__()
16
+ self.in_features = in_features
17
+ self.out_features = out_features
18
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
19
+ init.xavier_uniform_(self.weight)
20
+ if bias:
21
+ self.bias = Parameter(torch.FloatTensor(out_features))
22
+ init.constant_(self.bias, 0)
23
+ else:
24
+ self.register_parameter('bias', None)
25
+
26
+ self.reset_parameters()
27
+
28
+ def reset_parameters(self):
29
+ stdv = 1. / math.sqrt(self.weight.size(1))
30
+ self.weight.data.uniform_(-stdv, stdv)
31
+ if self.bias is not None:
32
+ self.bias.data.uniform_(-stdv, stdv)
33
+
34
+ def forward(self, input, adj):
35
+ support = torch.mm(input, self.weight)
36
+ output = torch.spmm(adj, support)
37
+ if self.bias is not None:
38
+ return output + self.bias
39
+ else:
40
+ return output
41
+
42
+ def __repr__(self):
43
+ return self.__class__.__name__ + ' (' \
44
+ + str(self.in_features) + ' -> ' \
45
+ + str(self.out_features) + ')'
IndicPhotoOCR/detection/textbpn/network/layers/RNN.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################################################################
2
+ # File Name: RNN.py
3
+ # Author: S.X.Zhang
4
+ ###################################################################
5
+
6
+ from __future__ import print_function
7
+ from __future__ import division
8
+ from __future__ import absolute_import
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn import init
14
+
15
+
16
+ class RNN(nn.Module):
17
+ def __init__(self, input, state_dim):
18
+ super(RNN, self).__init__()
19
+ self.bn0 = nn.BatchNorm1d(input, affine=False)
20
+ self.rnn = nn.LSTM(input, state_dim, 1, dropout=0.1, bidirectional=True)
21
+ self.prediction = nn.Sequential(
22
+ nn.Conv1d(state_dim*2, 128, 1),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv1d(128, 64, 1),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv1d(64, 2, 1))
27
+
28
+ def forward(self, x, adj):
29
+ x = self.bn0(x)
30
+ x = x.permute(2, 0, 1)
31
+ x, _ = self.rnn(x)
32
+ x = x.permute(1, 2, 0)
33
+ pred = self.prediction(x)
34
+
35
+ return pred
IndicPhotoOCR/detection/textbpn/network/layers/Transformer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################################################################
2
+ # File Name: GCN.py
3
+ # Author: S.X.Zhang
4
+ ###################################################################
5
+ import torch
6
+ from torch import nn, Tensor
7
+ import numpy as np
8
+ from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
9
+
10
+
11
+ class Positional_encoding(nn.Module):
12
+ def __init__(self, PE_size, n_position=256):
13
+ super(Positional_encoding, self).__init__()
14
+ self.PE_size = PE_size
15
+ self.n_position = n_position
16
+ self.register_buffer('pos_table', self.get_encoding_table(n_position, PE_size))
17
+
18
+ def get_encoding_table(self, n_position, PE_size):
19
+ position_table = np.array(
20
+ [[pos / np.power(10000, 2. * i / self.PE_size) for i in range(self.PE_size)] for pos in range(n_position)])
21
+ position_table[:, 0::2] = np.sin(position_table[:, 0::2])
22
+ position_table[:, 1::2] = np.cos(position_table[:, 1::2])
23
+ return torch.FloatTensor(position_table).unsqueeze(0)
24
+
25
+ def forward(self, inputs):
26
+ return inputs + self.pos_table[:, :inputs.size(1), :].clone().detach()
27
+
28
+
29
+ class MultiHeadAttention(nn.Module):
30
+ def __init__(self, num_heads, embed_dim, dropout=0.1, if_resi=True):
31
+ super(MultiHeadAttention, self).__init__()
32
+ self.layer_norm = nn.LayerNorm(embed_dim)
33
+ self.MultiheadAttention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
34
+ self.Q_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
35
+ self.K_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
36
+ self.V_proj = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.ReLU())
37
+ self.if_resi = if_resi
38
+
39
+ def forward(self, inputs):
40
+ query = self.layer_norm(inputs)
41
+ q = self.Q_proj(query)
42
+ k = self.K_proj(query)
43
+ v = self.V_proj(query)
44
+ attn_output, attn_output_weights = self.MultiheadAttention(q, k, v)
45
+ if self.if_resi:
46
+ attn_output += inputs
47
+ else:
48
+ attn_output = attn_output
49
+
50
+ return attn_output
51
+
52
+
53
+ class FeedForward(nn.Module):
54
+ def __init__(self, in_channel, FFN_channel, if_resi=True):
55
+ super(FeedForward, self).__init__()
56
+ """
57
+ 1024 2048
58
+ """
59
+ output_channel = (FFN_channel, in_channel)
60
+ self.fc1 = nn.Sequential(nn.Linear(in_channel, output_channel[0]), nn.ReLU())
61
+ self.fc2 = nn.Linear(output_channel[0], output_channel[1])
62
+ self.layer_norm = nn.LayerNorm(in_channel)
63
+ self.if_resi = if_resi
64
+
65
+ def forward(self, inputs):
66
+ outputs = self.layer_norm(inputs)
67
+ outputs = self.fc1(outputs)
68
+ outputs = self.fc2(outputs)
69
+ if self.if_resi:
70
+ outputs += inputs
71
+ else:
72
+ outputs = outputs
73
+ return outputs
74
+
75
+
76
+ class TransformerLayer(nn.Module):
77
+ def __init__(self, out_dim, in_dim, num_heads, attention_size,
78
+ dim_feedforward=1024, drop_rate=0.1, if_resi=True, block_nums=3):
79
+ super(TransformerLayer, self).__init__()
80
+ self.block_nums = block_nums
81
+ self.if_resi = if_resi
82
+ self.linear = nn.Linear(in_dim, attention_size)
83
+ for i in range(self.block_nums):
84
+ self.__setattr__('MHA_self_%d' % i, MultiHeadAttention(num_heads, attention_size,
85
+ dropout=drop_rate, if_resi=if_resi))
86
+ self.__setattr__('FFN_%d' % i, FeedForward(out_dim, dim_feedforward, if_resi=if_resi))
87
+
88
+ def forward(self, query):
89
+ inputs = self.linear(query)
90
+ # outputs = inputs
91
+ for i in range(self.block_nums):
92
+ outputs = self.__getattr__('MHA_self_%d' % i)(inputs)
93
+ outputs = self.__getattr__('FFN_%d' % i)(outputs)
94
+ if self.if_resi:
95
+ inputs = inputs+outputs
96
+ else:
97
+ inputs = outputs
98
+ # outputs = inputs
99
+ return inputs
100
+
101
+
102
+ class Transformer(nn.Module):
103
+
104
+ def __init__(self, in_dim, out_dim, num_heads=8,
105
+ dim_feedforward=1024, drop_rate=0.1, if_resi=False, block_nums=3):
106
+ super().__init__()
107
+
108
+ self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
109
+ self.conv1 = nn.Conv1d(in_dim, out_dim, 1, dilation=1)
110
+
111
+ # self.pos_embedding = Positional_encoding(in_dim)
112
+ self.transformer = TransformerLayer(out_dim, in_dim, num_heads, attention_size=out_dim,
113
+ dim_feedforward=dim_feedforward, drop_rate=drop_rate,
114
+ if_resi=if_resi, block_nums=block_nums)
115
+
116
+ self.prediction = nn.Sequential(
117
+ nn.Conv1d(2*out_dim, 128, 1),
118
+ nn.ReLU(inplace=True),
119
+ nn.Dropout(0.1),
120
+ nn.Conv1d(128, 64, 1),
121
+ nn.ReLU(inplace=True),
122
+ # nn.Dropout(0.1),
123
+ nn.Conv1d(64, 2, 1))
124
+
125
+ def forward(self, x, adj):
126
+ x = self.bn0(x)
127
+
128
+ x1 = x.permute(0, 2, 1)
129
+ # x1 = self.pos_embedding(x1)
130
+ x1 = self.transformer(x1)
131
+ x1 = x1.permute(0, 2, 1)
132
+
133
+ x = torch.cat([x1, self.conv1(x)], dim=1)
134
+ # x = x1+self.conv1(x)
135
+ pred = self.prediction(x)
136
+
137
+ return pred
138
+
139
+
140
+
IndicPhotoOCR/detection/textbpn/network/layers/Transformer_old.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###################################################################
2
+ # File Name: GCN.py
3
+ # Author: S.X.Zhang
4
+ ###################################################################
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, Tensor
9
+ from torch.autograd import Variable
10
+ import numpy as np
11
+ from cfglib.config import config as cfg
12
+
13
+
14
+ class Positional_encoding(nn.Module):
15
+ def __init__(self, PE_size, n_position=200):
16
+ super(Positional_encoding, self).__init__()
17
+ self.PE_size = PE_size
18
+ self.n_position = n_position
19
+ self.register_buffer('pos_table', self.get_encoding_table(n_position, PE_size))
20
+
21
+ def get_encoding_table(self, n_position, PE_size):
22
+ position_table = np.array(
23
+ [[pos / np.power(10000, 2. * i / self.PE_size) for i in range(self.PE_size)] for pos in range(n_position)])
24
+ position_table[:, 0::2] = np.sin(position_table[:, 0::2])
25
+ position_table[:, 1::2] = np.cos(position_table[:, 1::2])
26
+ return torch.FloatTensor(position_table).unsqueeze(0)
27
+
28
+ def forward(self, inputs):
29
+ return inputs + self.pos_table[:, :inputs.size(1), :].clone().detach()
30
+
31
+
32
+ class MultiHeadAttention(nn.Module):
33
+ def __init__(self, num_heads, embedding_size, attention_size,
34
+ drop_rate, future_blind=True, query_mask=False, if_resi=True):
35
+ super(MultiHeadAttention, self).__init__()
36
+ self.num_heads = num_heads
37
+ self.embedding_size = embedding_size
38
+ self.attention_size = attention_size
39
+ self.drop_rate = drop_rate
40
+ self.future_blind = future_blind
41
+
42
+ self.Q_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
43
+ self.K_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
44
+ self.V_proj = nn.Sequential(nn.Linear(self.embedding_size, self.attention_size), nn.ReLU())
45
+
46
+ self.drop_out = nn.Dropout(p=self.drop_rate)
47
+ self.layer_norm = nn.LayerNorm(self.attention_size)
48
+ self.if_resi = if_resi
49
+
50
+ def forward(self, query, key, value):
51
+ q = self.Q_proj(query)
52
+ k = self.K_proj(key)
53
+ v = self.V_proj(value)
54
+
55
+ q_ = torch.cat(torch.chunk(q, self.num_heads, dim=2), dim=0)
56
+ k_ = torch.cat(torch.chunk(k, self.num_heads, dim=2), dim=0)
57
+ v_ = torch.cat(torch.chunk(v, self.num_heads, dim=2), dim=0)
58
+
59
+ outputs = torch.bmm(q_, k_.permute(0, 2, 1))
60
+ outputs = outputs / (k_.size()[-1] ** 0.5)
61
+
62
+ # key mask
63
+
64
+ # future mask
65
+ if self.future_blind:
66
+ diag_vals = torch.ones_like(outputs[0, :, :]).to(cfg.device)
67
+ tril = torch.tril(diag_vals, diagonal=0)
68
+ masks = Variable(torch.unsqueeze(tril, 0).repeat(outputs.size()[0], 1, 1)) # (h*N,T_q,T_k)
69
+ padding = Variable(torch.ones_like(masks).to(cfg.device) * (-2 ** 32 + 1))
70
+ condition = masks.eq(0)
71
+ outputs = torch.where(condition, padding, outputs)
72
+
73
+ outputs = F.softmax(outputs, dim=-1)
74
+ # if self.future_blind==True:a
75
+ # print(outputs[0])
76
+ outputs = self.drop_out(outputs)
77
+
78
+ outputs = torch.bmm(outputs, v_)
79
+ outputs = torch.cat(torch.chunk(outputs, self.num_heads, dim=0), dim=2) # N,T_q,C
80
+
81
+ if self.if_resi:
82
+ # outputs += query
83
+ outputs += q
84
+ else:
85
+ outputs = outputs
86
+ outputs = self.layer_norm(outputs)
87
+
88
+ return outputs
89
+
90
+
91
+ class FeedForward(nn.Module):
92
+ def __init__(self, in_channel, FFN_channel, if_resi=True):
93
+ super(FeedForward, self).__init__()
94
+ """
95
+ 1024 2048
96
+ """
97
+ output_channel = (FFN_channel, in_channel)
98
+ self.fc1 = nn.Sequential(nn.Linear(in_channel, output_channel[0]), nn.ReLU())
99
+ self.fc2 = nn.Linear(output_channel[0], output_channel[1])
100
+ self.layer_norm = nn.LayerNorm(in_channel)
101
+ self.if_resi = if_resi
102
+
103
+ def forward(self, inputs):
104
+ outputs = self.fc1(inputs)
105
+ outputs = self.fc2(outputs)
106
+ if self.if_resi:
107
+ outputs += inputs
108
+ else:
109
+ outputs = outputs
110
+ outputs = self.layer_norm(outputs)
111
+ return outputs
112
+
113
+
114
+ class TransformerLayer(nn.Module):
115
+ def __init__(self, out_dim, num_heads, embedding_size, attention_size,
116
+ dim_feedforward=1024, drop_rate=0.1, if_resi=True, block_nums=3):
117
+ super(TransformerLayer, self).__init__()
118
+ self.block_nums = block_nums
119
+ self.if_resi = if_resi
120
+ for i in range(self.block_nums):
121
+ self.__setattr__('MHA_self_%d' % i, MultiHeadAttention(num_heads, embedding_size, attention_size,
122
+ drop_rate, future_blind=False, if_resi=if_resi))
123
+ self.__setattr__('FFN_%d' % i, FeedForward(out_dim, dim_feedforward, if_resi=if_resi))
124
+
125
+ def forward(self, query):
126
+ outputs = None
127
+ for i in range(self.block_nums):
128
+ outputs = self.__getattr__('MHA_self_%d' % i)(query, query, query)
129
+ outputs = self.__getattr__('FFN_%d' % i)(outputs)
130
+ return outputs
131
+
132
+
133
+ class Transformer(nn.Module):
134
+
135
+ def __init__(self, in_dim, out_dim, num_heads=8,
136
+ dim_feedforward=1024, drop_rate=0.1, if_resi=False, block_nums=3):
137
+ super().__init__()
138
+
139
+ self.bn0 = nn.BatchNorm1d(in_dim, affine=False)
140
+ self.conv1 = nn.Conv1d(in_dim, out_dim, 1, dilation=1)
141
+
142
+ embed_dim = in_dim
143
+ # self.pos_embedding = Positional_encoding(embed_dim)
144
+ self.transformer = TransformerLayer(out_dim, num_heads, embedding_size=embed_dim,
145
+ attention_size=out_dim, dim_feedforward=dim_feedforward,
146
+ drop_rate=drop_rate, if_resi=if_resi, block_nums=block_nums)
147
+
148
+ self.prediction = nn.Sequential(
149
+ nn.Conv1d(out_dim*2, 128, 1),
150
+ nn.ReLU(inplace=True),
151
+ nn.Dropout(0.1),
152
+ nn.Conv1d(128, 64, 1),
153
+ nn.ReLU(inplace=True),
154
+ # nn.Dropout(0.1),
155
+ nn.Conv1d(64, 2, 1))
156
+
157
+ def forward(self, x, adj):
158
+ x = self.bn0(x)
159
+
160
+ x1 = x.permute(0, 2, 1)
161
+ x1 = self.transformer(x1)
162
+ x1 = x1.permute(0, 2, 1)
163
+
164
+ x = torch.cat([x1, self.conv1(x)], dim=1)
165
+ # x = x1+self.conv1(x)
166
+ pred = self.prediction(x)
167
+
168
+ return pred
169
+
170
+
171
+
IndicPhotoOCR/detection/textbpn/network/layers/__init__.py ADDED
File without changes
IndicPhotoOCR/detection/textbpn/network/layers/gcn_utils.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ __author__ = "S.X.Zhang"
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ import torch.nn as nn
7
+ from torch.autograd import Variable
8
+
9
+
10
+ def normalize_adj(A, type="AD"):
11
+ if type == "DAD":
12
+ A = A + np.eye(A.shape[0]) # A=A+I
13
+ d = np.sum(A, axis=0)
14
+ d_inv = np.power(d, -0.5).flatten()
15
+ d_inv[np.isinf(d_inv)] = 0.0
16
+ d_inv = np.diag(d_inv)
17
+ G = A.dot(d_inv).transpose().dot(d_inv) # L = D^-1/2 A D^-1/2
18
+ G = torch.from_numpy(G)
19
+ elif type == "AD":
20
+ A = A + np.eye(A.shape[0]) # A=A+I
21
+ A = torch.from_numpy(A)
22
+ D = A.sum(1, keepdim=True)
23
+ G = A.div(D) # L= A/D
24
+ else:
25
+ A = A + np.eye(A.shape[0]) # A=A+I
26
+ D = A.sum(1, keepdim=True)
27
+ D = np.diag(D)
28
+ G = torch.from_numpy(D - A) # L = D-A
29
+ return G
30
+
31
+
32
+ def np_to_variable(x, is_cuda=True, dtype=torch.FloatTensor):
33
+ v = Variable(torch.from_numpy(x).type(dtype))
34
+ if is_cuda:
35
+ v = v.cuda()
36
+ return v
37
+
38
+
39
+ def set_trainable(model, requires_grad):
40
+ for param in model.parameters():
41
+ param.requires_grad = requires_grad
42
+
43
+
44
+ def weights_normal_init(model, dev=0.01):
45
+ if isinstance(model, list):
46
+ for m in model:
47
+ weights_normal_init(m, dev)
48
+ else:
49
+ for m in model.modules():
50
+ if isinstance(m, nn.Conv2d):
51
+ m.weight.data.normal_(0.0, dev)
52
+ elif isinstance(m, nn.Linear):
53
+ m.weight.data.normal_(0.0, dev)
54
+
55
+
56
+ def clip_gradient(model, clip_norm):
57
+ """Computes a gradient clipping coefficient based on gradient norm."""
58
+ totalnorm = 0
59
+ for p in model.parameters():
60
+ if p.requires_grad:
61
+ modulenorm = p.grad.data.norm()
62
+ totalnorm += modulenorm ** 2
63
+ totalnorm = np.sqrt(totalnorm)
64
+
65
+ norm = clip_norm / max(totalnorm, clip_norm)
66
+ for p in model.parameters():
67
+ if p.requires_grad:
68
+ p.grad.mul_(norm)
69
+
70
+
71
+ def EuclideanDistances(A, B):
72
+ BT = B.transpose()
73
+ vecProd = np.dot(A,BT)
74
+ SqA = A**2
75
+ sumSqA = np.matrix(np.sum(SqA, axis=1))
76
+ sumSqAEx = np.tile(sumSqA.transpose(), (1, vecProd.shape[1]))
77
+
78
+ SqB = B**2
79
+ sumSqB = np.sum(SqB, axis=1)
80
+ sumSqBEx = np.tile(sumSqB, (vecProd.shape[0], 1))
81
+ SqED = sumSqBEx + sumSqAEx - 2*vecProd
82
+ SqED[SqED<0]=0.0
83
+ ED = np.sqrt(SqED)
84
+ return ED
85
+
86
+
87
+ def get_center_feature(cnn_feature, img_poly, ind, h, w):
88
+ batch_size = cnn_feature.size(0)
89
+ for i in range(batch_size):
90
+ poly = img_poly[ind == i].cpu().numpy()
91
+ mask = np.zeros((h, w), dtype=np.uint8)
92
+ cv2.fillPoly(mask, poly.astype(np.int32), color=(1,))
93
+ return None
94
+
95
+
96
+ def get_node_feature(cnn_feature, img_poly, ind, h, w):
97
+ img_poly = img_poly.clone().float()
98
+ img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1
99
+ img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1
100
+
101
+ batch_size = cnn_feature.size(0)
102
+ gcn_feature = torch.zeros([img_poly.size(0), cnn_feature.size(1), img_poly.size(1)]).to(img_poly.device)
103
+ for i in range(batch_size):
104
+ poly = img_poly[ind == i].unsqueeze(0)
105
+ gcn_feature[ind == i] = torch.nn.functional.grid_sample(cnn_feature[i:i + 1], poly)[0].permute(1, 0, 2)
106
+ return gcn_feature
107
+
108
+
109
+ def get_adj_mat(n_adj, n_nodes):
110
+ a = np.zeros([n_nodes, n_nodes], dtype=np.float)
111
+
112
+ for i in range(n_nodes):
113
+ for j in range(-n_adj // 2, n_adj // 2 + 1):
114
+ if j != 0:
115
+ a[i][(i + j) % n_nodes] = 1
116
+ a[(i + j) % n_nodes][i] = 1
117
+ return a
118
+
119
+
120
+ def get_adj_ind(n_adj, n_nodes, device):
121
+ ind = torch.tensor([i for i in range(-n_adj // 2, n_adj // 2 + 1) if i != 0]).long()
122
+ ind = (torch.arange(n_nodes)[:, None] + ind[None]) % n_nodes
123
+ return ind.to(device)
124
+
125
+
126
+ def coord_embedding(b, w, h, device):
127
+ x_range = torch.linspace(0, 1, w, device=device)
128
+ y_range = torch.linspace(0, 1, h, device=device)
129
+ y, x = torch.meshgrid(y_range, x_range)
130
+ y = y.expand([b, 1, -1, -1])
131
+ x = x.expand([b, 1, -1, -1])
132
+ coord_map = torch.cat([x, y], 1)
133
+
134
+ return coord_map
135
+
136
+
137
+ def img_poly_to_can_poly(img_poly):
138
+ if len(img_poly) == 0:
139
+ return torch.zeros_like(img_poly)
140
+ x_min = torch.min(img_poly[..., 0], dim=-1)[0]
141
+ y_min = torch.min(img_poly[..., 1], dim=-1)[0]
142
+ can_poly = img_poly.clone()
143
+ can_poly[..., 0] = can_poly[..., 0] - x_min[..., None]
144
+ can_poly[..., 1] = can_poly[..., 1] - y_min[..., None]
145
+ # x_max = torch.max(img_poly[..., 0], dim=-1)[0]
146
+ # y_max = torch.max(img_poly[..., 1], dim=-1)[0]
147
+ # h, w = y_max - y_min + 1, x_max - x_min + 1
148
+ # long_side = torch.max(h, w)
149
+ # can_poly = can_poly / long_side[..., None, None]
150
+ return can_poly
IndicPhotoOCR/detection/textbpn/network/layers/model_block.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ __author__ = "S.X.Zhang"
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from IndicPhotoOCR.detection.textbpn.network.layers.vgg import VggNet
7
+ from IndicPhotoOCR.detection.textbpn.network.layers.resnet import ResNet
8
+ from IndicPhotoOCR.detection.textbpn.network.layers.resnet_dcn import ResNet_DCN
9
+ from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
10
+
11
+
12
+ class UpBlok(nn.Module):
13
+
14
+ def __init__(self, in_channels, out_channels):
15
+ super().__init__()
16
+ self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
17
+ self.conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
18
+ self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
19
+
20
+ def forward(self, upsampled, shortcut):
21
+ x = torch.cat([upsampled, shortcut], dim=1)
22
+ x = self.conv1x1(x)
23
+ x = F.relu(x)
24
+ x = self.conv3x3(x)
25
+ x = F.relu(x)
26
+ x = self.deconv(x)
27
+ return x
28
+
29
+
30
+ class MergeBlok(nn.Module):
31
+ def __init__(self, in_channels, out_channels):
32
+ super().__init__()
33
+ self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
34
+ self.conv3x3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
35
+
36
+ def forward(self, upsampled, shortcut):
37
+ x = torch.cat([upsampled, shortcut], dim=1)
38
+ x = self.conv1x1(x)
39
+ x = F.relu(x)
40
+ x = self.conv3x3(x)
41
+ return x
42
+
43
+
44
+ class FPN(nn.Module):
45
+
46
+ def __init__(self, backbone='resnet50', is_training=True):
47
+ super().__init__()
48
+ self.is_training = is_training
49
+ self.backbone_name = backbone
50
+
51
+ if backbone in ['vgg_bn', 'vgg']:
52
+ self.backbone = VggNet(name=backbone, pretrain=is_training)
53
+ self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
54
+ self.merge4 = UpBlok(512 + 256, 128)
55
+ self.merge3 = UpBlok(256 + 128, 64)
56
+ if cfg.scale == 1:
57
+ self.merge2 = UpBlok(128 + 64, 32) # FPN 1/2
58
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
59
+ elif cfg.scale == 2:
60
+ self.merge2 = UpBlok(128 + 64, 32) # FPN 1/2
61
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
62
+ elif cfg.scale == 4:
63
+ self.merge2 = MergeBlok(128 + 64, 32) # FPN 1/4
64
+
65
+ elif backbone in ['resnet50']:
66
+ self.backbone = ResNet(name=backbone, pretrain=is_training)
67
+ self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
68
+ self.merge4 = UpBlok(1024 + 256, 128)
69
+ self.merge3 = UpBlok(512 + 128, 64)
70
+ if cfg.scale == 1:
71
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
72
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
73
+ elif cfg.scale == 2:
74
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
75
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
76
+ elif cfg.scale == 4:
77
+ self.merge2 = MergeBlok(256 + 64, 32) # FPN 1/4
78
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
79
+
80
+ elif backbone in ['resnet18']:
81
+ self.backbone = ResNet(name=backbone, pretrain=is_training)
82
+ self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
83
+ self.merge4 = UpBlok(256 + 256, 128)
84
+ self.merge3 = UpBlok(128 + 128, 64)
85
+ if cfg.scale == 1:
86
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
87
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
88
+ elif cfg.scale == 2:
89
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
90
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
91
+ elif cfg.scale == 4:
92
+ self.merge2 = MergeBlok(64 + 64, 32) # FPN 1/4
93
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
94
+
95
+ elif backbone in ["deformable_resnet18"]:
96
+ self.backbone = ResNet_DCN(name=backbone, pretrain=is_training)
97
+ self.deconv5 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
98
+ self.merge4 = UpBlok(256 + 256, 128)
99
+ self.merge3 = UpBlok(128 + 128, 64)
100
+ if cfg.scale == 1:
101
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
102
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
103
+ elif cfg.scale == 2:
104
+ self.merge2 = UpBlok(64 + 64, 32) # FPN 1/2
105
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
106
+ elif cfg.scale == 4:
107
+ self.merge2 = MergeBlok(64 + 64, 32) # FPN 1/4
108
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
109
+
110
+ elif backbone in ["deformable_resnet50"]:
111
+ self.backbone = ResNet_DCN(name=backbone, pretrain=is_training)
112
+ self.deconv5 = nn.ConvTranspose2d(2048, 256, kernel_size=4, stride=2, padding=1)
113
+ self.merge4 = UpBlok(1024 + 256, 128)
114
+ self.merge3 = UpBlok(512 + 128, 64)
115
+ if cfg.scale == 1:
116
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
117
+ self.merge1 = UpBlok(64 + 32, 32) # FPN 1/1
118
+ elif cfg.scale == 2:
119
+ self.merge2 = UpBlok(256 + 64, 32) # FPN 1/2
120
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/2
121
+ elif cfg.scale == 4:
122
+ self.merge2 = MergeBlok(256 + 64, 32) # FPN 1/4
123
+ self.merge1 = MergeBlok(64 + 32, 32) # FPN 1/4
124
+ else:
125
+ print("backbone is not support !")
126
+
127
+ def forward(self, x):
128
+ C1, C2, C3, C4, C5 = self.backbone(x)
129
+ #print(C5.size())
130
+ #print(C4.size())
131
+ #print(C3.size())
132
+ #print(C2.size())
133
+ #print(C1.size())
134
+ up5 = self.deconv5(C5)
135
+ up5 = F.relu(up5)
136
+
137
+ up4 = self.merge4(C4, up5)
138
+ up4 = F.relu(up4)
139
+
140
+ up3 = self.merge3(C3, up4)
141
+ up3 = F.relu(up3)
142
+
143
+ up2 = self.merge2(C2, up3)
144
+ up2 = F.relu(up2)
145
+
146
+ up1 = self.merge1(C1, up2)
147
+ up1 = F.relu(up1)
148
+
149
+ return up1, up2, up3, up4, up5
IndicPhotoOCR/detection/textbpn/network/layers/position_encoding.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+ from util.misc import NestedTensor
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
18
+ super().__init__()
19
+ self.num_pos_feats = num_pos_feats
20
+ self.temperature = temperature
21
+ self.normalize = normalize
22
+ if scale is not None and normalize is False:
23
+ raise ValueError("normalize should be True if scale is passed")
24
+ if scale is None:
25
+ scale = 2 * math.pi
26
+ self.scale = scale
27
+
28
+ def forward(self, tensor_list: NestedTensor):
29
+ x = tensor_list.tensors
30
+ mask = tensor_list.mask
31
+ assert mask is not None
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
46
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
47
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
48
+ return pos
49
+
50
+
51
+ class PositionEmbeddingLearned(nn.Module):
52
+ """
53
+ Absolute pos embedding, learned.
54
+ """
55
+ def __init__(self, num_pos_feats=256):
56
+ super().__init__()
57
+ self.row_embed = nn.Embedding(50, num_pos_feats)
58
+ self.col_embed = nn.Embedding(50, num_pos_feats)
59
+ self.reset_parameters()
60
+
61
+ def reset_parameters(self):
62
+ nn.init.uniform_(self.row_embed.weight)
63
+ nn.init.uniform_(self.col_embed.weight)
64
+
65
+ def forward(self, tensor_list: NestedTensor):
66
+ x = tensor_list.tensors
67
+ h, w = x.shape[-2:]
68
+ i = torch.arange(w, device=x.device)
69
+ j = torch.arange(h, device=x.device)
70
+ x_emb = self.col_embed(i)
71
+ y_emb = self.row_embed(j)
72
+ pos = torch.cat([
73
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
74
+ y_emb.unsqueeze(1).repeat(1, w, 1),
75
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
76
+ return pos
77
+
78
+
79
+ def build_position_encoding(args):
80
+ N_steps = args.hidden_dim // 2
81
+ if args.position_embedding in ('v2', 'sine'):
82
+ # TODO find a better way of exposing other arguments
83
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
84
+ elif args.position_embedding in ('v3', 'learned'):
85
+ position_embedding = PositionEmbeddingLearned(N_steps)
86
+ else:
87
+ raise ValueError(f"not supported {args.position_embedding}")
88
+
89
+ return position_embedding
IndicPhotoOCR/detection/textbpn/network/layers/resnet.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet
4
+ import torch.utils.model_zoo as model_zoo
5
+ from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
6
+
7
+ model_urls = {
8
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
9
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
12
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
13
+
14
+ }
15
+
16
+
17
+ class ResNet(nn.Module):
18
+ def __init__(self, name="resnet50", pretrain=True):
19
+ super().__init__()
20
+
21
+ if name == "resnet50":
22
+ base_net = resnet.resnet50(pretrained=False)
23
+ elif name == "resnet101":
24
+ base_net = resnet.resnet101(pretrained=False)
25
+ elif name == "resnet18":
26
+ base_net = resnet.resnet18(pretrained=False)
27
+ elif name == "resnet34":
28
+ base_net = resnet.resnet34(pretrained=False)
29
+
30
+ else:
31
+ print(" base model is not support !")
32
+
33
+ if pretrain:
34
+ print("load the {} weight from ./cache".format(name))
35
+ base_net.load_state_dict(model_zoo.load_url(model_urls[name], model_dir="./cache",
36
+ map_location=torch.device(cfg.device)), strict=False)
37
+ # print(base_net)
38
+ self.stage1 = nn.Sequential(
39
+ base_net.conv1,
40
+ base_net.bn1,
41
+ base_net.relu,
42
+ base_net.maxpool
43
+ )
44
+ self.stage2 = base_net.layer1
45
+ self.stage3 = base_net.layer2
46
+ self.stage4 = base_net.layer3
47
+ self.stage5 = base_net.layer4
48
+ self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
49
+
50
+ def forward(self, x):
51
+ C1 = self.stage1(x)
52
+ C2 = self.stage2(C1)
53
+ C3 = self.stage3(C2)
54
+ C4 = self.stage4(C3)
55
+ C5 = self.stage5(C4)
56
+
57
+ if cfg.scale == 2 or cfg.scale == 1:
58
+ # up2 --> 1/2
59
+ C1 = self.up2(C1)
60
+
61
+ return C1, C2, C3, C4, C5
62
+
63
+
64
+ if __name__ == '__main__':
65
+ import torch
66
+ input = torch.randn((4, 3, 512, 512))
67
+ net = ResNet()
68
+ C1, C2, C3, C4, C5 = net(input)
69
+ print(C1.size())
70
+ print(C2.size())
71
+ print(C3.size())
72
+ print(C4.size())
73
+ print(C5.size())
IndicPhotoOCR/detection/textbpn/network/layers/resnet_dcn.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from IndicPhotoOCR.detection.textbpn.network.backbone.resnet import deformable_resnet18,deformable_resnet50
4
+ import torch.utils.model_zoo as model_zoo
5
+ from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
6
+
7
+ model_urls = {
8
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
9
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
12
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
13
+
14
+ }
15
+
16
+
17
+ class ResNet_DCN(nn.Module):
18
+ def __init__(self, name="deformable_resnet18", pretrain=False):
19
+ super().__init__()
20
+
21
+ if name == "deformable_resnet18":
22
+ self.base_net = deformable_resnet18(pretrained=False)
23
+ if pretrain:
24
+ print("load the {} weight from ./cache".format(name))
25
+ self.base_net.load_state_dict(
26
+ model_zoo.load_url(model_urls["resnet18"], model_dir="./cache",
27
+ map_location=torch.device(cfg.device)), strict=False)
28
+
29
+ elif name == "deformable_resnet50":
30
+ self.base_net = deformable_resnet50(pretrained=False)
31
+ if pretrain:
32
+ print("load the {} weight from ./cache".format(name))
33
+ self.base_net.load_state_dict(
34
+ model_zoo.load_url(model_urls["resnet50"], model_dir="./cache",
35
+ map_location=torch.device(cfg.device)), strict=False)
36
+ else:
37
+ print(" base model is not support !")
38
+
39
+ # print(base_net)
40
+ self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1)
41
+
42
+ def forward(self, x):
43
+ C1, C2, C3, C4, C5 = self.base_net(x)
44
+ # up2 --> 1/2
45
+ C1 = self.up2(C1)
46
+
47
+ return C1, C2, C3, C4, C5
48
+
49
+
50
+ if __name__ == '__main__':
51
+ import torch
52
+ input = torch.randn((4, 3, 512, 512))
53
+ net = ResNet_DCN()
54
+ C1, C2, C3, C4, C5 = net(input)
55
+ print(C1.size())
56
+ print(C2.size())
57
+ print(C3.size())
58
+ print(C4.size())
59
+ print(C5.size())
IndicPhotoOCR/detection/textbpn/network/layers/vgg.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.utils.model_zoo as model_zoo
3
+ import torchvision.models as models
4
+ from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
5
+
6
+ model_urls = {
7
+ 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
8
+ 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
9
+ 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
10
+ 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
11
+ 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
12
+ 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
13
+ }
14
+
15
+
16
+ class VggNet(nn.Module):
17
+ def __init__(self, name="vgg16", pretrain=True):
18
+ super().__init__()
19
+ if name == "vgg16":
20
+ base_net = models.vgg16(pretrained=False)
21
+ elif name == "vgg16_bn":
22
+ base_net = models.vgg16_bn(pretrained=False)
23
+ else:
24
+ print(" base model is not support !")
25
+ if pretrain:
26
+ print("load the {} weight from ./cache".format(name))
27
+ base_net.load_state_dict(model_zoo.load_url(model_urls[name],
28
+ model_dir="./cache",map_location=torch.device(cfg.device)))
29
+
30
+ if name == "vgg16":
31
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 5)])
32
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(5, 10)])
33
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(10, 17)])
34
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(17, 24)])
35
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 31)])
36
+ elif name == "vgg16_bn":
37
+ self.stage1 = nn.Sequential(*[base_net.features[layer] for layer in range(0, 7)])
38
+ self.stage2 = nn.Sequential(*[base_net.features[layer] for layer in range(7, 14)])
39
+ self.stage3 = nn.Sequential(*[base_net.features[layer] for layer in range(14, 24)])
40
+ self.stage4 = nn.Sequential(*[base_net.features[layer] for layer in range(24, 34)])
41
+ self.stage5 = nn.Sequential(*[base_net.features[layer] for layer in range(34, 44)])
42
+
43
+ def forward(self, x):
44
+ C1 = self.stage1(x)
45
+ C2 = self.stage2(C1)
46
+ C3 = self.stage3(C2)
47
+ C4 = self.stage4(C3)
48
+ C5 = self.stage5(C4)
49
+
50
+ return C1, C2, C3, C4, C5
51
+
52
+
53
+ if __name__ == '__main__':
54
+ import torch
55
+ input = torch.randn((4, 3, 512, 512))
56
+ net = VggNet()
57
+ C1, C2, C3, C4, C5 = net(input)
58
+ print(C1.size())
59
+ print(C2.size())
60
+ print(C3.size())
61
+ print(C4.size())
62
+ print(C5.size())
IndicPhotoOCR/detection/textbpn/network/loss.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 10/1/21
3
+ # @Author : GXYM
4
+ import torch
5
+ import torch.nn as nn
6
+ from cfglib.config import config as cfg
7
+ from network.Seg_loss import SegmentLoss
8
+ from network.Reg_loss import PolyMatchingLoss
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class TextLoss(nn.Module):
13
+
14
+ def __init__(self):
15
+ super().__init__()
16
+ self.MSE_loss = torch.nn.MSELoss(reduce=False, size_average=False)
17
+ self.BCE_loss = torch.nn.BCELoss(reduce=False, size_average=False)
18
+ self.PolyMatchingLoss = PolyMatchingLoss(cfg.num_points, cfg.device)
19
+ self.KL_loss = torch.nn.KLDivLoss(reduce=False, size_average=False)
20
+
21
+ @staticmethod
22
+ def single_image_loss(pre_loss, loss_label):
23
+ batch_size = pre_loss.shape[0]
24
+ sum_loss = torch.mean(pre_loss.view(-1)) * 0
25
+ pre_loss = pre_loss.view(batch_size, -1)
26
+ loss_label = loss_label.view(batch_size, -1)
27
+ eps = 0.001
28
+ for i in range(batch_size):
29
+ average_number = 0
30
+ positive_pixel = len(pre_loss[i][(loss_label[i] >= eps)])
31
+ average_number += positive_pixel
32
+ if positive_pixel != 0:
33
+ posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= eps)])
34
+ sum_loss += posi_loss
35
+ if len(pre_loss[i][(loss_label[i] < eps)]) < 3 * positive_pixel:
36
+ nega_loss = torch.mean(pre_loss[i][(loss_label[i] < eps)])
37
+ average_number += len(pre_loss[i][(loss_label[i] < eps)])
38
+ else:
39
+ nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < eps)], 3 * positive_pixel)[0])
40
+ average_number += 3 * positive_pixel
41
+ sum_loss += nega_loss
42
+ else:
43
+ nega_loss = torch.mean(torch.topk(pre_loss[i], 100)[0])
44
+ average_number += 100
45
+ sum_loss += nega_loss
46
+ # sum_loss += loss/average_number
47
+
48
+ return sum_loss/batch_size
49
+
50
+ def cls_ohem(self, predict, target, train_mask, negative_ratio=3.):
51
+ pos = (target * train_mask).bool()
52
+ neg = ((1 - target) * train_mask).bool()
53
+
54
+ n_pos = pos.float().sum()
55
+ if n_pos.item() > 0:
56
+ loss_pos = self.BCE_loss(predict[pos], target[pos]).sum()
57
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
58
+ n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
59
+ else:
60
+ loss_pos = torch.tensor(0.)
61
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
62
+ n_neg = 100
63
+ loss_neg, _ = torch.topk(loss_neg, n_neg)
64
+
65
+ return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
66
+
67
+ @staticmethod
68
+ def loss_calc_flux(pred_flux, gt_flux, weight_matrix, mask, train_mask):
69
+
70
+ # norm loss
71
+ gt_flux = 0.999999 * gt_flux / (gt_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-3)
72
+ norm_loss = weight_matrix * torch.mean((pred_flux - gt_flux) ** 2, dim=1)*train_mask
73
+ norm_loss = norm_loss.sum(-1).mean()
74
+ # norm_loss = norm_loss.sum()
75
+
76
+ # angle loss
77
+ mask = train_mask * mask
78
+ pred_flux = 0.999999 * pred_flux / (pred_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-3)
79
+ # angle_loss = weight_matrix * (torch.acos(torch.sum(pred_flux * gt_flux, dim=1))) ** 2
80
+ # angle_loss = angle_loss.sum(-1).mean()
81
+ angle_loss = (1 - torch.cosine_similarity(pred_flux, gt_flux, dim=1))
82
+ angle_loss = angle_loss[mask].mean()
83
+
84
+ return norm_loss, angle_loss
85
+
86
+ @staticmethod
87
+ def get_poly_energy(energy_field, img_poly, ind, h, w):
88
+ img_poly = img_poly.clone().float()
89
+ img_poly[..., 0] = img_poly[..., 0] / (w / 2.) - 1
90
+ img_poly[..., 1] = img_poly[..., 1] / (h / 2.) - 1
91
+
92
+ batch_size = energy_field.size(0)
93
+ gcn_feature = torch.zeros([img_poly.size(0), energy_field.size(1), img_poly.size(1)]).to(img_poly.device)
94
+ for i in range(batch_size):
95
+ poly = img_poly[ind == i].unsqueeze(0)
96
+ gcn_feature[ind == i] = torch.nn.functional.grid_sample(energy_field[i:i + 1], poly)[0].permute(1, 0, 2)
97
+ return gcn_feature
98
+
99
+ def loss_energy_regularization(self, energy_field, img_poly, inds, h, w):
100
+ energys = []
101
+ for i, py in enumerate(img_poly):
102
+ energy = self.get_poly_energy(energy_field.unsqueeze(1), py, inds, h, w)
103
+ energys.append(energy.squeeze(1).sum(-1))
104
+
105
+ regular_loss = torch.tensor(0.)
106
+ energy_loss = torch.tensor(0.)
107
+ for i, e in enumerate(energys[1:]):
108
+ regular_loss += torch.clamp(e - energys[i], min=0.0).mean()
109
+ energy_loss += torch.where(e <= 0.01, torch.tensor(0.), e).mean()
110
+
111
+ return (energy_loss+regular_loss)/len(energys[1:])
112
+
113
+ def forward(self, input_dict, output_dict, eps=None):
114
+ """
115
+ calculate boundary proposal network loss
116
+ """
117
+ # tr_mask = tr_mask.permute(0, 3, 1, 2).contiguous()
118
+
119
+ fy_preds = output_dict["fy_preds"]
120
+ py_preds = output_dict["py_preds"]
121
+ inds = output_dict["inds"]
122
+
123
+ train_mask = input_dict['train_mask']
124
+ tr_mask = input_dict['tr_mask'] > 0
125
+ distance_field = input_dict['distance_field']
126
+ direction_field = input_dict['direction_field']
127
+ weight_matrix = input_dict['weight_matrix']
128
+ gt_tags = input_dict['gt_points']
129
+
130
+ # # scale the prediction map
131
+ # fy_preds = F.interpolate(fy_preds, scale_factor=cfg.scale, mode='bilinear')
132
+
133
+ if cfg.scale > 1:
134
+ train_mask = F.interpolate(train_mask.float().unsqueeze(1),
135
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze().bool()
136
+ tr_mask = F.interpolate(tr_mask.float().unsqueeze(1),
137
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze().bool()
138
+
139
+ distance_field = F.interpolate(distance_field.unsqueeze(1),
140
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze()
141
+ direction_field = F.interpolate(direction_field,
142
+ scale_factor=1 / cfg.scale, mode='bilinear')
143
+ weight_matrix = F.interpolate(weight_matrix.unsqueeze(1),
144
+ scale_factor=1/cfg.scale, mode='bilinear').squeeze()
145
+
146
+ # pixel class loss
147
+ # cls_loss = self.cls_ohem(fy_preds[:, 0, :, :], tr_mask.float(), train_mask)
148
+ cls_loss = self.BCE_loss(fy_preds[:, 0, :, :], tr_mask.float())
149
+ cls_loss = torch.mul(cls_loss, train_mask.float()).mean()
150
+
151
+ # distance field loss
152
+ dis_loss = self.MSE_loss(fy_preds[:, 1, :, :], distance_field)
153
+ dis_loss = torch.mul(dis_loss, train_mask.float())
154
+ dis_loss = self.single_image_loss(dis_loss, distance_field)
155
+
156
+ # # direction field loss
157
+ norm_loss, angle_loss = self.loss_calc_flux(fy_preds[:, 2:4, :, :], direction_field,
158
+ weight_matrix, tr_mask, train_mask)
159
+
160
+ # boundary point loss
161
+ point_loss = self.PolyMatchingLoss(py_preds[1:], gt_tags[inds])
162
+
163
+ # Minimum energy loss regularization
164
+ h, w = distance_field.size(1) * cfg.scale, distance_field.size(2) * cfg.scale
165
+ energy_loss = self.loss_energy_regularization(distance_field, py_preds, inds[0], h, w)
166
+
167
+ if eps is None:
168
+ alpha = 1.0; beta = 3.0; theta=0.5; gama = 0.05
169
+ else:
170
+ alpha = 1.0; beta = 3.0; theta=0.5;
171
+ gama = 0.1*torch.sigmoid(torch.tensor((eps - cfg.max_epoch)/cfg.max_epoch))
172
+ loss = alpha*cls_loss + beta*dis_loss + theta*(norm_loss + angle_loss) + gama*(point_loss + energy_loss)
173
+
174
+ loss_dict = {
175
+ 'total_loss': loss,
176
+ 'cls_loss': alpha*cls_loss,
177
+ 'distance loss': beta*dis_loss,
178
+ 'dir_loss': theta*(norm_loss + angle_loss),
179
+ 'norm_loss': theta*norm_loss,
180
+ 'angle_loss': theta*angle_loss,
181
+ 'point_loss': gama*point_loss,
182
+ 'energy_loss': gama*energy_loss,
183
+
184
+ }
185
+
186
+ return loss_dict
187
+
IndicPhotoOCR/detection/textbpn/network/loss_org.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 10/1/21
3
+ # @Author : GXYM
4
+ import torch
5
+ import torch.nn as nn
6
+ from cfglib.config import config as cfg
7
+ from network.Seg_loss import SegmentLoss
8
+ from network.Reg_loss import PolyMatchingLoss
9
+
10
+
11
+ class TextLoss(nn.Module):
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+ self.MSE_loss = torch.nn.MSELoss(reduce=False, size_average=False)
16
+ self.BCE_loss = torch.nn.BCELoss(reduce=False, size_average=False)
17
+ self.PolyMatchingLoss = PolyMatchingLoss(cfg.num_points, cfg.device)
18
+ self.KL_loss = torch.nn.KLDivLoss(reduce=False, size_average=False)
19
+
20
+ @staticmethod
21
+ def single_image_loss(pre_loss, loss_label):
22
+ batch_size = pre_loss.shape[0]
23
+ sum_loss = torch.mean(pre_loss.view(-1)) * 0
24
+ pre_loss = pre_loss.view(batch_size, -1)
25
+ loss_label = loss_label.view(batch_size, -1)
26
+ eps = 0.001
27
+ for i in range(batch_size):
28
+ average_number = 0
29
+ positive_pixel = len(pre_loss[i][(loss_label[i] >= eps)])
30
+ average_number += positive_pixel
31
+ if positive_pixel != 0:
32
+ posi_loss = torch.mean(pre_loss[i][(loss_label[i] >= eps)])
33
+ sum_loss += posi_loss
34
+ if len(pre_loss[i][(loss_label[i] < eps)]) < 3 * positive_pixel:
35
+ nega_loss = torch.mean(pre_loss[i][(loss_label[i] < eps)])
36
+ average_number += len(pre_loss[i][(loss_label[i] < eps)])
37
+ else:
38
+ nega_loss = torch.mean(torch.topk(pre_loss[i][(loss_label[i] < eps)], 3 * positive_pixel)[0])
39
+ average_number += 3 * positive_pixel
40
+ sum_loss += nega_loss
41
+ else:
42
+ nega_loss = torch.mean(torch.topk(pre_loss[i], 100)[0])
43
+ average_number += 100
44
+ sum_loss += nega_loss
45
+ # sum_loss += loss/average_number
46
+
47
+ return sum_loss/batch_size
48
+
49
+ def cls_ohem(self, predict, target, train_mask, negative_ratio=3.):
50
+ pos = (target * train_mask).bool()
51
+ neg = ((1 - target) * train_mask).bool()
52
+
53
+ n_pos = pos.float().sum()
54
+
55
+ if n_pos.item() > 0:
56
+ loss_pos = self.BCE_loss(predict[pos], target[pos]).sum()
57
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
58
+ n_neg = min(int(neg.float().sum().item()), int(negative_ratio * n_pos.float()))
59
+ else:
60
+ loss_pos = torch.tensor(0.)
61
+ loss_neg = self.BCE_loss(predict[neg], target[neg])
62
+ n_neg = 100
63
+ loss_neg, _ = torch.topk(loss_neg, n_neg)
64
+
65
+ return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).float()
66
+
67
+ @staticmethod
68
+ def loss_calc_flux(pred_flux, gt_flux, weight_matrix, mask, train_mask):
69
+
70
+ # norm loss
71
+ gt_flux = 0.999999 * gt_flux / (gt_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-9)
72
+ norm_loss = weight_matrix * torch.sum((pred_flux - gt_flux) ** 2, dim=1)*train_mask
73
+ norm_loss = norm_loss.sum(-1).mean()
74
+
75
+ # angle loss
76
+ mask = train_mask * mask
77
+ pred_flux = 0.999999 * pred_flux / (pred_flux.norm(p=2, dim=1).unsqueeze(1) + 1e-9)
78
+ # angle_loss = weight_matrix * (torch.acos(torch.sum(pred_flux * gt_flux, dim=1))) ** 2
79
+ # angle_loss = angle_loss.sum(-1).mean()
80
+ angle_loss = (1 - torch.cosine_similarity(pred_flux, gt_flux, dim=1))
81
+ angle_loss = angle_loss[mask].mean()
82
+
83
+ return norm_loss, angle_loss
84
+
85
+ def forward(self, input_dict, output_dict, eps=None):
86
+ """
87
+ calculate boundary proposal network loss
88
+ """
89
+ # tr_mask = tr_mask.permute(0, 3, 1, 2).contiguous()
90
+
91
+ fy_preds = output_dict["fy_preds"]
92
+ py_preds = output_dict["py_preds"]
93
+ inds = output_dict["inds"]
94
+
95
+ train_mask = input_dict['train_mask']
96
+ tr_mask = input_dict['tr_mask'] > 0
97
+ distance_field = input_dict['distance_field']
98
+ direction_field = input_dict['direction_field']
99
+ weight_matrix = input_dict['weight_matrix']
100
+ gt_tags = input_dict['gt_points']
101
+
102
+ # pixel class loss
103
+ cls_loss = self.cls_ohem(fy_preds[:, 0, :, :], tr_mask.float(), train_mask.bool())
104
+
105
+ # distance field loss
106
+ dis_loss = self.MSE_loss(fy_preds[:, 1, :, :], distance_field)
107
+ dis_loss = torch.mul(dis_loss, train_mask.float())
108
+ dis_loss = self.single_image_loss(dis_loss, distance_field)
109
+
110
+ # direction field loss
111
+ norm_loss, angle_loss = self.loss_calc_flux(fy_preds[:, 2:4, :, :],
112
+ direction_field, weight_matrix, tr_mask, train_mask)
113
+
114
+ # boundary point loss
115
+ point_loss = self.PolyMatchingLoss(py_preds, gt_tags[inds])
116
+
117
+ if eps is None:
118
+ loss_b = 0.05*point_loss
119
+ loss = cls_loss + 3.0*dis_loss + norm_loss + angle_loss + loss_b
120
+ else:
121
+ loss_b = 0.1*(torch.sigmoid(torch.tensor((eps - cfg.max_epoch)/cfg.max_epoch))) * point_loss
122
+ loss = cls_loss + 3.0*dis_loss + norm_loss + angle_loss + loss_b
123
+
124
+ loss_dict = {
125
+ 'total_loss': loss,
126
+ 'cls_loss': cls_loss,
127
+ 'distance loss': 3.0*dis_loss,
128
+ 'dir_loss': norm_loss + angle_loss,
129
+ 'point_loss': loss_b,
130
+ 'norm_loss': norm_loss,
131
+ 'angle_loss': angle_loss,
132
+
133
+ }
134
+
135
+ return loss_dict
136
+
IndicPhotoOCR/detection/textbpn/network/textnet.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 10/1/21
3
+ # @Author : GXYM
4
+ import torch
5
+ import torch.nn as nn
6
+ from IndicPhotoOCR.detection.textbpn.network.layers.model_block import FPN
7
+ from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
8
+ import numpy as np
9
+ from IndicPhotoOCR.detection.textbpn.network.layers.CircConv import DeepSnake
10
+ from IndicPhotoOCR.detection.textbpn.network.layers.GCN import GCN
11
+ from IndicPhotoOCR.detection.textbpn.network.layers.RNN import RNN
12
+ from IndicPhotoOCR.detection.textbpn.network.layers.Adaptive_Deformation import AdaptiveDeformation
13
+ # from IndicPhotoOCR.detection.textbpn.network.layers.Transformer_old import Transformer_old
14
+ from IndicPhotoOCR.detection.textbpn.network.layers.Transformer import Transformer
15
+ import cv2
16
+ from IndicPhotoOCR.detection.textbpn.util.misc import get_sample_point, fill_hole
17
+ from IndicPhotoOCR.detection.textbpn.network.layers.gcn_utils import get_node_feature, \
18
+ get_adj_mat, get_adj_ind, coord_embedding, normalize_adj
19
+ import torch.nn.functional as F
20
+ import time
21
+
22
+
23
+ class Evolution(nn.Module):
24
+ def __init__(self, node_num, adj_num, is_training=True, device=None, model="snake"):
25
+ super(Evolution, self).__init__()
26
+ self.node_num = node_num
27
+ self.adj_num = adj_num
28
+ self.device = device
29
+ self.is_training = is_training
30
+ self.clip_dis = 16
31
+
32
+ self.iter = 3
33
+ if model == "gcn":
34
+ self.adj = get_adj_mat(self.adj_num, self.node_num)
35
+ self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device)
36
+ for i in range(self.iter):
37
+ evolve_gcn = GCN(36, 128)
38
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
39
+ elif model == "rnn":
40
+ self.adj = None
41
+ for i in range(self.iter):
42
+ evolve_gcn = RNN(36, 128)
43
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
44
+ elif model == "AD":
45
+ self.adj = get_adj_mat(self.adj_num, self.node_num)
46
+ self.adj = normalize_adj(self.adj, type="DAD").float().to(self.device)
47
+ for i in range(self.iter):
48
+ evolve_gcn = AdaptiveDeformation(36, 128)
49
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
50
+ # elif model == "BT_old":
51
+ # self.adj = None
52
+ # for i in range(self.iter):
53
+ # evolve_gcn = Transformer_old(36, 512, num_heads=8,
54
+ # dim_feedforward=2048, drop_rate=0.0, if_resi=True, block_nums=4)
55
+ # self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
56
+ elif model == "BT":
57
+ self.adj = None
58
+ for i in range(self.iter):
59
+ evolve_gcn = Transformer(36, 128, num_heads=8,
60
+ dim_feedforward=1024, drop_rate=0.0, if_resi=True, block_nums=3)
61
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
62
+ else:
63
+ self.adj = get_adj_ind(self.adj_num, self.node_num, self.device)
64
+ for i in range(self.iter):
65
+ evolve_gcn = DeepSnake(state_dim=128, feature_dim=36, conv_type='dgrid')
66
+ self.__setattr__('evolve_gcn' + str(i), evolve_gcn)
67
+
68
+ for m in self.modules():
69
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
70
+ m.weight.data.normal_(0.0, 0.02)
71
+ # nn.init.kaiming_normal_(m.weight, mode='fan_in')
72
+ if m.bias is not None:
73
+ nn.init.constant_(m.bias, 0)
74
+
75
+ @staticmethod
76
+ def get_boundary_proposal(input=None, seg_preds=None, switch="gt"):
77
+
78
+ if switch == "gt":
79
+ inds = torch.where(input['ignore_tags'] > 0)
80
+ # if len(inds[0]) > 320:
81
+ # inds = (inds[0][:320], inds[1][:320])
82
+ init_polys = input['proposal_points'][inds]
83
+ else:
84
+ tr_masks = input['tr_mask'].cpu().numpy()
85
+ tcl_masks = seg_preds[:, 0, :, :].detach().cpu().numpy() > cfg.threshold
86
+ inds = []
87
+ init_polys = []
88
+ for bid, tcl_mask in enumerate(tcl_masks):
89
+ ret, labels = cv2.connectedComponents(tcl_mask.astype(np.uint8), connectivity=8)
90
+ for idx in range(1, ret):
91
+ text_mask = labels == idx
92
+ ist_id = int(np.sum(text_mask*tr_masks[bid])/np.sum(text_mask))-1
93
+ inds.append([bid, ist_id])
94
+ poly = get_sample_point(text_mask, cfg.num_points, cfg.approx_factor)
95
+ init_polys.append(poly)
96
+ inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device)
97
+ init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device)
98
+
99
+ return init_polys, inds, None
100
+
101
+ def get_boundary_proposal_eval(self, input=None, seg_preds=None):
102
+
103
+ # if cfg.scale > 1:
104
+ # seg_preds = F.interpolate(seg_preds, scale_factor=cfg.scale, mode='bilinear')
105
+ cls_preds = seg_preds[:, 0, :, :].detach().cpu().numpy()
106
+ dis_preds = seg_preds[:, 1, :, ].detach().cpu().numpy()
107
+
108
+ inds = []
109
+ init_polys = []
110
+ confidences = []
111
+ for bid, dis_pred in enumerate(dis_preds):
112
+ # # dis_mask = (dis_pred / np.max(dis_pred)) > cfg.dis_threshold
113
+ dis_mask = dis_pred > cfg.dis_threshold
114
+ # dis_mask = fill_hole(dis_mask)
115
+ ret, labels = cv2.connectedComponents(dis_mask.astype(np.uint8), connectivity=8, ltype=cv2.CV_16U)
116
+ for idx in range(1, ret):
117
+ text_mask = labels == idx
118
+ confidence = round(cls_preds[bid][text_mask].mean(), 3)
119
+ # 50 for MLT2017 and ArT (or DCN is used in backone); else is all 150;
120
+ # just can set to 50, which has little effect on the performance
121
+ if np.sum(text_mask) < 50/(cfg.scale*cfg.scale) or confidence < cfg.cls_threshold:
122
+ continue
123
+ confidences.append(confidence)
124
+ inds.append([bid, 0])
125
+
126
+ poly = get_sample_point(text_mask, cfg.num_points,
127
+ cfg.approx_factor, scales=np.array([cfg.scale, cfg.scale]))
128
+ init_polys.append(poly)
129
+
130
+ if len(inds) > 0:
131
+ inds = torch.from_numpy(np.array(inds)).permute(1, 0).to(input["img"].device, non_blocking=True)
132
+ init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float()
133
+ else:
134
+ init_polys = torch.from_numpy(np.array(init_polys)).to(input["img"].device, non_blocking=True).float()
135
+ inds = torch.from_numpy(np.array(inds)).to(input["img"].device, non_blocking=True)
136
+
137
+ return init_polys, inds, confidences
138
+
139
+ def evolve_poly(self, snake, cnn_feature, i_it_poly, ind):
140
+ if len(i_it_poly) == 0:
141
+ return torch.zeros_like(i_it_poly)
142
+ h, w = cnn_feature.size(2)*cfg.scale, cnn_feature.size(3)*cfg.scale
143
+ node_feats = get_node_feature(cnn_feature, i_it_poly, ind, h, w)
144
+ i_poly = i_it_poly + torch.clamp(snake(node_feats, self.adj).permute(0, 2, 1), -self.clip_dis, self.clip_dis)
145
+ if self.is_training:
146
+ i_poly = torch.clamp(i_poly, 0, w-1)
147
+ else:
148
+ i_poly[:, :, 0] = torch.clamp(i_poly[:, :, 0], 0, w - 1)
149
+ i_poly[:, :, 1] = torch.clamp(i_poly[:, :, 1], 0, h - 1)
150
+ return i_poly
151
+
152
+ def forward(self, embed_feature, input=None, seg_preds=None, switch="gt"):
153
+ if self.is_training:
154
+ init_polys, inds, confidences = self.get_boundary_proposal(input=input, seg_preds=seg_preds, switch=switch)
155
+ # TODO sample fix number
156
+ else:
157
+ init_polys, inds, confidences = self.get_boundary_proposal_eval(input=input, seg_preds=seg_preds)
158
+ if init_polys.shape[0] == 0:
159
+ return [init_polys for i in range(self.iter+1)], inds, confidences
160
+
161
+ py_preds = [init_polys, ]
162
+ for i in range(self.iter):
163
+ evolve_gcn = self.__getattr__('evolve_gcn' + str(i))
164
+ init_polys = self.evolve_poly(evolve_gcn, embed_feature, init_polys, inds[0])
165
+ py_preds.append(init_polys)
166
+
167
+ return py_preds, inds, confidences
168
+
169
+
170
+ class TextNet(nn.Module):
171
+
172
+ def __init__(self, backbone='vgg', is_training=True):
173
+ super().__init__()
174
+ self.is_training = is_training
175
+ self.backbone_name = backbone
176
+ self.fpn = FPN(self.backbone_name, is_training=(not cfg.resume and is_training))
177
+
178
+ self.seg_head = nn.Sequential(
179
+ nn.Conv2d(32, 16, kernel_size=3, padding=2, dilation=2),
180
+ nn.PReLU(),
181
+ nn.Conv2d(16, 16, kernel_size=3, padding=4, dilation=4),
182
+ nn.PReLU(),
183
+ nn.Conv2d(16, 4, kernel_size=1, stride=1, padding=0),
184
+ )
185
+ self.BPN = Evolution(cfg.num_points, adj_num=4,
186
+ is_training=is_training, device=cfg.device, model="BT")
187
+
188
+ def load_model(self, model_path):
189
+ print('Loading from {}'.format(model_path))
190
+ state_dict = torch.load(model_path, map_location=torch.device(cfg.device))
191
+ self.load_state_dict(state_dict['model'], strict=(not self.is_training))
192
+
193
+ def forward(self, input_dict, test_speed=False):
194
+ output = {}
195
+ b, c, h, w = input_dict["img"].shape
196
+ if self.is_training or cfg.exp_name in ['ArT', 'MLT2017', "MLT2019"] or test_speed:
197
+ image = input_dict["img"]
198
+ else:
199
+ image = torch.zeros((b, c, cfg.test_size[1], cfg.test_size[1]), dtype=torch.float32).to(cfg.device)
200
+ image[:, :, :h, :w] = input_dict["img"][:, :, :, :]
201
+
202
+ up1, _, _, _, _ = self.fpn(image)
203
+ up1 = up1[:, :, :h // cfg.scale, :w // cfg.scale]
204
+
205
+ preds = self.seg_head(up1)
206
+ fy_preds = torch.cat([torch.sigmoid(preds[:, 0:2, :, :]), preds[:, 2:4, :, :]], dim=1)
207
+ cnn_feats = torch.cat([up1, fy_preds], dim=1)
208
+
209
+ py_preds, inds, confidences = self.BPN(cnn_feats, input=input_dict, seg_preds=fy_preds, switch="gt")
210
+
211
+ output["fy_preds"] = fy_preds
212
+ output["py_preds"] = py_preds
213
+ output["inds"] = inds
214
+ output["confidences"] = confidences
215
+
216
+ return output
IndicPhotoOCR/detection/textbpn/output.png ADDED

Git LFS Details

  • SHA256: 44b8104e8e5d470e051d4b568214e3591e098f2677d1c0e1a0d6594e2d049636
  • Pointer size: 132 Bytes
  • Size of remote file: 8.7 MB
IndicPhotoOCR/detection/textbpn/textbpnpp_detector.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from IndicPhotoOCR.detection.textbpn.network.textnet import TextNet
5
+ from IndicPhotoOCR.detection.textbpn.cfglib.config import config as cfg
6
+ import warnings
7
+ import os
8
+ import requests
9
+ from tqdm import tqdm
10
+
11
+ # Suppress warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ model_info = {
15
+ "textbpnpp": {
16
+ "path": "models/TextBPN_resnet50_300.pth",
17
+ "url" : "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet50_300.pth",
18
+ },
19
+ "textbpnpp_deformable": {
20
+ "path":"models/TextBPN_deformable_resnet50_300.pth",
21
+ "url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_deformable_resnet50_300.pth",
22
+ },
23
+ "textbpn_resnet18" : {
24
+ "path":"models/TextBPN_resnet18_300.pth",
25
+ "url": "https://github.com/Bhashini-IITJ/SceneTextDetection/releases/download/TextBPN%2B%2B/TextBPN_resnet18_300.pth",
26
+
27
+ }
28
+ }
29
+ # Ensure model file exists; download directly if not
30
+ def ensure_model(model_name):
31
+ model_path = model_info[model_name]["path"]
32
+ url = model_info[model_name]["url"]
33
+ root_model_dir = "IndicPhotoOCR/detection/textbpn"
34
+ model_path = os.path.join(root_model_dir, model_path)
35
+
36
+ if not os.path.exists(model_path):
37
+ print(f"Model not found locally. Downloading {model_name} from {url}...")
38
+
39
+ # Start the download with a progress bar
40
+ response = requests.get(url, stream=True)
41
+ total_size = int(response.headers.get('content-length', 0))
42
+ os.makedirs(f"{root_model_dir}/models", exist_ok=True)
43
+
44
+ with open(model_path, "wb") as f, tqdm(
45
+ desc=model_name,
46
+ total=total_size,
47
+ unit='B',
48
+ unit_scale=True,
49
+ unit_divisor=1024,
50
+ ) as bar:
51
+ for data in response.iter_content(chunk_size=1024):
52
+ f.write(data)
53
+ bar.update(len(data))
54
+
55
+ print(f"Downloaded model for {model_name}.")
56
+
57
+ return model_path
58
+
59
+ class TextBPNpp_detector:
60
+ def __init__(self, model_name="textbpnpp", backbone="resnet50", device="cpu"):
61
+ """
62
+ Initialize the TextBPN model.
63
+ :param model_path: Path to the pre-trained model.
64
+ :param backbone: Backbone architecture (default: "resnet50").
65
+ :param device: Device to run the model on (default: "cpu").
66
+ """
67
+ self.model_path = ensure_model(model_name)
68
+ self.device = torch.device(device)
69
+ self.model = TextNet(is_training=False, backbone=backbone)
70
+ self.model.load_model(self.model_path)
71
+ self.model.eval()
72
+ self.model.to(self.device)
73
+
74
+ @staticmethod
75
+ def to_device(tensor, device):
76
+ """
77
+ Move tensor to the specified device.
78
+ :param tensor: Tensor to move.
79
+ :param device: Target device.
80
+ :return: Tensor on the target device.
81
+ """
82
+ return tensor.to(device, non_blocking=True)
83
+
84
+ @staticmethod
85
+ def pad_image(image, stride=32):
86
+ """
87
+ Pad the image to make its dimensions divisible by the stride.
88
+ :param image: Input image.
89
+ :param stride: Stride size.
90
+ :return: Padded image and original dimensions.
91
+ """
92
+ h, w = image.shape[:2]
93
+ new_h = (h + stride - 1) // stride * stride
94
+ new_w = (w + stride - 1) // stride * stride
95
+ padded_image = cv2.copyMakeBorder(
96
+ image, 0, new_h - h, 0, new_w - w, cv2.BORDER_CONSTANT, value=(0, 0, 0)
97
+ )
98
+ return padded_image, (h, w)
99
+
100
+ @staticmethod
101
+ def rescale_result(image, bbox_contours, original_height, original_width):
102
+ """
103
+ Rescale the bounding box contours to the original image size.
104
+ :param image: Image after resizing.
105
+ :param bbox_contours: Bounding box contours.
106
+ :param original_height: Original image height.
107
+ :param original_width: Original image width.
108
+ :return: Original image and rescaled contours.
109
+ """
110
+ contours = []
111
+ for cont in bbox_contours:
112
+ cont[:, 0] = (cont[:, 0] * original_width / image.shape[1]).astype(int)
113
+ cont[:, 1] = (cont[:, 1] * original_height / image.shape[0]).astype(int)
114
+ contours.append(cont)
115
+ return contours
116
+
117
+ def detect(self, image_path):
118
+ """
119
+ Perform text detection on the given image.
120
+ :param image_path: Path to the input image.
121
+ :return: Dictionary with detection results.
122
+ """
123
+ image = cv2.imread(image_path)
124
+ if image is None:
125
+ raise ValueError(f"Failed to read the image at {image_path}")
126
+
127
+ padded_image, original_size = self.pad_image(image)
128
+ padded_tensor = (
129
+ torch.from_numpy(padded_image).permute(2, 0, 1).float() / 255.0
130
+ ).unsqueeze(0) # Convert to tensor and add batch dimension
131
+
132
+ cfg.test_size = [padded_image.shape[0], padded_image.shape[1]]
133
+
134
+ input_dict = {"img": self.to_device(padded_tensor, self.device)}
135
+ with torch.no_grad():
136
+ output_dict = self.model(input_dict, padded_image.shape)
137
+
138
+ contours = output_dict["py_preds"][-1].int().cpu().numpy()
139
+ contours = self.rescale_result(image, contours, *original_size)
140
+
141
+ bbox_result_dict = {"detections": []}
142
+ for contour in contours:
143
+ # x_min, y_min = np.min(contour, axis=0)
144
+ # x_max, y_max = np.max(contour, axis=0)
145
+ # bbox_result_dict["detections"].append([x_min, y_min, x_max, y_max])
146
+ bbox_result_dict["detections"].append(contour.tolist())
147
+
148
+ return bbox_result_dict
149
+
150
+ def visualize_detections(self, image_path, bbox_result_dict, output_path="output.png"):
151
+ """
152
+ Visualize detections on the image.
153
+ :param image_path: Path to the input image.
154
+ :param bbox_result_dict: Detection results in the format:
155
+ {'detections': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...]}.
156
+ :param output_path: Path to save the visualized image. If None, the image is only displayed.
157
+ """
158
+ # Load the image
159
+ image = cv2.imread(image_path)
160
+ if image is None:
161
+ raise ValueError(f"Failed to read the image at {image_path}")
162
+
163
+ # Draw each detection
164
+ for bbox in bbox_result_dict.get("detections", []):
165
+ points = np.array(bbox, dtype=np.int32) # Convert to numpy array
166
+ cv2.polylines(image, [points], isClosed=True, color=(0, 255, 0), thickness=2)
167
+
168
+ # Display or save the visualized image
169
+ if output_path:
170
+ cv2.imwrite(output_path, image)
171
+ print(f"Visualization saved to {output_path}")
172
+ else:
173
+ cv2.imshow("Detections", image)
174
+ cv2.waitKey(0)
175
+ cv2.destroyAllWindows()
176
+
177
+ if __name__ == "__main__":
178
+ import argparse
179
+ parser = argparse.ArgumentParser(description='Text detection using EAST model')
180
+ parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
181
+ parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
182
+ parser.add_argument('--model_name', type=str, required=True, help='Path to the model checkpoint file')
183
+ args = parser.parse_args()
184
+
185
+
186
+
187
+ # model_path = "/DATA1/ocrteam/anik/git/IndicPhotoOCR/IndicPhotoOCR/detection/textbpn/models/TextBPN_resnet50_300.pth"
188
+ # image_path = "/DATA1/ocrteam/anik/splitonBSTD/detection/D/image_542.jpg"
189
+
190
+ detector = TextBPNpp_detector(args.model_name, device="cpu")
191
+ result = detector.detect(args.image_path)
192
+ print(result)
193
+ # detector.visualize_detections(image_path, result)
194
+
195
+ # python -m IndicPhotoOCR.detection.textbpn.textbpnpp_detector \
196
+ # --image_path /DATA1/ocrteam/anik/splitonBSTD/detection/D/image_542.jpg \
197
+ # --model_name textbpnpp
IndicPhotoOCR/detection/textbpn/util/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .visualize import *
2
+ from .pbox import *
IndicPhotoOCR/detection/textbpn/util/augmentation.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ __author__ = "S.X.Zhang"
3
+ import numpy as np
4
+ import math
5
+ import cv2
6
+ import copy
7
+ import numpy.random as random
8
+ from shapely.geometry import Polygon
9
+ import torchvision.transforms as transforms
10
+ import torchvision.transforms.functional as F
11
+ from PIL import ImageEnhance, Image
12
+
13
+
14
+ ###<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<###
15
+ ###<<<<<<<<< Function >>>>>>>>>>>>###
16
+ ###>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>###
17
+ def crop_first(image, polygons, scale =10):
18
+ polygons_new = copy.deepcopy(polygons)
19
+ h, w, _ = image.shape
20
+ pad_h = h // scale
21
+ pad_w = w // scale
22
+ h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
23
+ w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
24
+
25
+ text_polys = []
26
+ pos_polys = []
27
+ for polygon in polygons_new:
28
+ rect = cv2.minAreaRect(polygon.points.astype(np.int32))
29
+ box = cv2.boxPoints(rect)
30
+ box = np.int0(box)
31
+ text_polys.append([box[0], box[1], box[2], box[3]])
32
+ if polygon.label != -1:
33
+ pos_polys.append([box[0], box[1], box[2], box[3]])
34
+
35
+ polys = np.array(text_polys, dtype=np.int32)
36
+ for poly in polys:
37
+ poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入
38
+ minx = np.min(poly[:, 0])
39
+ maxx = np.max(poly[:, 0])
40
+ w_array[minx + pad_w:maxx + pad_w] = 1
41
+ miny = np.min(poly[:, 1])
42
+ maxy = np.max(poly[:, 1])
43
+ h_array[miny + pad_h:maxy + pad_h] = 1
44
+ # ensure the cropped area not across a text 保证截取区域不会横穿文字
45
+ h_axis = np.where(h_array == 0)[0]
46
+ w_axis = np.where(w_array == 0)[0]
47
+ pp_polys = np.array(pos_polys, dtype=np.int32)
48
+
49
+ return h_axis, w_axis, pp_polys
50
+
51
+ ####<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<####
52
+ ####<<<<<<<<<<< Class >>>>>>>>>>>>>####
53
+ ####>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>####
54
+ class Compose(object):
55
+ """Composes several augmentations together.
56
+ Args:
57
+ transforms (List[Transform]): list of transforms to compose.
58
+ Example:
59
+ >>> augmentations.Compose([
60
+ >>> transforms.CenterCrop(10),
61
+ >>> transforms.ToTensor(),
62
+ >>> ])
63
+ """
64
+
65
+ def __init__(self, transforms):
66
+ self.transforms = transforms
67
+
68
+ def __call__(self, img, pts=None):
69
+ for t in self.transforms:
70
+ img, pts = t(img, pts)
71
+ return img, pts
72
+
73
+
74
+ class Normalize(object):
75
+ def __init__(self, mean, std):
76
+ self.mean = np.array(mean)
77
+ self.std = np.array(std)
78
+
79
+ def __call__(self, image, polygons=None):
80
+ image = image.astype(np.float32)
81
+ image /= 255.0
82
+ image -= self.mean
83
+ image /= self.std
84
+ return image, polygons
85
+
86
+
87
+ class MinusMean(object):
88
+ def __init__(self, mean):
89
+ self.mean = np.array(mean)
90
+
91
+ def __call__(self, image, polygons=None):
92
+ image = image.astype(np.float32)
93
+ image -= self.mean
94
+ return image, polygons
95
+
96
+
97
+ class RandomMirror(object):
98
+ # 镜像
99
+ def __init__(self):
100
+ pass
101
+
102
+ def __call__(self, image, polygons=None):
103
+ if polygons is None:
104
+ return image, polygons
105
+ if random.random()< 0.3:
106
+ image = np.ascontiguousarray(image[:, ::-1])
107
+ _, width, _ = image.shape
108
+ for polygon in polygons:
109
+ polygon.points[:, 0] = width - polygon.points[:, 0]
110
+ return image, polygons
111
+
112
+
113
+ class AugmentColor(object):
114
+ # 颜色增强(添加噪声)
115
+ def __init__(self):
116
+ self.U = np.array([[-0.56543481, 0.71983482, 0.40240142],
117
+ [-0.5989477, -0.02304967, -0.80036049],
118
+ [-0.56694071, -0.6935729, 0.44423429]], dtype=np.float32)
119
+ self.EV = np.array([1.65513492, 0.48450358, 0.1565086], dtype=np.float32)
120
+ self.sigma = 0.1
121
+ self.color_vec = None
122
+
123
+ def __call__(self, img, polygons=None):
124
+ color_vec = self.color_vec
125
+ if self.color_vec is None:
126
+ if not self.sigma > 0.0:
127
+ color_vec = np.zeros(3, dtype=np.float32)
128
+ else:
129
+ color_vec = np.random.normal(0.0, self.sigma, 3)
130
+
131
+ alpha = color_vec.astype(np.float32) * self.EV
132
+ noise = np.dot(self.U, alpha.T) * 255
133
+ return np.clip(img + noise[np.newaxis, np.newaxis, :], 0, 255), polygons
134
+
135
+
136
+ class RandomContrast(object):
137
+ def __init__(self, lower=0.5, upper=1.5):
138
+ self.lower = lower
139
+ self.upper = upper
140
+ assert self.upper >= self.lower, "contrast upper must be >= lower."
141
+ assert self.lower >= 0, "contrast lower must be non-negative."
142
+
143
+ # expects float image
144
+ def __call__(self, image, polygons=None):
145
+ if random.randint(2):
146
+ alpha = random.uniform(self.lower, self.upper)
147
+ image *= alpha
148
+ return np.clip(image, 0, 255), polygons
149
+
150
+
151
+ class RandomBrightness(object):
152
+ def __init__(self, delta=32):
153
+ assert delta >= 0.0
154
+ assert delta <= 255.0
155
+ self.delta = delta
156
+
157
+ def __call__(self, image, polygons=None):
158
+ image = image.astype(np.float32)
159
+ if random.randint(2):
160
+ delta = random.uniform(-self.delta, self.delta)
161
+ image += delta
162
+ return np.clip(image, 0, 255), polygons
163
+
164
+
165
+ class RandomErasing(object):
166
+ def __init__(self, sr=(0.0004, 0.01), scale=(0.5, 3), ratio=0.2, Type ="Erasing"):
167
+ """
168
+
169
+ :param area:
170
+ :param type: Erasing or Cutout
171
+ """
172
+ self.sr = sr
173
+ self.scale= scale
174
+ self.ratio=ratio
175
+ self.type=Type
176
+
177
+ def __call__(self, img, polygons=None):
178
+
179
+ if random.random()< self.ratio:
180
+ return img, polygons
181
+ area=img.shape[0]*img.shape[1]
182
+ target_area=random.randint(*self.sr)*area
183
+ aspect_ratio=random.uniform(*self.scale)
184
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
185
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
186
+
187
+ if w < img.shape[1] and h < img.shape[0]:
188
+ x1 = random.randint(0, img.shape[1] - w)
189
+ y1 = random.randint(0, img.shape[0] - h)
190
+ if self.type == "Erasing":
191
+ color=(random.randint(0, 255),random.randint(0, 255),random.randint(0, 255))
192
+ img[y1:y1+h, x1:x1+h,:]=color
193
+ else:
194
+ Gray_value=random.randint(0, 255)
195
+ color = (Gray_value, Gray_value ,Gray_value)
196
+ img[y1:y1 + h, x1:x1 + h, :] = color
197
+
198
+ return img, polygons
199
+
200
+
201
+ class RandomMixUp(object):
202
+ def __init__(self, mixup_alpha=2):
203
+ self.mixup_alpha = mixup_alpha
204
+
205
+ def __call__(self, img1, img2, label1=[], label2=[]):
206
+ beta=np.random.beta(self.mixup_alpha,self.mixup_alpha)
207
+
208
+ #image = img1 * Gama + (1 - Gama) * img2
209
+ image=cv2.addWeighted(img1, beta, img2, (1-beta), 0)
210
+
211
+ if label1 is None or label2 is None:
212
+ return img1, label1
213
+ if isinstance(label1, list) and isinstance(label2, list):
214
+ label=[]
215
+ for id in range(len(label1)):
216
+ lab = beta*label1[id]+ (1-beta)*label2[id]
217
+ label.append(lab)
218
+ return image, label
219
+ else:
220
+ print("Error: label is not a list type")
221
+
222
+ return img1, label1
223
+
224
+
225
+ class Rotate(object):
226
+ def __init__(self, up=30):
227
+ self.up = up
228
+
229
+ @staticmethod
230
+ def rotate(center, pt, theta): # 二维图形学的旋转
231
+ xr, yr = center
232
+ yr = -yr
233
+ x, y = pt[:, 0], pt[:, 1]
234
+ y = -y
235
+
236
+ theta = theta / 180 * math.pi
237
+ cos = math.cos(theta)
238
+ sin = math.sin(theta)
239
+
240
+ _x = xr + (x - xr) * cos - (y - yr) * sin
241
+ _y = yr + (x - xr) * sin + (y - yr) * cos
242
+
243
+ return _x, -_y
244
+
245
+ def __call__(self, img, polygons=None):
246
+ if np.random.randint(2):
247
+ return img, polygons
248
+ angle = np.random.normal(loc=0.0, scale=0.5) * self.up # angle 按照高斯分布
249
+ rows, cols = img.shape[0:2]
250
+ M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1.0)
251
+ img = cv2.warpAffine(img, M, (cols, rows), borderValue=[0, 0, 0])
252
+ center = cols / 2.0, rows / 2.0
253
+ if polygons is not None:
254
+ for polygon in polygons:
255
+ x, y = self.rotate(center, polygon.points, angle)
256
+ pts = np.vstack([x, y]).T
257
+ polygon.points = pts
258
+ return img, polygons
259
+
260
+
261
+ class RotatePadding(object):
262
+ def __init__(self, up=60,colors=True):
263
+ self.up = up
264
+ self.colors = colors
265
+ self.ratio = 0.5
266
+
267
+ @staticmethod
268
+ def rotate(center, pt, theta, movSize=[0, 0], scale=1): # 二维图形学的旋转
269
+ (xr, yr) = center
270
+ yr = -yr
271
+ x, y = pt[:, 0], pt[:, 1]
272
+ y = -y
273
+
274
+ theta = theta / 180 * math.pi
275
+ cos = math.cos(theta)
276
+ sin = math.sin(theta)
277
+
278
+ x = (x - xr) * scale
279
+ y = (y - yr) * scale
280
+
281
+ _x = xr + x * cos - y * sin + movSize[0]
282
+ _y = -(yr + x * sin + y * cos) + movSize[1]
283
+
284
+ return _x, _y
285
+
286
+ @staticmethod
287
+ def shift(size, degree):
288
+ angle = degree * math.pi / 180.0
289
+ width = size[0]
290
+ height = size[1]
291
+
292
+ alpha = math.cos(angle)
293
+ beta = math.sin(angle)
294
+ new_width = int(width * math.fabs(alpha) + height * math.fabs(beta))
295
+ new_height = int(width * math.fabs(beta) + height * math.fabs(alpha))
296
+
297
+ size = [new_width, new_height]
298
+ return size
299
+
300
+ def __call__(self, image, polygons=None, scale=1.0):
301
+ if np.random.random() <= self.ratio:
302
+ return image, polygons
303
+ angle = np.random.normal(loc=0.0, scale=0.5) * self.up # angle 按照高斯分布
304
+ rows, cols = image.shape[0:2]
305
+ center = (cols / 2.0, rows / 2.0)
306
+ newSize = self.shift([cols * scale, rows * scale], angle)
307
+ movSize = [int((newSize[0] - cols) / 2), int((newSize[1] - rows) / 2)]
308
+
309
+ M = cv2.getRotationMatrix2D(center, angle, scale)
310
+ M[0, 2] += int((newSize[0] - cols) / 2)
311
+ M[1, 2] += int((newSize[1] - rows) / 2)
312
+
313
+ if self.colors:
314
+ H, W, _ = image.shape
315
+ mask = np.zeros_like(image)
316
+ (h_index, w_index) = (np.random.randint(0, H * 7 // 8), np.random.randint(0, W * 7 // 8))
317
+ img_cut = image[h_index:(h_index + H // 9), w_index:(w_index + W // 9)]
318
+ img_cut = cv2.resize(img_cut, (newSize[0], newSize[1]))
319
+ mask = cv2.warpAffine(mask, M, (newSize[0], newSize[1]), borderValue=[1, 1, 1])
320
+ image = cv2.warpAffine(image, M, (newSize[0], newSize[1]), borderValue=[0,0,0])
321
+ image=image+img_cut*mask
322
+ else:
323
+ color = [0, 0, 0]
324
+ image = cv2.warpAffine(image, M, (newSize[0], newSize[1]), borderValue=color)
325
+
326
+ if polygons is not None:
327
+ for polygon in polygons:
328
+ x, y = self.rotate(center, polygon.points, angle,movSize,scale)
329
+ pts = np.vstack([x, y]).T
330
+ polygon.points = pts
331
+ return image, polygons
332
+
333
+
334
+ class SquarePadding(object):
335
+
336
+ def __call__(self, image, polygons=None):
337
+
338
+ H, W, _ = image.shape
339
+
340
+ if H == W:
341
+ return image, polygons
342
+
343
+ padding_size = max(H, W)
344
+ (h_index, w_index) = (np.random.randint(0, H*7//8),np.random.randint(0, W*7//8))
345
+ img_cut = image[h_index:(h_index+H//9),w_index:(w_index+W//9)]
346
+ expand_image = cv2.resize(img_cut,(padding_size, padding_size))
347
+ #expand_image = np.zeros((padding_size, padding_size, 3), dtype=image.dtype)
348
+ #expand_image=img_cut[:,:,:]
349
+ if H > W:
350
+ y0, x0 = 0, (H - W) // 2
351
+ else:
352
+ y0, x0 = (W - H) // 2, 0
353
+ if polygons is not None:
354
+ for polygon in polygons:
355
+ polygon.points += np.array([x0, y0])
356
+ expand_image[y0:y0+H, x0:x0+W] = image
357
+ image = expand_image
358
+
359
+ return image, polygons
360
+
361
+
362
+ class RandomImgCropPatch(object):
363
+ def __init__(self, up=30, beta=0.3):
364
+ self.up = up
365
+ self.beta=0.3
366
+ self.scale = 10
367
+
368
+ @staticmethod
369
+ def get_contour_min_area_box(contour):
370
+ rect = cv2.minAreaRect(contour)
371
+ box = cv2.boxPoints(rect)
372
+ box = np.int0(box)
373
+ return box
374
+
375
+ def CropWH(self, image, cut_w, cut_h, polygons=None):
376
+ h_axis, w_axis, polys = crop_first(image, polygons, scale=self.scale)
377
+ h, w, _ = image.shape
378
+ pad_h = h // self.scale
379
+ pad_w = w // self.scale
380
+ # TODO try Flip
381
+ xx = np.random.choice(w_axis, size=2)
382
+ xmin = np.min(xx) - pad_w
383
+ xmax = xmin + cut_w
384
+ yy = np.random.choice(h_axis, size=2)
385
+ ymin = np.min(yy) - pad_h
386
+ ymax = ymin + cut_h
387
+ if polys.shape[0] != 0:
388
+ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
389
+ & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
390
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
391
+ else:
392
+ selected_polys = []
393
+
394
+ cropped = image[ymin:ymax + 1, xmin:xmax + 1, :]
395
+ polygons_new = []
396
+ for idx in selected_polys:
397
+ polygon = polygons[idx]
398
+ polygon.points -= np.array([xmin, ymin])
399
+ polygons_new.append(polygon)
400
+ image = cropped
401
+ polygon = polygons_new
402
+
403
+ return image, polygon
404
+
405
+ def __call__(self, images, polygons_list=None):
406
+ I_x, I_y = 1024,1024
407
+
408
+ w = int(round(I_x * random.beta(self.beta, self.beta)))
409
+ h = int(round(I_y * random.beta(self.beta, self.beta)))
410
+ w_ = [w, I_x - w, w, I_x - w]
411
+ h_ = [h, h, I_y - h, I_y - h]
412
+ new_img = np.zeros((I_x, I_y, 3), dtype=images[0].dtype)
413
+ imgs=[]
414
+ new_polygons=[]
415
+ for i, im in enumerate(images):
416
+ img, polygons = self.CropWH(im, w_[i], h_[i], polygons=polygons_list[i])
417
+ imgs.append(img)
418
+ new_polygons.append(polygons)
419
+ new_img[0:w, 0:h, :] = imgs[0]
420
+ new_img[w:I_x, 0:h, :] = imgs[1]
421
+ new_img[0:w, h:I_y, :] = imgs[2]
422
+ new_img[w:I_x, h:I_y, :] = imgs[3]
423
+ for polygon in new_polygons[1]:
424
+ polygon.points += np.array([w, 0])
425
+ for polygon in new_polygons[2]:
426
+ polygon.points += np.array([0, h])
427
+ for polygon in new_polygons[3]:
428
+ polygon.points += np.array([w, h])
429
+
430
+ polygons=new_polygons[0]+new_polygons[1]+new_polygons[2]+new_polygons[3]
431
+
432
+ return new_img, polygons
433
+
434
+
435
+ class RandomCropFlip(object):
436
+
437
+ def __init__(self, min_crop_side_ratio=0.01):
438
+ self.scale = 10
439
+ self.ratio = 0.2
440
+ self.epsilon = 10.0
441
+ self.min_crop_side_ratio = min_crop_side_ratio
442
+
443
+ def __call__(self, image, polygons=None):
444
+
445
+ if polygons is None:
446
+ return image, polygons
447
+
448
+ if np.random.random() <= self.ratio:
449
+ return image, polygons
450
+
451
+ # 计算 有效的Crop区域, 方便选取有效的种子点
452
+ h_axis, w_axis, pp_polys = crop_first(image, polygons, scale =self.scale)
453
+ if len(h_axis) == 0 or len(w_axis) == 0:
454
+ return image, polygons
455
+
456
+ # TODO try crop
457
+ attempt = 0
458
+ h, w, _ = image.shape
459
+ area = h * w
460
+ pad_h = h // self.scale
461
+ pad_w = w // self.scale
462
+ while attempt < 10:
463
+ attempt += 1
464
+ polygons_new = []
465
+ xx = np.random.choice(w_axis, size=2)
466
+ xmin = np.min(xx) - pad_w
467
+ xmax = np.max(xx) - pad_w
468
+ xmin = np.clip(xmin, 0, w - 1)
469
+ xmax = np.clip(xmax, 0, w - 1)
470
+ yy = np.random.choice(h_axis, size=2)
471
+ ymin = np.min(yy) - pad_h
472
+ ymax = np.max(yy) - pad_h
473
+ ymin = np.clip(ymin, 0, h - 1)
474
+ ymax = np.clip(ymax, 0, h - 1)
475
+ if (xmax - xmin) * (ymax - ymin) < area * self.min_crop_side_ratio:
476
+ # area too small
477
+ continue
478
+
479
+ pts = np.stack([[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
480
+ pp = Polygon(pts).buffer(0)
481
+ Fail_flag = False
482
+ for polygon in polygons:
483
+ ppi = Polygon(polygon.points).buffer(0)
484
+ ppiou = float(ppi.intersection(pp).area)
485
+ if np.abs(ppiou - float(ppi.area)) > self.epsilon and np.abs(ppiou) > self.epsilon:
486
+ Fail_flag = True
487
+ break
488
+ if np.abs(ppiou - float(ppi.area)) < self.epsilon:
489
+ polygons_new.append(polygon)
490
+
491
+ if Fail_flag:
492
+ continue
493
+ else:
494
+ break
495
+
496
+ if len(polygons_new) == 0:
497
+ cropped = image[ymin:ymax, xmin:xmax, :]
498
+ select_type = random.randint(3)
499
+ if select_type == 0:
500
+ img = np.ascontiguousarray(cropped[:, ::-1])
501
+ elif select_type == 1:
502
+ img = np.ascontiguousarray(cropped[::-1, :])
503
+ else:
504
+ img = np.ascontiguousarray(cropped[::-1, ::-1])
505
+ image[ymin:ymax, xmin:xmax, :] = img
506
+ return image, polygons
507
+
508
+ else:
509
+ cropped = image[ymin:ymax, xmin:xmax, :]
510
+ height, width, _ = cropped.shape
511
+ select_type = random.randint(3)
512
+ if select_type == 0:
513
+ img = np.ascontiguousarray(cropped[:, ::-1])
514
+ for polygon in polygons_new:
515
+ polygon.points[:, 0] = width - polygon.points[:, 0] + 2 * xmin
516
+ elif select_type == 1:
517
+ img = np.ascontiguousarray(cropped[::-1, :])
518
+ for polygon in polygons_new:
519
+ polygon.points[:, 1] = height - polygon.points[:, 1] + 2 * ymin
520
+ else:
521
+ img = np.ascontiguousarray(cropped[::-1, ::-1])
522
+ for polygon in polygons_new:
523
+ polygon.points[:, 0] = width - polygon.points[:, 0] + 2 * xmin
524
+ polygon.points[:, 1] = height - polygon.points[:, 1] + 2 * ymin
525
+ image[ymin:ymax, xmin:xmax, :] = img
526
+
527
+ return image, polygons
528
+
529
+
530
+ class RandomResizedCrop(object):
531
+ def __init__(self, min_crop_side_ratio=0.1):
532
+ self.scale = 10
533
+ self.epsilon = 1e-2
534
+ self.min_crop_side_ratio = min_crop_side_ratio
535
+
536
+ def __call__(self, image, polygons):
537
+
538
+ if polygons is None:
539
+ return image, polygons
540
+
541
+ # 计算 有效的Crop区域, 方便选取有效的种子点
542
+ h_axis, w_axis, pp_polys = crop_first(image, polygons, scale =self.scale)
543
+ if len(h_axis) == 0 or len(w_axis) == 0:
544
+ return image, polygons
545
+
546
+ # TODO try crop
547
+ attempt = 0
548
+ h, w, _ = image.shape
549
+ area = h * w
550
+ pad_h = h // self.scale
551
+ pad_w = w // self.scale
552
+ while attempt < 10:
553
+ attempt += 1
554
+ xx = np.random.choice(w_axis, size=2)
555
+ xmin = np.min(xx) - pad_w
556
+ xmax = np.max(xx) - pad_w
557
+ xmin = np.clip(xmin, 0, w - 1)
558
+ xmax = np.clip(xmax, 0, w - 1)
559
+ yy = np.random.choice(h_axis, size=2)
560
+ ymin = np.min(yy) - pad_h
561
+ ymax = np.max(yy) - pad_h
562
+ ymin = np.clip(ymin, 0, h - 1)
563
+ ymax = np.clip(ymax, 0, h - 1)
564
+ if (xmax - xmin)*(ymax - ymin) <area*self.min_crop_side_ratio:
565
+ # area too small
566
+ continue
567
+ if pp_polys.shape[0] != 0:
568
+ poly_axis_in_area = (pp_polys[:, :, 0] >= xmin) & (pp_polys[:, :, 0] <= xmax) \
569
+ & (pp_polys[:, :, 1] >= ymin) & (pp_polys[:, :, 1] <= ymax)
570
+ selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
571
+ else:
572
+ selected_polys = []
573
+
574
+ if len(selected_polys) == 0:
575
+ continue
576
+ else:
577
+ pts = np.stack([[xmin, xmax, xmax, xmin], [ymin, ymin, ymax, ymax]]).T.astype(np.int32)
578
+ pp = Polygon(pts).buffer(0)
579
+ polygons_new = []
580
+ Fail_flag = False
581
+ for polygon in copy.deepcopy(polygons):
582
+ ppi = Polygon(polygon.points).buffer(0)
583
+ ppiou = float(ppi.intersection(pp).area)
584
+ if np.abs(ppiou - float(ppi.area)) > self.epsilon and np.abs(ppiou) > self.epsilon:
585
+ Fail_flag = True
586
+ break
587
+ elif np.abs(ppiou - float(ppi.area)) < self.epsilon:
588
+ # polygon.points -= np.array([xmin, ymin])
589
+ polygons_new.append(polygon)
590
+
591
+ if Fail_flag:
592
+ continue
593
+ else:
594
+ cropped = image[ymin:ymax + 1, xmin:xmax + 1, :]
595
+ for polygon in polygons_new:
596
+ polygon.points -= np.array([xmin, ymin])
597
+
598
+ return cropped, polygons_new
599
+
600
+ return image, polygons
601
+
602
+
603
+ class RandomResizeScale(object):
604
+ def __init__(self, size=512, ratio=(3./4, 5./2)):
605
+ self.size = size
606
+ self.ratio = ratio
607
+
608
+ def __call__(self, image, polygons=None):
609
+
610
+ aspect_ratio = np.random.uniform(self.ratio[0], self.ratio[1])
611
+ h, w, _ = image.shape
612
+ scales = self.size*1.0/max(h, w)
613
+ aspect_ratio = scales * aspect_ratio
614
+ aspect_ratio = int(w * aspect_ratio)*1.0/w
615
+ image = cv2.resize(image, (int(w * aspect_ratio), int(h*aspect_ratio)))
616
+ scales = np.array([aspect_ratio, aspect_ratio])
617
+ if polygons is not None:
618
+ for polygon in polygons:
619
+ polygon.points = polygon.points * scales
620
+
621
+ return image, polygons
622
+
623
+
624
+ class Resize(object):
625
+ def __init__(self, size=1024):
626
+ self.size = size
627
+ self.SP = SquarePadding()
628
+
629
+ def __call__(self, image, polygons=None):
630
+ h, w, _ = image.shape
631
+ image = cv2.resize(image, (self.size,
632
+ self.size))
633
+ scales = np.array([self.size / w, self.size / h])
634
+
635
+ if polygons is not None:
636
+ for polygon in polygons:
637
+ polygon.points = polygon.points * scales
638
+
639
+ return image, polygons
640
+
641
+
642
+ class ResizeSquare(object):
643
+ def __init__(self, size=(480, 1280)):
644
+ self.size = size
645
+
646
+ def __call__(self, image, polygons=None):
647
+ h, w, _ = image.shape
648
+ img_size_min = min(h, w)
649
+ img_size_max = max(h, w)
650
+
651
+ if img_size_min < self.size[0]:
652
+ im_scale = float(self.size[0]) / float(img_size_min) # expand min to size[0]
653
+ if np.ceil(im_scale * img_size_max) > self.size[1]: # expand max can't > size[1]
654
+ im_scale = float(self.size[1]) / float(img_size_max)
655
+ elif img_size_max > self.size[1]:
656
+ im_scale = float(self.size[1]) / float(img_size_max)
657
+ else:
658
+ im_scale = 1.0
659
+
660
+ new_h = int(int(h * im_scale/32)*32)
661
+ new_w = int(int(w * im_scale/32)*32)
662
+ # if new_h*new_w > 1600*1920:
663
+ # im_scale = 1600 / float(img_size_max)
664
+ # new_h = int(int(h * im_scale/32)*32)
665
+ # new_w = int(int(w * im_scale/32)*32)
666
+ image = cv2.resize(image, (new_w, new_h))
667
+ scales = np.array([new_w / w, new_h / h])
668
+ if polygons is not None:
669
+ for polygon in polygons:
670
+ polygon.points = polygon.points * scales
671
+
672
+ return image, polygons
673
+
674
+
675
+ class ResizeLimitSquare(object):
676
+ def __init__(self, size=512, ratio=0.6):
677
+ self.size = size
678
+ self.ratio = ratio
679
+ self.SP = SquarePadding()
680
+
681
+ def __call__(self, image, polygons=None):
682
+ if np.random.random() <= self.ratio:
683
+ image, polygons = self.SP(image, polygons)
684
+ h, w, _ = image.shape
685
+ image = cv2.resize(image, (self.size,self.size))
686
+ scales = np.array([self.size*1.0/ w, self.size*1.0 / h])
687
+
688
+ if polygons is not None:
689
+ for polygon in polygons:
690
+ polygon.points = polygon.points * scales
691
+
692
+ return image, polygons
693
+
694
+
695
+ class RandomResizePadding(object):
696
+ def __init__(self, size=512, random_scale=np.array([0.75, 1.0, 1.25,1.5,2.0]),stride=32, ratio=0.6667):
697
+ self.random_scale = random_scale
698
+ self.size = size
699
+ self.ratio=ratio
700
+ self.stride=stride
701
+ self.SP=SquarePadding()
702
+
703
+ ###########Random size for different eproches ########################
704
+ rd_scale = np.random.choice(self.random_scale)
705
+ step_num = round(np.random.normal(loc=0.0, scale=0.35) * 8) # step 按照高斯分布
706
+ self.input_size = np.clip(int(self.size * rd_scale + step_num * self.stride),
707
+ (int(self.size * self.random_scale[0] - self.stride)),
708
+ int(self.size * self.random_scale[-1] + self.stride))
709
+ ############################ end ########################
710
+
711
+ def __call__(self, image, polygons=None):
712
+
713
+ if np.random.random() <= self.ratio:
714
+ image, polygons = self.SP(image, polygons)
715
+ h, w, _ = image.shape
716
+ image = cv2.resize(image, (self.input_size,self.input_size))
717
+ scales = np.array([self.input_size*1.0/ w, self.input_size*1.0 / h])
718
+
719
+ if polygons is not None:
720
+ for polygon in polygons:
721
+ polygon.points = polygon.points * scales
722
+
723
+ return image, polygons
724
+
725
+ transform_type_dict = dict(
726
+ brightness=ImageEnhance.Brightness, contrast=ImageEnhance.Contrast,
727
+ sharpness=ImageEnhance.Sharpness, color=ImageEnhance.Color
728
+ )
729
+
730
+
731
+ class RandomDistortion(object):
732
+ def __init__(self, transform_dict, prob=0.5):
733
+ self.transforms = [(transform_type_dict[k], transform_dict[k]) for k in transform_dict]
734
+ self.prob = prob
735
+
736
+ def __call__(self, img, target):
737
+ if random.random() > self.prob:
738
+ return img, target
739
+ out = Image.fromarray(img)
740
+ rand_num = np.random.uniform(0, 1, len(self.transforms))
741
+
742
+ for i, (transformer, alpha) in enumerate(self.transforms):
743
+ r = alpha * (rand_num[i] * 2.0 - 1.0) + 1 # r in [1-alpha, 1+alpha)
744
+ out = transformer(out).enhance(r)
745
+
746
+ return np.array(out), target
747
+
748
+
749
+ class Augmentation(object):
750
+ def __init__(self, size, mean, std):
751
+ self.size = size
752
+ self.mean = mean
753
+ self.std = std
754
+ self._transform_dict = {'brightness': 0.5, 'contrast': 0.5, 'sharpness': 0.8386, 'color': 0.5}
755
+ self.augmentation = Compose([
756
+ RandomCropFlip(),
757
+ RandomResizeScale(size=self.size, ratio=(3. / 8, 5. / 2)),
758
+ RandomResizedCrop(),
759
+ RotatePadding(up=60, colors=True), # pretrain on Syn is "up=30", else is "up=60"
760
+ ResizeLimitSquare(size=self.size),
761
+ RandomMirror(),
762
+ RandomDistortion(self._transform_dict),
763
+ Normalize(mean=self.mean, std=self.std),
764
+ ])
765
+
766
+ def __call__(self, image, polygons=None):
767
+ return self.augmentation(image, polygons)
768
+
769
+
770
+ class BaseTransform(object):
771
+ def __init__(self, size, mean, std):
772
+ self.size = size
773
+ self.mean = mean
774
+ self.std = std
775
+ self.augmentation = Compose([
776
+ # Resize(size=640),
777
+ ResizeSquare(size=self.size),
778
+ Normalize(mean, std)
779
+ ])
780
+
781
+ def __call__(self, image, polygons=None):
782
+ return self.augmentation(image, polygons)
783
+
784
+
785
+ class BaseTransformNresize(object):
786
+ def __init__(self, mean, std):
787
+ self.mean = mean
788
+ self.std = std
789
+ self.augmentation = Compose([
790
+ Normalize(mean, std)
791
+ ])
792
+
793
+ def __call__(self, image, polygons=None):
794
+ return self.augmentation(image, polygons)
IndicPhotoOCR/detection/textbpn/util/canvas.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ __author__ = '古溪'
4
+
5
+ import numpy as np
6
+ import random
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ def heatmap(im_gray):
11
+ cmap = plt.get_cmap('jet')
12
+ rgba_img = cmap(255 - im_gray)
13
+ Hmap = np.delete(rgba_img, 3, 2)
14
+ # print(Hmap.shape, Hmap.max(), Hmap.min())
15
+ # cv2.imshow("heat_img", Hmap)
16
+ # cv2.waitKey(0)
17
+ return Hmap
18
+
19
+
20
+ def loss_ploy(loss_list, steps, period, name=""):
21
+ fig1, ax1 = plt.subplots(figsize=(16, 9))
22
+ ax1.plot(range(steps // period), loss_list)
23
+ ax1.set_title("Average loss vs step*{}".format(period))
24
+ ax1.set_xlabel("step*{}".format(period))
25
+ ax1.set_ylabel("Current loss")
26
+ plt.savefig('{}@loss_vs_step*{}.png'.format(name,period))
27
+ plt.clf()
28
+
29
+
30
+ def plt_ploys(ploys, period, name=""):
31
+ fig1, ax1 = plt.subplots(figsize=(16, 9))
32
+ cnames = ['aliceblue','antiquewhite','aqua','aquamarine','azure',
33
+ 'blanchedalmond','blue','blueviolet','brown','burlywood',
34
+ 'coral','cornflowerblue','cornsilk','crimson','cyan',
35
+ 'darkblue','deeppink','deepskyblue','dodgerblue','forestgreen',
36
+ 'gold','goldenrod','green','greenyellow','honeydew','hotpink',
37
+ 'lawngreen','lightblue','lightgreen','lightpink','lightsalmon',
38
+ 'lightseagreen','lightsteelblue','lightyellow','lime','limegreen',
39
+ 'mediumseagreen','mediumspringgreen','midnightblue','orange','orangered',
40
+ 'pink','red','royalblue','seagreen','skyblue','springgreen','steelblue',
41
+ 'tan','teal','thistle','yellow','yellowgreen']
42
+
43
+ color = random.sample(cnames, len(ploys.keys()))
44
+ for ii, key in enumerate(ploys.keys()):
45
+ ax1.plot(range(1, len(ploys[key])+1), ploys[key],color=color[ii], label=key)
46
+ ax1.set_title("Loss Carve line")
47
+ ax1.set_xlabel("step*{}".format(period))
48
+ ax1.set_ylabel("Current loss")
49
+ plt.legend(ploys.keys())
50
+ plt.savefig('{}@loss_vs_step*{}.png'.format(name, period))
51
+ plt.clf()
52
+
53
+ if __name__ == '__main__':
54
+ # TODO ADD CODE
55
+ pass
IndicPhotoOCR/detection/textbpn/util/detection.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # c++ version pse based on opencv 3+
2
+ from pse import decode as pse_decode
3
+ from cfglib.config import config as cfg
4
+
5
+
6
+ class TextDetector(object):
7
+
8
+ def __init__(self, model):
9
+ # evaluation mode
10
+ self.model = model
11
+ model.eval()
12
+ # parameter
13
+ self.scale = cfg.scale
14
+ self.threshold = cfg.threshold
15
+
16
+ def detect(self, image, img_show):
17
+ # get model output
18
+ preds = self.model.forward(image)
19
+ preds, boxes, contours = pse_decode(preds[0], self.scale, self.threshold)
20
+
21
+ output = {
22
+ 'image': image,
23
+ 'tr': preds,
24
+ 'bbox': boxes
25
+ }
26
+ return contours, output
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
IndicPhotoOCR/detection/textbpn/util/eval.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import subprocess
5
+ from cfglib.config import config as cfg
6
+ from util.misc import mkdirs
7
+
8
+
9
+ def osmkdir(out_dir):
10
+ import shutil
11
+ if os.path.exists(out_dir):
12
+ shutil.rmtree(out_dir)
13
+ os.makedirs(out_dir)
14
+
15
+
16
+ def analysize_result(source_dir, fid_path, outpt_dir, name):
17
+
18
+ bad_txt = open("{}/eval.txt".format(outpt_dir), 'w')
19
+ all_eval = open("{}/{}/{}_eval.txt".format(cfg.output_dir, "Analysis", name), 'a+')
20
+ sel_list = list()
21
+ with open(fid_path) as f:
22
+ lines = f.read().split("\n")
23
+ for line in lines:
24
+ line_items = line.split(" ")
25
+ id = line_items[0]
26
+ precision = float(line_items[2].split('=')[-1])
27
+ recall = float(line_items[4].split('=')[-1])
28
+ if id != "ALL" and (precision < 0.5 or recall < 0.5):
29
+ img_path = os.path.join(source_dir, line_items[0].replace(".txt", ".jpg"))
30
+ if os.path.exists(img_path):
31
+ os.system('cp {} {}'.format(img_path, outpt_dir))
32
+ sel_list.append((int(id.replace(".txt", "").replace("img", "").replace("_", "")), line))
33
+ if id == "ALL":
34
+ all_eval.write("{} {} {}\n".format(
35
+ outpt_dir.split('/')[-1],
36
+ "{}/{}".format(cfg.dis_threshold, cfg.cls_threshold),
37
+ line))
38
+ sel_list = sorted(sel_list, key=lambda its: its[0])
39
+ bad_txt.write('\n'.join([its[1] for its in sel_list]))
40
+ all_eval.close()
41
+ bad_txt.close()
42
+
43
+
44
+ def deal_eval_total_text(debug=False):
45
+ # compute DetEval
46
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
47
+ if not os.path.exists(eval_dir):
48
+ os.makedirs(eval_dir)
49
+
50
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
51
+ subprocess.call(
52
+ ['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', cfg.exp_name, '--tr', '0.7',
53
+ '--tp', '0.6'])
54
+ subprocess.call(
55
+ ['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', cfg.exp_name, '--tr', '0.8',
56
+ '--tp', '0.4'])
57
+
58
+ if debug:
59
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
60
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "total_text")
61
+ if not os.path.exists(outpt_dir_base):
62
+ mkdirs(outpt_dir_base)
63
+
64
+ outpt_dir1 = os.path.join(outpt_dir_base, "{}_{}_{}_{}_{}"
65
+ .format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch, 0.7, 0.6))
66
+ osmkdir(outpt_dir1)
67
+ fid_path1 = '{}/Eval_TotalText_{}_{}.txt'.format(eval_dir, 0.7, 0.6)
68
+
69
+ analysize_result(source_dir, fid_path1, outpt_dir1, "totalText")
70
+
71
+ outpt_dir2 = os.path.join(outpt_dir_base, "{}_{}_{}_{}_{}"
72
+ .format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch, 0.8, 0.4))
73
+ osmkdir(outpt_dir2)
74
+ fid_path2 = '{}/Eval_TotalText_{}_{}.txt'.format(eval_dir, 0.8, 0.4)
75
+
76
+ analysize_result(source_dir, fid_path2, outpt_dir2, "totalText")
77
+
78
+ print('End.')
79
+
80
+
81
+ def deal_eval_ctw1500(debug=False):
82
+ # compute DetEval
83
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
84
+ if not os.path.exists(eval_dir):
85
+ os.makedirs(eval_dir)
86
+
87
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
88
+ subprocess.call(['python', 'dataset/ctw1500/Evaluation_Protocol/ctw1500_eval.py', cfg.exp_name])
89
+
90
+ if debug:
91
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
92
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "ctw1500")
93
+ if not os.path.exists(outpt_dir_base):
94
+ mkdirs(outpt_dir_base)
95
+
96
+ outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
97
+ osmkdir(outpt_dir)
98
+ fid_path1 = '{}/Eval_ctw1500_{}.txt'.format(eval_dir, 0.5)
99
+
100
+ analysize_result(source_dir, fid_path1, outpt_dir, "ctw1500")
101
+
102
+ print('End.')
103
+
104
+
105
+ def deal_eval_icdar15(debug=False):
106
+ # compute DetEval
107
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
108
+ if not os.path.exists(eval_dir):
109
+ os.makedirs(eval_dir)
110
+
111
+ input_dir = 'output/{}'.format(cfg.exp_name)
112
+ father_path = os.path.abspath(input_dir)
113
+ print(father_path)
114
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
115
+ subprocess.call(['sh', 'dataset/icdar15/eval.sh', father_path])
116
+
117
+ if debug:
118
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
119
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "icdar15")
120
+ if not os.path.exists(outpt_dir_base):
121
+ mkdirs(outpt_dir_base)
122
+
123
+ outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
124
+ osmkdir(outpt_dir)
125
+ fid_path1 = '{}/Eval_icdar15.txt'.format(eval_dir)
126
+
127
+ analysize_result(source_dir, fid_path1, outpt_dir, "icdar15")
128
+
129
+ print('End.')
130
+
131
+ pass
132
+
133
+
134
+ def deal_eval_TD500(debug=False):
135
+ # compute DetEval
136
+ eval_dir = os.path.join(cfg.output_dir, "Analysis", "output_eval")
137
+ if not os.path.exists(eval_dir):
138
+ os.makedirs(eval_dir)
139
+
140
+ input_dir = 'output/{}'.format(cfg.exp_name)
141
+ father_path = os.path.abspath(input_dir)
142
+ print(father_path)
143
+ print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
144
+ subprocess.call(['sh', 'dataset/TD500/eval.sh', father_path])
145
+
146
+ if debug:
147
+ source_dir = os.path.join(cfg.vis_dir, '{}_test'.format(cfg.exp_name))
148
+ outpt_dir_base = os.path.join(cfg.output_dir, "Analysis", "eval_view", "TD500")
149
+ if not os.path.exists(outpt_dir_base):
150
+ mkdirs(outpt_dir_base)
151
+
152
+ outpt_dir = os.path.join(outpt_dir_base, "{}_{}_{}".format(cfg.test_size[0], cfg.test_size[1], cfg.checkepoch))
153
+ osmkdir(outpt_dir)
154
+ fid_path1 = '{}/Eval_TD500.txt'.format(eval_dir)
155
+
156
+ analysize_result(source_dir, fid_path1, outpt_dir, "TD500")
157
+
158
+ print('End.')
159
+
160
+
161
+ def data_transfer_ICDAR(contours):
162
+ cnts = list()
163
+ for cont in contours:
164
+ rect = cv2.minAreaRect(cont)
165
+ if min(rect[1][0], rect[1][1]) <= 5:
166
+ continue
167
+ points = cv2.boxPoints(rect)
168
+ points = np.int0(points)
169
+ # print(points.shape)
170
+ # points = np.reshape(points, (4, 2))
171
+ cnts.append(points)
172
+ return cnts
173
+
174
+
175
+ def data_transfer_TD500(contours, res_file, img=None):
176
+ with open(res_file, 'w') as f:
177
+ for cont in contours:
178
+ rect = cv2.minAreaRect(cont)
179
+ if min(rect[1][0], rect[1][1]) <= 5:
180
+ continue
181
+ points = cv2.boxPoints(rect)
182
+ box = np.int0(points)
183
+ cv2.drawContours(img, [box], 0, (0, 255, 0), 3)
184
+
185
+ cx, cy = rect[0]
186
+ w_, h_ = rect[1]
187
+ angle = rect[2]
188
+ mid_ = 0
189
+ if angle > 45:
190
+ angle = 90 - angle
191
+ mid_ = w_;
192
+ w_ = h_;
193
+ h_ = mid_
194
+ elif angle < -45:
195
+ angle = 90 + angle
196
+ mid_ = w_;
197
+ w_ = h_;
198
+ h_ = mid_
199
+ angle = angle / 180 * 3.141592653589
200
+
201
+ x_min = int(cx - w_ / 2)
202
+ x_max = int(cx + w_ / 2)
203
+ y_min = int(cy - h_ / 2)
204
+ y_max = int(cy + h_ / 2)
205
+ f.write('{},{},{},{},{}\r\n'.format(x_min, y_min, x_max, y_max, angle))
206
+
207
+ return img
208
+
209
+
210
+ def data_transfer_MLT2017(contours, res_file):
211
+ with open(res_file, 'w') as f:
212
+ for cont in contours:
213
+ rect = cv2.minAreaRect(cont)
214
+ if min(rect[1][0], rect[1][1]) <= 5:
215
+ continue
216
+ ploy_area = cv2.contourArea(cont)
217
+ rect_area = rect[1][0]*rect[1][1]
218
+ solidity = ploy_area/rect_area
219
+ width = rect[1][0] - np.clip(rect[1][0] * (1-np.sqrt(solidity)), 0, 6)
220
+ height = rect[1][1] - np.clip(rect[1][1] * (1-np.sqrt(solidity)), 0, 4)
221
+ points = cv2.boxPoints((rect[0], (width, height), rect[2]))
222
+ points = np.int0(points)
223
+ p = np.reshape(points, -1)
224
+ f.write('{},{},{},{},{},{},{},{},{}\r\n'
225
+ .format(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], 1))
226
+
227
+
228
+
IndicPhotoOCR/detection/textbpn/util/graph.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ from __future__ import division
3
+ from __future__ import absolute_import
4
+
5
+ import numpy as np
6
+ import time
7
+ from util.misc import norm2
8
+
9
+ class Data(object):
10
+ def __init__(self, name):
11
+ self.__name = name
12
+ self.__links = set()
13
+
14
+ @property
15
+ def name(self):
16
+ return self.__name
17
+
18
+ @property
19
+ def links(self):
20
+ return set(self.__links)
21
+
22
+ def add_link(self, other, score):
23
+ self.__links.add(other)
24
+ other.__links.add(self)
25
+
26
+
27
+ def connected_components(nodes, score_dict, th):
28
+ '''
29
+ conventional connected components searching
30
+ '''
31
+ result = []
32
+ nodes = set(nodes)
33
+ while nodes:
34
+ n = nodes.pop()
35
+ group = {n}
36
+ queue = [n]
37
+ while queue:
38
+ n = queue.pop(0)
39
+ if th is not None:
40
+ neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
41
+ else:
42
+ neighbors = n.links
43
+ neighbors.difference_update(group)
44
+ nodes.difference_update(neighbors)
45
+ group.update(neighbors)
46
+ queue.extend(neighbors)
47
+ result.append(group)
48
+ return result
49
+
50
+
51
+ def connected_components_constraint(nodes, max_sz, score_dict=None, th=None):
52
+ '''
53
+ only use edges whose scores are above `th`
54
+ if a component is larger than `max_sz`, all the nodes in this component are added into `remain` and returned for next iteration.
55
+ '''
56
+ result = []
57
+ remain = set()
58
+ nodes = set(nodes)
59
+ while nodes:
60
+ n = nodes.pop()
61
+ group = {n}
62
+ queue = [n]
63
+ valid = True
64
+ while queue:
65
+ n = queue.pop(0)
66
+ if th is not None:
67
+ neighbors = {l for l in n.links if score_dict[tuple(sorted([n.name, l.name]))] >= th}
68
+ else:
69
+ neighbors = n.links
70
+ neighbors.difference_update(group)
71
+ nodes.difference_update(neighbors)
72
+ group.update(neighbors)
73
+ queue.extend(neighbors)
74
+ if len(group) > max_sz or len(remain.intersection(neighbors)) > 0:
75
+ # if this group is larger than `max_sz`, add the nodes into `remain`
76
+ valid = False
77
+ remain.update(group)
78
+ break
79
+ if valid: # if this group is smaller than or equal to `max_sz`, finalize it.
80
+ result.append(group)
81
+ return result, remain
82
+
83
+
84
+ def graph_propagation_naive(edges, score, th, bboxs=None, dis_thresh=50, pool='avg'):
85
+
86
+ edges = np.sort(edges, axis=1)
87
+
88
+ score_dict = {} # score lookup table
89
+ if pool is None:
90
+ for i, e in enumerate(edges):
91
+ score_dict[e[0], e[1]] = score[i]
92
+ elif pool == 'avg':
93
+ for i, e in enumerate(edges):
94
+ if bboxs is not None:
95
+ box1 = bboxs[e[0]][:8].reshape(4, 2)
96
+ box2 = bboxs[e[1]][:8].reshape(4, 2)
97
+ c1 = np.mean(box1, 0); c2 = np.mean(box2, 0)
98
+ dst = norm2(c1 - c2)
99
+ if dst > dis_thresh:
100
+ score[i] = 0
101
+ if (e[0], e[1]) in score_dict:
102
+ score_dict[e[0], e[1]] = 0.5 * (score_dict[e[0], e[1]] + score[i])
103
+ else:
104
+ score_dict[e[0], e[1]] = score[i]
105
+
106
+ elif pool == 'max':
107
+ for i, e in enumerate(edges):
108
+ if (e[0], e[1]) in score_dict:
109
+ score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]], score[i])
110
+ else:
111
+ score_dict[e[0], e[1]] = score[i]
112
+ else:
113
+ raise ValueError('Pooling operation not supported')
114
+
115
+ nodes = np.sort(np.unique(edges.flatten()))
116
+ mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
117
+ mapping[nodes] = np.arange(nodes.shape[0])
118
+ link_idx = mapping[edges]
119
+ vertex = [Data(n) for n in nodes]
120
+ for l, s in zip(link_idx, score):
121
+ vertex[l[0]].add_link(vertex[l[1]], s)
122
+
123
+ # first iteration
124
+ comps = connected_components(vertex, score_dict,th)
125
+
126
+ return comps
127
+
128
+
129
+ def graph_search(edges, scores, edges_num, th=None):
130
+ # graph search
131
+ scores = scores.reshape((-1, edges_num))
132
+ select_index = np.argsort(scores, axis=1)[:, -2:]
133
+ edges = np.sort(edges, axis=1).reshape((-1, edges_num, 2))
134
+
135
+ score_dict = {}
136
+ for i, ips in enumerate(select_index):
137
+ edg = edges[i]
138
+ si = scores[i]
139
+ for j, idx in enumerate(ips):
140
+ e = edg[idx, :]
141
+ if (e[0], e[1]) in score_dict:
142
+ score_dict[e[0], e[1]] = 0.5 * (score_dict[e[0], e[1]] + si[j])
143
+ else:
144
+ score_dict[e[0], e[1]] = si[j]
145
+
146
+ nodes = np.sort(np.unique(edges.flatten()))
147
+ vertex = [Data(n) for n in nodes]
148
+ for (key, value) in score_dict.items():
149
+ vertex[key[0]].add_link(vertex[key[1]], value)
150
+
151
+ comps = connected_components(vertex, score_dict, th)
152
+
153
+ return comps
154
+
155
+
156
+ def graph_propagation(edges, score, max_sz, step=0.1, beg_th=0.5, pool=None):
157
+
158
+ edges = np.sort(edges, axis=1)
159
+ th = score.min()
160
+ # th = beg_th
161
+ # construct graph
162
+ score_dict = {} # score lookup table
163
+ if pool is None:
164
+ for i,e in enumerate(edges):
165
+ score_dict[e[0], e[1]] = score[i]
166
+ elif pool == 'avg':
167
+ for i,e in enumerate(edges):
168
+ if (e[0], e[1]) in score_dict:
169
+ score_dict[e[0], e[1]] = 0.5*(score_dict[e[0], e[1]] + score[i])
170
+ else:
171
+ score_dict[e[0], e[1]] = score[i]
172
+
173
+ elif pool == 'max':
174
+ for i,e in enumerate(edges):
175
+ if (e[0],e[1]) in score_dict:
176
+ score_dict[e[0], e[1]] = max(score_dict[e[0], e[1]] , score[i])
177
+ else:
178
+ score_dict[e[0], e[1]] = score[i]
179
+ else:
180
+ raise ValueError('Pooling operation not supported')
181
+
182
+ nodes = np.sort(np.unique(edges.flatten()))
183
+ mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
184
+ mapping[nodes] = np.arange(nodes.shape[0])
185
+ link_idx = mapping[edges]
186
+ vertex = [Data(n) for n in nodes]
187
+ for l, s in zip(link_idx, score):
188
+ vertex[l[0]].add_link(vertex[l[1]], s)
189
+
190
+ # first iteration
191
+ comps, remain = connected_components_constraint(vertex, max_sz)
192
+
193
+ # iteration
194
+ components = comps[:]
195
+ while remain:
196
+ th = th + (1 - th) * step
197
+ comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
198
+ components.extend(comps)
199
+ return components
200
+
201
+
202
+ def graph_propagation_soft(edges, score, max_sz, step=0.1, **kwargs):
203
+
204
+ edges = np.sort(edges, axis=1)
205
+ th = score.min()
206
+
207
+ # construct graph
208
+ score_dict = {} # score lookup table
209
+ for i,e in enumerate(edges):
210
+ score_dict[e[0], e[1]] = score[i]
211
+
212
+ nodes = np.sort(np.unique(edges.flatten()))
213
+ mapping = -1 * np.ones((nodes.max()+1), dtype=np.int)
214
+ mapping[nodes] = np.arange(nodes.shape[0])
215
+ link_idx = mapping[edges]
216
+ vertex = [Data(n) for n in nodes]
217
+ for l, s in zip(link_idx, score):
218
+ vertex[l[0]].add_link(vertex[l[1]], s)
219
+
220
+ # first iteration
221
+ comps, remain = connected_components_constraint(vertex, max_sz)
222
+ first_vertex_idx = np.array([mapping[n.name] for c in comps for n in c])
223
+ fusion_vertex_idx = np.setdiff1d(np.arange(nodes.shape[0]), first_vertex_idx, assume_unique=True)
224
+ # iteration
225
+ components = comps[:]
226
+ while remain:
227
+ th = th + (1 - th) * step
228
+ comps, remain = connected_components_constraint(remain, max_sz, score_dict, th)
229
+ components.extend(comps)
230
+ label_dict = {}
231
+ for i,c in enumerate(components):
232
+ for n in c:
233
+ label_dict[n.name] = i
234
+ print('Propagation ...')
235
+ prop_vertex = [vertex[idx] for idx in fusion_vertex_idx]
236
+ label, label_fusion = diffusion(prop_vertex, label_dict, score_dict, **kwargs)
237
+ return label, label_fusion
238
+
239
+
240
+ def diffusion(vertex, label, score_dict, max_depth=5, weight_decay=0.6, normalize=True):
241
+ class BFSNode():
242
+ def __init__(self, node, depth, value):
243
+ self.node = node
244
+ self.depth = depth
245
+ self.value = value
246
+
247
+ label_fusion = {}
248
+ for name in label.keys():
249
+ label_fusion[name] = {label[name]: 1.0}
250
+ prog = 0
251
+ prog_step = len(vertex) // 20
252
+ start = time.time()
253
+ for root in vertex:
254
+ if prog % prog_step == 0:
255
+ print("progress: {} / {}, elapsed time: {}".format(prog, len(vertex), time.time() - start))
256
+ prog += 1
257
+ #queue = {[root, 0, 1.0]}
258
+ queue = {BFSNode(root, 0, 1.0)}
259
+ visited = [root.name]
260
+ root_label = label[root.name]
261
+ while queue:
262
+ curr = queue.pop()
263
+ if curr.depth >= max_depth: # pruning
264
+ continue
265
+ neighbors = curr.node.links
266
+ tmp_value = []
267
+ tmp_neighbor = []
268
+ for n in neighbors:
269
+ if n.name not in visited:
270
+ sub_value = score_dict[tuple(sorted([curr.node.name, n.name]))] * weight_decay * curr.value
271
+ tmp_value.append(sub_value)
272
+ tmp_neighbor.append(n)
273
+ if root_label not in label_fusion[n.name].keys():
274
+ label_fusion[n.name][root_label] = sub_value
275
+ else:
276
+ label_fusion[n.name][root_label] += sub_value
277
+ visited.append(n.name)
278
+ #queue.add([n, curr.depth+1, sub_value])
279
+ sortidx = np.argsort(tmp_value)[::-1]
280
+ for si in sortidx:
281
+ queue.add(BFSNode(tmp_neighbor[si], curr.depth+1, tmp_value[si]))
282
+ if normalize:
283
+ for name in label_fusion.keys():
284
+ summ = sum(label_fusion[name].values())
285
+ for k in label_fusion[name].keys():
286
+ label_fusion[name][k] /= summ
287
+ return label, label_fusion
288
+
289
+
290
+ def clusters2labels(clusters, n_nodes):
291
+ labels = (-1)* np.ones((n_nodes,))
292
+ for ci, c in enumerate(clusters):
293
+ for xid in c:
294
+ labels[xid.name] = ci
295
+ assert np.sum(labels < 0) < 1
296
+ return labels
297
+
298
+
299
+ def single_remove(bbox, pred):
300
+ single_idcs = np.zeros_like(pred)
301
+ pred_unique = np.unique(pred)
302
+ for u in pred_unique:
303
+ idcs = pred == u
304
+ if np.sum(idcs) == 1:
305
+ single_idcs[np.where(idcs)[0][0]] = 1
306
+ remain_idcs = [i for i in range(len(pred)) if not single_idcs[i]]
307
+ remain_idcs = np.asarray(remain_idcs)
308
+ return bbox[remain_idcs, :], pred[remain_idcs]
309
+
IndicPhotoOCR/detection/textbpn/util/io.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding=utf-8
2
+ '''
3
+ Created on 2016年9月27日
4
+
5
+ @author: dengdan
6
+
7
+ Tool functions for file system operation and I/O.
8
+ In the style of linux shell commands
9
+ '''
10
+ import os
11
+ import pickle as pkl
12
+ import subprocess
13
+ import logging
14
+ from . import strs, io
15
+
16
+
17
+ def mkdir(path):
18
+ """
19
+ If the target directory does not exists, it and its parent directories will created.
20
+ """
21
+ path = get_absolute_path(path)
22
+ if not exists(path):
23
+ os.makedirs(path)
24
+ return path
25
+
26
+ def make_parent_dir(path):
27
+ """make the parent directories for a file."""
28
+ parent_dir = get_dir(path)
29
+ mkdir(parent_dir)
30
+
31
+
32
+ def pwd():
33
+ return os.getcwd()
34
+
35
+ def dump(path, obj):
36
+ path = get_absolute_path(path)
37
+ parent_path = get_dir(path)
38
+ mkdir(parent_path)
39
+ with open(path, 'w') as f:
40
+ logging.info('dumping file:' + path);
41
+ pkl.dump(obj, f)
42
+
43
+ def load(path):
44
+ path = get_absolute_path(path)
45
+ with open(path, 'r') as f:
46
+ data = pkl.load(f)
47
+ return data
48
+
49
+ def join_path(a, *p):
50
+ return os.path.join(a, *p)
51
+
52
+ def is_dir(path):
53
+ path = get_absolute_path(path)
54
+ return os.path.isdir(path)
55
+
56
+ is_directory = is_dir
57
+
58
+ def is_path(path):
59
+ path = get_absolute_path(path)
60
+ return os.path.ispath(path)
61
+
62
+ def get_dir(path):
63
+ '''
64
+ return the directory it belongs to.
65
+ if path is a directory itself, itself will be return
66
+ '''
67
+ path = get_absolute_path(path)
68
+ if is_dir(path):
69
+ return path;
70
+ return os.path.split(path)[0]
71
+
72
+ def get_parent_dir(path):
73
+ current_dir = get_dir(path)
74
+ return get_absolute_path(join_path(current_dir, '..'))
75
+
76
+ def get_filename(path):
77
+ return os.path.split(path)[1]
78
+
79
+ def get_absolute_path(p):
80
+ if p.startswith('~'):
81
+ p = os.path.expanduser(p)
82
+ return os.path.abspath(p)
83
+
84
+ def cd(p):
85
+ p = get_absolute_path(p)
86
+ os.chdir(p)
87
+
88
+ def ls(path = '.', suffix = None):
89
+ """
90
+ list files in a directory.
91
+ return file names in a list
92
+ """
93
+ path = get_absolute_path(path)
94
+ files = os.listdir(path)
95
+
96
+ if suffix is None:
97
+ return files
98
+
99
+ filtered = []
100
+ for f in files:
101
+ if string.ends_with(f, suffix, ignore_case = True):
102
+ filtered.append(f)
103
+
104
+ return filtered
105
+
106
+ def find_files(pattern):
107
+ import glob
108
+ return glob.glob(pattern)
109
+
110
+ def read_lines(p):
111
+ """return the text in a file in lines as a list """
112
+ p = get_absolute_path(p)
113
+ f = open(p,'r')
114
+ return f.readlines()
115
+
116
+ def write_lines(p, lines, append_break = False):
117
+ p = get_absolute_path(p)
118
+ make_parent_dir(p)
119
+ with open(p, 'w') as f:
120
+ for line in lines:
121
+ if append_break:
122
+ f.write(line + '\n')
123
+ else:
124
+ f.write(line)
125
+
126
+ def cat(p):
127
+ """return the text in a file as a whole"""
128
+ cmd = 'cat ' + p
129
+ return subprocess.getoutput(cmd)
130
+
131
+ def exists(path):
132
+ path = get_absolute_path(path)
133
+ return os.path.exists(path)
134
+
135
+ def not_exists(path):
136
+ return not exists(path)
137
+
138
+ def load_mat(path):
139
+ import scipy.io as sio # type: ignore
140
+ path = get_absolute_path(path)
141
+ return sio.loadmat(path)
142
+
143
+ def dump_mat(path, dict_obj, append = True):
144
+ import scipy.io as sio # type: ignore
145
+ path = get_absolute_path(path)
146
+ make_parent_dir(path)
147
+ sio.savemat(file_name = path, mdict = dict_obj, appendmat = append)
148
+
149
+ def dir_mat(path):
150
+ '''
151
+ list the variables in mat file.
152
+ return a list: [(name, shape, dtype), ...]
153
+ '''
154
+ import scipy.io as sio # type: ignore
155
+ path = get_absolute_path(path)
156
+ return sio.whosmat(path)
157
+
158
+ SIZE_UNIT_K = 1024
159
+ SIZE_UNIT_M = SIZE_UNIT_K ** 2
160
+ SIZE_UNIT_G = SIZE_UNIT_K ** 3
161
+ def get_file_size(path, unit = SIZE_UNIT_K):
162
+ size = os.path.getsize(get_absolute_path(path))
163
+ return size * 1.0 / unit
164
+
165
+
166
+ def create_h5(path):
167
+ import h5py # type: ignore
168
+ path = get_absolute_path(path)
169
+ make_parent_dir(path)
170
+ return h5py.File(path, 'w');
171
+
172
+ def open_h5(path, mode = 'r'):
173
+ import h5py
174
+ path = get_absolute_path(path)
175
+ return h5py.File(path, mode);
176
+
177
+ def read_h5(h5, key):
178
+ return h5[key][:]
179
+ def read_h5_attrs(h5, key, attrs):
180
+ return h5[key].attrs[attrs]
181
+
182
+ def copy(src, dest):
183
+ io.make_parent_dir(dest)
184
+ import shutil
185
+ shutil.copy(get_absolute_path(src), get_absolute_path(dest))
186
+
187
+ cp = copy
188
+
189
+ def remove(p):
190
+ import os
191
+ os.remove(get_absolute_path(p))
192
+ rm = remove
193
+
194
+ def search(pattern, path, file_only = True):
195
+ """
196
+ Search files whose name matches the give pattern. The search scope
197
+ is the directory and sub-directories of 'path'.
198
+ """
199
+ path = get_absolute_path(path)
200
+ pattern_here = io.join_path(path, pattern)
201
+ targets = []
202
+
203
+ # find matchings in current directory
204
+ candidates = find_files(pattern_here)
205
+ for can in candidates:
206
+ if io.is_dir(can) and file_only:
207
+ continue
208
+ else:
209
+ targets.append(can)
210
+
211
+ # find matching in sub-dirs
212
+ files = ls(path)
213
+ for f in files:
214
+ fpath = io.join_path(path, f)
215
+ if is_dir(fpath):
216
+ targets_in_sub_dir = search(pattern, fpath, file_only)
217
+ targets.extend(targets_in_sub_dir)
218
+ return targets
219
+
220
+ def dump_json(path, data):
221
+ import ujson as json
222
+ path = get_absolute_path(path)
223
+ make_parent_dir(path)
224
+
225
+ with open(path, 'w') as f:
226
+ json.dump(data, f)
227
+ return path
228
+
229
+ def load_json(path):
230
+ import ujson as json
231
+ path = get_absolute_path(path)
232
+ with open(path, 'r') as f:
233
+ return json.load(f)