Weiyun1025 commited on
Commit
e279bc7
·
1 Parent(s): 904d4ae

Upload model

Browse files
Files changed (6) hide show
  1. config.json +38 -0
  2. dcnv3.py +176 -0
  3. dcnv3_func.py +112 -0
  4. intern_image.py +554 -0
  5. intern_image_config.py +42 -0
  6. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act_layer": "GELU",
3
+ "architectures": [
4
+ "InternImageModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "intern_image_config.InternImageConfig",
8
+ "AutoModel": "intern_image.InternImageModel"
9
+ },
10
+ "channels": 192,
11
+ "cls_scale": 1.5,
12
+ "core_op": "DCNv3_pytorch",
13
+ "depths": [
14
+ 5,
15
+ 5,
16
+ 24,
17
+ 5
18
+ ],
19
+ "drop_path_rate": 0.1,
20
+ "drop_path_type": "linear",
21
+ "drop_rate": 0.0,
22
+ "groups": [
23
+ 12,
24
+ 24,
25
+ 48,
26
+ 96
27
+ ],
28
+ "layer_scale": 1e-05,
29
+ "mlp_ratio": 4.0,
30
+ "model_type": "intern_image",
31
+ "norm_layer": "LN",
32
+ "num_classes": 1000,
33
+ "offset_scale": 2.0,
34
+ "post_norm": true,
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.26.1",
37
+ "with_cp": true
38
+ }
dcnv3.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternImage
3
+ # Copyright (c) 2022 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import print_function
9
+ from __future__ import division
10
+
11
+ import warnings
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ from torch.nn.init import xavier_uniform_, constant_
15
+ from .dcnv3_func import dcnv3_core_pytorch
16
+
17
+
18
+ class to_channels_first(nn.Module):
19
+
20
+ def __init__(self):
21
+ super().__init__()
22
+
23
+ def forward(self, x):
24
+ return x.permute(0, 3, 1, 2)
25
+
26
+
27
+ class to_channels_last(nn.Module):
28
+
29
+ def __init__(self):
30
+ super().__init__()
31
+
32
+ def forward(self, x):
33
+ return x.permute(0, 2, 3, 1)
34
+
35
+
36
+ def build_norm_layer(dim,
37
+ norm_layer,
38
+ in_format='channels_last',
39
+ out_format='channels_last',
40
+ eps=1e-6):
41
+ layers = []
42
+ if norm_layer == 'BN':
43
+ if in_format == 'channels_last':
44
+ layers.append(to_channels_first())
45
+ layers.append(nn.BatchNorm2d(dim))
46
+ if out_format == 'channels_last':
47
+ layers.append(to_channels_last())
48
+ elif norm_layer == 'LN':
49
+ if in_format == 'channels_first':
50
+ layers.append(to_channels_last())
51
+ layers.append(nn.LayerNorm(dim, eps=eps))
52
+ if out_format == 'channels_first':
53
+ layers.append(to_channels_first())
54
+ else:
55
+ raise NotImplementedError(
56
+ f'build_norm_layer does not support {norm_layer}')
57
+ return nn.Sequential(*layers)
58
+
59
+
60
+ def build_act_layer(act_layer):
61
+ if act_layer == 'ReLU':
62
+ return nn.ReLU(inplace=True)
63
+ elif act_layer == 'SiLU':
64
+ return nn.SiLU(inplace=True)
65
+ elif act_layer == 'GELU':
66
+ return nn.GELU()
67
+
68
+ raise NotImplementedError(f'build_act_layer does not support {act_layer}')
69
+
70
+
71
+ def _is_power_of_2(n):
72
+ if (not isinstance(n, int)) or (n < 0):
73
+ raise ValueError(
74
+ "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
75
+
76
+ return (n & (n-1) == 0) and n != 0
77
+
78
+
79
+ class DCNv3_pytorch(nn.Module):
80
+ def __init__(
81
+ self, channels=64, kernel_size=3, stride=1,
82
+ pad=1, dilation=1, group=4, offset_scale=1.0,
83
+ act_layer='GELU', norm_layer='LN'):
84
+ """
85
+ DCNv3 Module
86
+ :param channels
87
+ :param kernel_size
88
+ :param stride
89
+ :param pad
90
+ :param dilation
91
+ :param group
92
+ :param offset_scale
93
+ :param act_layer
94
+ :param norm_layer
95
+ """
96
+ super().__init__()
97
+ if channels % group != 0:
98
+ raise ValueError(
99
+ f'channels must be divisible by group, but got {channels} and {group}')
100
+ _d_per_group = channels // group
101
+ # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation
102
+ if not _is_power_of_2(_d_per_group):
103
+ warnings.warn(
104
+ "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 "
105
+ "which is more efficient in our CUDA implementation.")
106
+
107
+ self.offset_scale = offset_scale
108
+ self.channels = channels
109
+ self.kernel_size = kernel_size
110
+ self.stride = stride
111
+ self.dilation = 1
112
+ self.pad = pad
113
+ self.group = group
114
+ self.group_channels = channels // group
115
+ self.offset_scale = offset_scale
116
+
117
+ self.dw_conv = nn.Sequential(
118
+ nn.Conv2d(
119
+ channels,
120
+ channels,
121
+ kernel_size=kernel_size,
122
+ stride=1,
123
+ padding=(kernel_size-1)//2,
124
+ groups=channels),
125
+ build_norm_layer(
126
+ channels,
127
+ norm_layer,
128
+ 'channels_first',
129
+ 'channels_last'),
130
+ build_act_layer(act_layer))
131
+ self.offset = nn.Linear(
132
+ channels,
133
+ group * kernel_size * kernel_size * 2)
134
+ self.mask = nn.Linear(
135
+ channels,
136
+ group * kernel_size * kernel_size)
137
+ self.input_proj = nn.Linear(channels, channels)
138
+ self.output_proj = nn.Linear(channels, channels)
139
+ self._reset_parameters()
140
+
141
+ def _reset_parameters(self):
142
+ constant_(self.offset.weight.data, 0.)
143
+ constant_(self.offset.bias.data, 0.)
144
+ constant_(self.mask.weight.data, 0.)
145
+ constant_(self.mask.bias.data, 0.)
146
+ xavier_uniform_(self.input_proj.weight.data)
147
+ constant_(self.input_proj.bias.data, 0.)
148
+ xavier_uniform_(self.output_proj.weight.data)
149
+ constant_(self.output_proj.bias.data, 0.)
150
+
151
+ def forward(self, input):
152
+ """
153
+ :param query (N, H, W, C)
154
+ :return output (N, H, W, C)
155
+ """
156
+ N, H, W, _ = input.shape
157
+
158
+ x = self.input_proj(input)
159
+
160
+ x1 = input.permute(0, 3, 1, 2)
161
+ x1 = self.dw_conv(x1)
162
+ offset = self.offset(x1)
163
+ mask = self.mask(x1).reshape(N, H, W, self.group, -1)
164
+ mask = F.softmax(mask, -1).reshape(N, H, W, -1)
165
+
166
+ x = dcnv3_core_pytorch(
167
+ x, offset, mask,
168
+ self.kernel_size, self.kernel_size,
169
+ self.stride, self.stride,
170
+ self.pad, self.pad,
171
+ self.dilation, self.dilation,
172
+ self.group, self.group_channels,
173
+ self.offset_scale)
174
+ x = self.output_proj(x)
175
+
176
+ return x
dcnv3_func.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternImage
3
+ # Copyright (c) 2022 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import print_function
9
+ from __future__ import division
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1):
16
+ _, H_, W_, _ = spatial_shapes
17
+ H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
18
+ W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
19
+
20
+ ref_y, ref_x = torch.meshgrid(
21
+ torch.linspace(
22
+ # pad_h + 0.5,
23
+ # H_ - pad_h - 0.5,
24
+ (dilation_h * (kernel_h - 1)) // 2 + 0.5,
25
+ (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h,
26
+ H_out,
27
+ dtype=torch.float32,
28
+ device=device),
29
+ torch.linspace(
30
+ # pad_w + 0.5,
31
+ # W_ - pad_w - 0.5,
32
+ (dilation_w * (kernel_w - 1)) // 2 + 0.5,
33
+ (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w,
34
+ W_out,
35
+ dtype=torch.float32,
36
+ device=device))
37
+ ref_y = ref_y.reshape(-1)[None] / H_
38
+ ref_x = ref_x.reshape(-1)[None] / W_
39
+
40
+ ref = torch.stack((ref_x, ref_y), -1).reshape(
41
+ 1, H_out, W_out, 1, 2)
42
+
43
+ return ref
44
+
45
+
46
+ def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device):
47
+ _, H_, W_, _ = spatial_shapes
48
+ points_list = []
49
+ x, y = torch.meshgrid(
50
+ torch.linspace(
51
+ -((dilation_w * (kernel_w - 1)) // 2),
52
+ -((dilation_w * (kernel_w - 1)) // 2) +
53
+ (kernel_w - 1) * dilation_w, kernel_w,
54
+ dtype=torch.float32,
55
+ device=device),
56
+ torch.linspace(
57
+ -((dilation_h * (kernel_h - 1)) // 2),
58
+ -((dilation_h * (kernel_h - 1)) // 2) +
59
+ (kernel_h - 1) * dilation_h, kernel_h,
60
+ dtype=torch.float32,
61
+ device=device))
62
+
63
+ points_list.extend([x / W_, y / H_])
64
+ grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\
65
+ repeat(1, group, 1).permute(1, 0, 2)
66
+ grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2)
67
+
68
+ return grid
69
+
70
+
71
+ def dcnv3_core_pytorch(
72
+ input, offset, mask, kernel_h,
73
+ kernel_w, stride_h, stride_w, pad_h,
74
+ pad_w, dilation_h, dilation_w, group,
75
+ group_channels, offset_scale):
76
+ # for debug and test only,
77
+ # need to use cuda version instead
78
+ input = F.pad(
79
+ input,
80
+ [0, 0, pad_h, pad_h, pad_w, pad_w])
81
+ N_, H_in, W_in, _ = input.shape
82
+ _, H_out, W_out, _ = offset.shape
83
+
84
+ ref = _get_reference_points(
85
+ input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w)
86
+ grid = _generate_dilation_grids(
87
+ input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device)
88
+ spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\
89
+ repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device)
90
+
91
+ sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \
92
+ offset * offset_scale / spatial_norm
93
+
94
+ P_ = kernel_h * kernel_w
95
+ sampling_grids = 2 * sampling_locations - 1
96
+ # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in
97
+ input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\
98
+ reshape(N_*group, group_channels, H_in, W_in)
99
+ # N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2
100
+ sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\
101
+ flatten(0, 1)
102
+ # N_*group, group_channels, H_out*W_out, P_
103
+ sampling_input_ = F.grid_sample(
104
+ input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False)
105
+
106
+ # (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_)
107
+ mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\
108
+ reshape(N_*group, 1, H_out*W_out, P_)
109
+ output = (sampling_input_ * mask).sum(-1).view(N_,
110
+ group*group_channels, H_out*W_out)
111
+
112
+ return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous()
intern_image.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternImage
3
+ # Copyright (c) 2022 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from transformers import PreTrainedModel
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from .intern_image_config import InternImageConfig
14
+ from .dcnv3 import DCNv3_pytorch
15
+
16
+
17
+ class to_channels_first(nn.Module):
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, x):
23
+ return x.permute(0, 3, 1, 2)
24
+
25
+
26
+ class to_channels_last(nn.Module):
27
+
28
+ def __init__(self):
29
+ super().__init__()
30
+
31
+ def forward(self, x):
32
+ return x.permute(0, 2, 3, 1)
33
+
34
+
35
+ def build_norm_layer(dim,
36
+ norm_layer,
37
+ in_format='channels_last',
38
+ out_format='channels_last',
39
+ eps=1e-6):
40
+ layers = []
41
+ if norm_layer == 'BN':
42
+ if in_format == 'channels_last':
43
+ layers.append(to_channels_first())
44
+ layers.append(nn.BatchNorm2d(dim))
45
+ if out_format == 'channels_last':
46
+ layers.append(to_channels_last())
47
+ elif norm_layer == 'LN':
48
+ if in_format == 'channels_first':
49
+ layers.append(to_channels_last())
50
+ layers.append(nn.LayerNorm(dim, eps=eps))
51
+ if out_format == 'channels_first':
52
+ layers.append(to_channels_first())
53
+ else:
54
+ raise NotImplementedError(
55
+ f'build_norm_layer does not support {norm_layer}')
56
+ return nn.Sequential(*layers)
57
+
58
+
59
+ def build_act_layer(act_layer):
60
+ if act_layer == 'ReLU':
61
+ return nn.ReLU(inplace=True)
62
+ elif act_layer == 'SiLU':
63
+ return nn.SiLU(inplace=True)
64
+ elif act_layer == 'GELU':
65
+ return nn.GELU()
66
+
67
+ raise NotImplementedError(f'build_act_layer does not support {act_layer}')
68
+ class StemLayer(nn.Module):
69
+ r""" Stem layer of InternImage
70
+ Args:
71
+ in_chans (int): number of input channels
72
+ out_chans (int): number of output channels
73
+ act_layer (str): activation layer
74
+ norm_layer (str): normalization layer
75
+ """
76
+
77
+ def __init__(self,
78
+ in_chans=3,
79
+ out_chans=96,
80
+ act_layer='GELU',
81
+ norm_layer='BN'):
82
+ super().__init__()
83
+ self.conv1 = nn.Conv2d(in_chans,
84
+ out_chans // 2,
85
+ kernel_size=3,
86
+ stride=2,
87
+ padding=1)
88
+ self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
89
+ 'channels_first', 'channels_first')
90
+ self.act = build_act_layer(act_layer)
91
+ self.conv2 = nn.Conv2d(out_chans // 2,
92
+ out_chans,
93
+ kernel_size=3,
94
+ stride=2,
95
+ padding=1)
96
+ self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
97
+ 'channels_last')
98
+
99
+ def forward(self, x):
100
+ x = self.conv1(x)
101
+ x = self.norm1(x)
102
+ x = self.act(x)
103
+ x = self.conv2(x)
104
+ x = self.norm2(x)
105
+ return x
106
+
107
+
108
+ class DownsampleLayer(nn.Module):
109
+ r""" Downsample layer of InternImage
110
+ Args:
111
+ channels (int): number of input channels
112
+ norm_layer (str): normalization layer
113
+ """
114
+
115
+ def __init__(self, channels, norm_layer='LN'):
116
+ super().__init__()
117
+ self.conv = nn.Conv2d(channels,
118
+ 2 * channels,
119
+ kernel_size=3,
120
+ stride=2,
121
+ padding=1,
122
+ bias=False)
123
+ self.norm = build_norm_layer(2 * channels, norm_layer,
124
+ 'channels_first', 'channels_last')
125
+
126
+ def forward(self, x):
127
+ x = self.conv(x.permute(0, 3, 1, 2))
128
+ x = self.norm(x)
129
+ return x
130
+
131
+
132
+ class MLPLayer(nn.Module):
133
+ r""" MLP layer of InternImage
134
+ Args:
135
+ in_features (int): number of input features
136
+ hidden_features (int): number of hidden features
137
+ out_features (int): number of output features
138
+ act_layer (str): activation layer
139
+ drop (float): dropout rate
140
+ """
141
+
142
+ def __init__(self,
143
+ in_features,
144
+ hidden_features=None,
145
+ out_features=None,
146
+ act_layer='GELU',
147
+ drop=0.):
148
+ super().__init__()
149
+ out_features = out_features or in_features
150
+ hidden_features = hidden_features or in_features
151
+ self.fc1 = nn.Linear(in_features, hidden_features)
152
+ self.act = build_act_layer(act_layer)
153
+ self.fc2 = nn.Linear(hidden_features, out_features)
154
+ self.drop = nn.Dropout(drop)
155
+
156
+ def forward(self, x):
157
+ x = self.fc1(x)
158
+ x = self.act(x)
159
+ x = self.drop(x)
160
+ x = self.fc2(x)
161
+ x = self.drop(x)
162
+ return x
163
+
164
+
165
+ class InternImageLayer(nn.Module):
166
+ r""" Basic layer of InternImage
167
+ Args:
168
+ core_op (nn.Module): core operation of InternImage
169
+ channels (int): number of input channels
170
+ groups (list): Groups of each block.
171
+ mlp_ratio (float): ratio of mlp hidden features to input channels
172
+ drop (float): dropout rate
173
+ drop_path (float): drop path rate
174
+ act_layer (str): activation layer
175
+ norm_layer (str): normalization layer
176
+ post_norm (bool): whether to use post normalization
177
+ layer_scale (float): layer scale
178
+ offset_scale (float): offset scale
179
+ with_cp (bool): whether to use checkpoint
180
+ """
181
+
182
+ def __init__(self,
183
+ core_op,
184
+ channels,
185
+ groups,
186
+ mlp_ratio=4.,
187
+ drop=0.,
188
+ drop_path=0.,
189
+ act_layer='GELU',
190
+ norm_layer='LN',
191
+ post_norm=False,
192
+ layer_scale=None,
193
+ offset_scale=1.0,
194
+ with_cp=False):
195
+ super().__init__()
196
+ self.channels = channels
197
+ self.groups = groups
198
+ self.mlp_ratio = mlp_ratio
199
+ self.with_cp = with_cp
200
+
201
+ self.norm1 = build_norm_layer(channels, 'LN')
202
+ self.post_norm = post_norm
203
+ self.dcn = core_op(channels=channels,
204
+ kernel_size=3,
205
+ stride=1,
206
+ pad=1,
207
+ dilation=1,
208
+ group=groups,
209
+ offset_scale=offset_scale,
210
+ act_layer=act_layer,
211
+ norm_layer=norm_layer)
212
+ self.drop_path = DropPath(drop_path) if drop_path > 0. \
213
+ else nn.Identity()
214
+ self.norm2 = build_norm_layer(channels, 'LN')
215
+ self.mlp = MLPLayer(in_features=channels,
216
+ hidden_features=int(channels * mlp_ratio),
217
+ act_layer=act_layer,
218
+ drop=drop)
219
+ self.layer_scale = layer_scale is not None
220
+ if self.layer_scale:
221
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels),
222
+ requires_grad=True)
223
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
224
+ requires_grad=True)
225
+
226
+ def forward(self, x):
227
+
228
+ def _inner_forward(x):
229
+ if not self.layer_scale:
230
+ if self.post_norm:
231
+ x = x + self.drop_path(self.norm1(self.dcn(x)))
232
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
233
+ else:
234
+ x = x + self.drop_path(self.dcn(self.norm1(x)))
235
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
236
+ return x
237
+ if self.post_norm:
238
+ x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
239
+ x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
240
+ else:
241
+ x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
242
+ x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
243
+ return x
244
+
245
+ if self.with_cp and x.requires_grad:
246
+ x = checkpoint.checkpoint(_inner_forward, x)
247
+ else:
248
+ x = _inner_forward(x)
249
+ return x
250
+
251
+
252
+ class InternImageBlock(nn.Module):
253
+ r""" Block of InternImage
254
+ Args:
255
+ core_op (nn.Module): core operation of InternImage
256
+ channels (int): number of input channels
257
+ depths (list): Depth of each block.
258
+ groups (list): Groups of each block.
259
+ mlp_ratio (float): ratio of mlp hidden features to input channels
260
+ drop (float): dropout rate
261
+ drop_path (float): drop path rate
262
+ act_layer (str): activation layer
263
+ norm_layer (str): normalization layer
264
+ post_norm (bool): whether to use post normalization
265
+ layer_scale (float): layer scale
266
+ offset_scale (float): offset scale
267
+ with_cp (bool): whether to use checkpoint
268
+ """
269
+
270
+ def __init__(self,
271
+ core_op,
272
+ channels,
273
+ depth,
274
+ groups,
275
+ downsample=True,
276
+ mlp_ratio=4.,
277
+ drop=0.,
278
+ drop_path=0.,
279
+ act_layer='GELU',
280
+ norm_layer='LN',
281
+ post_norm=False,
282
+ offset_scale=1.0,
283
+ layer_scale=None,
284
+ with_cp=False):
285
+ super().__init__()
286
+ self.channels = channels
287
+ self.depth = depth
288
+ self.post_norm = post_norm
289
+
290
+ self.blocks = nn.ModuleList([
291
+ InternImageLayer(core_op=core_op,
292
+ channels=channels,
293
+ groups=groups,
294
+ mlp_ratio=mlp_ratio,
295
+ drop=drop,
296
+ drop_path=drop_path[i] if isinstance(
297
+ drop_path, list) else drop_path,
298
+ act_layer=act_layer,
299
+ norm_layer=norm_layer,
300
+ post_norm=post_norm,
301
+ layer_scale=layer_scale,
302
+ offset_scale=offset_scale,
303
+ with_cp=with_cp) for i in range(depth)
304
+ ])
305
+ if not self.post_norm:
306
+ self.norm = build_norm_layer(channels, 'LN')
307
+ self.downsample = DownsampleLayer(
308
+ channels=channels, norm_layer=norm_layer) if downsample else None
309
+
310
+ def forward(self, x, return_wo_downsample=False):
311
+ for blk in self.blocks:
312
+ x = blk(x)
313
+ if not self.post_norm:
314
+ x = self.norm(x)
315
+ if return_wo_downsample:
316
+ x_ = x
317
+ if self.downsample is not None:
318
+ x = self.downsample(x)
319
+
320
+ if return_wo_downsample:
321
+ return x, x_
322
+ return x
323
+
324
+
325
+ class InternImage(nn.Module):
326
+ r""" InternImage
327
+ A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
328
+ https://arxiv.org/pdf/2103.14030
329
+ Args:
330
+ core_op (str): Core operator. Default: 'DCNv3'
331
+ channels (int): Number of the first stage. Default: 64
332
+ depths (list): Depth of each block. Default: [3, 4, 18, 5]
333
+ groups (list): Groups of each block. Default: [3, 6, 12, 24]
334
+ num_classes (int): Number of classes. Default: 1000
335
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
336
+ drop_rate (float): Probability of an element to be zeroed. Default: 0.
337
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
338
+ act_layer (str): Activation layer. Default: 'GELU'
339
+ norm_layer (str): Normalization layer. Default: 'LN'
340
+ layer_scale (bool): Whether to use layer scale. Default: False
341
+ cls_scale (bool): Whether to use class scale. Default: False
342
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
343
+ """
344
+
345
+ def __init__(self,
346
+ core_op='DCNv3_pytorch',
347
+ channels=64,
348
+ depths=[3, 4, 18, 5],
349
+ groups=[3, 6, 12, 24],
350
+ num_classes=1000,
351
+ mlp_ratio=4.,
352
+ drop_rate=0.,
353
+ drop_path_rate=0.2,
354
+ drop_path_type='linear',
355
+ act_layer='GELU',
356
+ norm_layer='LN',
357
+ layer_scale=None,
358
+ offset_scale=1.0,
359
+ post_norm=False,
360
+ cls_scale=1.5,
361
+ with_cp=False,
362
+ **kwargs):
363
+ super().__init__()
364
+ assert core_op == 'DCNv3_pytorch'
365
+ core_op = DCNv3_pytorch
366
+
367
+ self.core_op = core_op
368
+ self.num_classes = num_classes
369
+ self.num_levels = len(depths)
370
+ self.depths = depths
371
+ self.channels = channels
372
+ self.num_features = int(channels * 2**(self.num_levels - 1))
373
+ self.post_norm = post_norm
374
+ self.mlp_ratio = mlp_ratio
375
+ print(f'using core type: {core_op}')
376
+ print(f'using activation layer: {act_layer}')
377
+ print(f'using main norm layer: {norm_layer}')
378
+ print(f'using dpr: {drop_path_type}, {drop_path_rate}')
379
+
380
+ in_chans = 3
381
+ self.patch_embed = StemLayer(in_chans=in_chans,
382
+ out_chans=channels,
383
+ act_layer=act_layer,
384
+ norm_layer=norm_layer)
385
+ self.pos_drop = nn.Dropout(p=drop_rate)
386
+
387
+ dpr = [
388
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
389
+ ]
390
+ if drop_path_type == 'uniform':
391
+ for i in range(len(dpr)):
392
+ dpr[i] = drop_path_rate
393
+
394
+ self.levels = nn.ModuleList()
395
+ for i in range(self.num_levels):
396
+ level = InternImageBlock(
397
+ core_op=core_op,
398
+ channels=int(channels * 2**i),
399
+ depth=depths[i],
400
+ groups=groups[i],
401
+ mlp_ratio=self.mlp_ratio,
402
+ drop=drop_rate,
403
+ drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
404
+ act_layer=act_layer,
405
+ norm_layer=norm_layer,
406
+ post_norm=post_norm,
407
+ downsample=(i < self.num_levels - 1),
408
+ layer_scale=layer_scale,
409
+ offset_scale=offset_scale,
410
+ with_cp=with_cp)
411
+ self.levels.append(level)
412
+
413
+ self.conv_head = nn.Sequential(
414
+ nn.Conv2d(self.num_features,
415
+ int(self.num_features * cls_scale),
416
+ kernel_size=1,
417
+ bias=False),
418
+ build_norm_layer(int(self.num_features * cls_scale), 'BN',
419
+ 'channels_first', 'channels_first'),
420
+ build_act_layer(act_layer))
421
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
422
+ self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
423
+ if num_classes > 0 else nn.Identity()
424
+ self.num_layers = len(depths)
425
+ self.apply(self._init_weights)
426
+ self.apply(self._init_deform_weights)
427
+
428
+ def _init_weights(self, m):
429
+ if isinstance(m, nn.Linear):
430
+ trunc_normal_(m.weight, std=.02)
431
+ if isinstance(m, nn.Linear) and m.bias is not None:
432
+ nn.init.constant_(m.bias, 0)
433
+ elif isinstance(m, nn.LayerNorm):
434
+ nn.init.constant_(m.bias, 0)
435
+ nn.init.constant_(m.weight, 1.0)
436
+
437
+ def _init_deform_weights(self, m):
438
+ if isinstance(m, self.core_op):
439
+ m._reset_parameters()
440
+
441
+ @torch.jit.ignore
442
+ def lr_decay_keywards(self, decay_ratio=0.87):
443
+ lr_ratios = {}
444
+
445
+ # blocks
446
+ idx = 0
447
+ for i in range(4):
448
+ layer_num = 3 - i # 3 2 1 0
449
+ for j in range(self.depths[layer_num]):
450
+ block_num = self.depths[layer_num] - j - 1
451
+ tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num)
452
+ decay = 1.0 * (decay_ratio**idx)
453
+ lr_ratios[tag] = decay
454
+ idx += 1
455
+ # patch_embed (before stage-1)
456
+ lr_ratios["patch_embed"] = lr_ratios['levels.0.blocks.0.']
457
+ # levels.0.downsample (between stage-1 and stage-2)
458
+ lr_ratios["levels.0.downsample"] = lr_ratios['levels.1.blocks.0.']
459
+ lr_ratios["levels.0.norm"] = lr_ratios['levels.1.blocks.0.']
460
+ # levels.1.downsample (between stage-2 and stage-3)
461
+ lr_ratios["levels.1.downsample"] = lr_ratios['levels.2.blocks.0.']
462
+ lr_ratios["levels.1.norm"] = lr_ratios['levels.2.blocks.0.']
463
+ # levels.2.downsample (between stage-3 and stage-4)
464
+ lr_ratios["levels.2.downsample"] = lr_ratios['levels.3.blocks.0.']
465
+ lr_ratios["levels.2.norm"] = lr_ratios['levels.3.blocks.0.']
466
+ return lr_ratios
467
+
468
+ def forward_features(self, x):
469
+ x = self.patch_embed(x)
470
+ x = self.pos_drop(x)
471
+
472
+ for level in self.levels:
473
+ x = level(x)
474
+
475
+ x = self.conv_head(x.permute(0, 3, 1, 2))
476
+ x = self.avgpool(x)
477
+ x = torch.flatten(x, 1)
478
+ return x
479
+
480
+ def forward_features_seq_out(self, x):
481
+ x = self.patch_embed(x)
482
+ x = self.pos_drop(x)
483
+
484
+ seq_out = []
485
+ for level in self.levels:
486
+ x, x_ = level(x, return_wo_downsample=True)
487
+ seq_out.append(x_)
488
+ return seq_out
489
+
490
+ def forward(self, x):
491
+ x = self.forward_features(x)
492
+ x = self.head(x)
493
+ return x
494
+
495
+
496
+ class InternImageModel(PreTrainedModel):
497
+ config_class = InternImageConfig
498
+
499
+ def __init__(self, config):
500
+ super().__init__(config)
501
+ self.model = InternImage(
502
+ core_op=config.core_op,
503
+ channels=config.channels,
504
+ depths=config.depths,
505
+ groups=config.groups,
506
+ num_classes=config.num_classes,
507
+ mlp_ratio=config.mlp_ratio,
508
+ drop_rate=config.drop_rate,
509
+ drop_path_rate=config.drop_path_rate,
510
+ drop_path_type=config.drop_path_type,
511
+ act_layer=config.act_layer,
512
+ norm_layer=config.norm_layer,
513
+ layer_scale=config.layer_scale,
514
+ offset_scale=config.offset_scale,
515
+ post_norm=config.post_norm,
516
+ cls_scale=config.cls_scale,
517
+ with_cp=config.with_cp,
518
+ )
519
+
520
+ def forward(self, tensor):
521
+ return self.model.forward_features(tensor)
522
+
523
+ class InternImageModelForImageClassification(PreTrainedModel):
524
+ config_class = InternImageConfig
525
+
526
+ def __init__(self, config):
527
+ super().__init__(config)
528
+ self.model = InternImage(
529
+ core_op=config.core_op,
530
+ channels=config.channels,
531
+ depths=config.depths,
532
+ groups=config.groups,
533
+ num_classes=config.num_classes,
534
+ mlp_ratio=config.mlp_ratio,
535
+ drop_rate=config.drop_rate,
536
+ drop_path_rate=config.drop_path_rate,
537
+ drop_path_type=config.drop_path_type,
538
+ act_layer=config.act_layer,
539
+ norm_layer=config.norm_layer,
540
+ layer_scale=config.layer_scale,
541
+ offset_scale=config.offset_scale,
542
+ post_norm=config.post_norm,
543
+ cls_scale=config.cls_scale,
544
+ with_cp=config.with_cp,
545
+ )
546
+
547
+ def forward(self, tensor, labels=None):
548
+ logits = self.model(tensor)
549
+
550
+ if labels is not None:
551
+ loss = F.cross_entropy(logits, labels)
552
+ return {'loss': loss, 'logits': logits}
553
+
554
+ return {'logits': logits}
intern_image_config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class InternImageConfig(PretrainedConfig):
4
+ model_type = "intern_image"
5
+
6
+ def __init__(
7
+ self,
8
+ core_op='DCNv3_pytorch',
9
+ channels=64,
10
+ depths=(4, 4, 18, 4),
11
+ groups=(4, 8, 16, 32),
12
+ num_classes=1000,
13
+ mlp_ratio=4.,
14
+ drop_rate=0.,
15
+ drop_path_rate=0.1,
16
+ drop_path_type='linear',
17
+ act_layer='GELU',
18
+ norm_layer='LN',
19
+ layer_scale=None,
20
+ offset_scale=1.0,
21
+ post_norm=False,
22
+ cls_scale=1.5,
23
+ with_cp=False,
24
+ **kwargs,
25
+ ):
26
+ self.core_op = core_op
27
+ self.channels = channels
28
+ self.depths = depths
29
+ self.groups = groups
30
+ self.num_classes = num_classes
31
+ self.mlp_ratio = mlp_ratio
32
+ self.drop_rate = drop_rate
33
+ self.drop_path_rate = drop_path_rate
34
+ self.drop_path_type = drop_path_type
35
+ self.act_layer = act_layer
36
+ self.norm_layer = norm_layer
37
+ self.layer_scale = layer_scale
38
+ self.offset_scale = offset_scale
39
+ self.post_norm = post_norm
40
+ self.cls_scale = cls_scale
41
+ self.with_cp = with_cp
42
+ super().__init__(**kwargs)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e57f1a2e26743105b697ee199d0abde18307fa9475fecb361d19d39883b9ecb9
3
+ size 1339575003