weidai00 commited on
Commit
1088f5d
·
verified ·
1 Parent(s): 4cc70e2

Update AV/models/layers.py

Browse files
Files changed (1) hide show
  1. AV/models/layers.py +674 -674
AV/models/layers.py CHANGED
@@ -1,674 +1,674 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import torch
4
- from torch import nn
5
- import torch.nn.functional as F
6
- # from timm.models.layers.cbam import CbamModule
7
- import numpy as np
8
- from einops import rearrange, repeat
9
- import math
10
-
11
-
12
- class ConvBn2d(nn.Module):
13
- def __init__(self, in_channels, out_channels, kernel_size, padding):
14
- super(ConvBn2d, self).__init__()
15
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
16
- self.bn = nn.BatchNorm2d(out_channels)
17
-
18
- def forward(self, x):
19
- x = self.conv(x)
20
- x = self.bn(x)
21
- return x
22
-
23
-
24
- class sSE(nn.Module):
25
- def __init__(self, out_channels):
26
- super(sSE, self).__init__()
27
- self.conv = ConvBn2d(in_channels=out_channels, out_channels=1, kernel_size=1, padding=0)
28
-
29
- def forward(self, x):
30
- x = self.conv(x)
31
- # print('spatial',x.size())
32
- x = F.sigmoid(x)
33
- return x
34
-
35
-
36
- class cSE(nn.Module):
37
- def __init__(self, out_channels):
38
- super(cSE, self).__init__()
39
- self.conv1 = ConvBn2d(in_channels=out_channels, out_channels=int(out_channels / 2), kernel_size=1, padding=0)
40
- self.conv2 = ConvBn2d(in_channels=int(out_channels / 2), out_channels=out_channels, kernel_size=1, padding=0)
41
-
42
- def forward(self, x):
43
- x = nn.AvgPool2d(x.size()[2:])(x)
44
- # print('channel',x.size())
45
- x = self.conv1(x)
46
- x = F.relu(x)
47
- x = self.conv2(x)
48
- x = F.sigmoid(x)
49
- return x
50
-
51
-
52
- class scSEBlock(nn.Module):
53
- def __init__(self, out_channels):
54
- super(scSEBlock, self).__init__()
55
- self.spatial_gate = sSE(out_channels)
56
- self.channel_gate = cSE(out_channels)
57
-
58
- def forward(self, x):
59
- g1 = self.spatial_gate(x)
60
- g2 = self.channel_gate(x)
61
- x = g1 * x + g2 * x
62
- return x
63
-
64
-
65
- class SaveFeatures():
66
- features = None
67
-
68
- def __init__(self, m):
69
- self.hook = m.register_forward_hook(self.hook_fn)
70
-
71
- def hook_fn(self, module, input, output):
72
- # print('input',input)
73
- # print('output',output.size())
74
- if len(output.shape) == 3:
75
- B, L, C = output.shape
76
- h = int(L ** 0.5)
77
- output = output.view(B, h, h, C)
78
-
79
- output = output.permute(0, 3, 1, 2).contiguous()
80
- if len(output.shape) == 4 and output.shape[2] != output.shape[3]:
81
- output = output.permute(0, 3, 1, 2).contiguous()
82
- # print(module)
83
- self.features = output
84
-
85
- def remove(self):
86
- self.hook.remove()
87
-
88
-
89
- class DBlock(nn.Module):
90
-
91
- def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None):
92
-
93
- super(DBlock, self).__init__()
94
-
95
- self.conv1 = nn.Sequential(
96
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
97
- nn.BatchNorm2d(out_channels),
98
- nn.ReLU(inplace=True),
99
- )
100
-
101
- if attention_type == 'scse':
102
- self.attention1 = scSEBlock(in_channels)
103
- elif attention_type == 'cbam':
104
- self.attention1 = nn.Identity()
105
-
106
- elif attention_type == 'transformer':
107
-
108
- self.attention1 = nn.Identity()
109
-
110
-
111
- else:
112
- self.attention1 = nn.Identity()
113
-
114
- self.conv2 = \
115
- nn.Sequential(
116
- nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
117
- nn.BatchNorm2d(out_channels),
118
- nn.ReLU(inplace=True),
119
- )
120
-
121
- self.conv3 = nn.Sequential(
122
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
123
- nn.BatchNorm2d(out_channels),
124
- nn.ReLU(inplace=True),
125
- )
126
- if attention_type == 'scse':
127
- self.attention2 = scSEBlock(out_channels)
128
- elif attention_type == 'cbam':
129
- self.attention2 = CbamModule(channels=out_channels)
130
-
131
- elif attention_type == 'transformer':
132
- self.attention2 = nn.Identity()
133
-
134
- else:
135
- self.attention2 = nn.Identity()
136
-
137
- def forward(self, x, skip):
138
- if x.shape[1] != skip.shape[1]:
139
- x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
140
-
141
- # print(x.shape,skip.shape)
142
- x = self.attention1(x)
143
- x = self.conv1(x)
144
-
145
- x = torch.cat([x, skip], dim=1)
146
-
147
- x = self.conv2(x)
148
- x = self.conv3(x)
149
- x = self.attention2(x)
150
-
151
- return x
152
-
153
-
154
- class DBlock_res(nn.Module):
155
-
156
- def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None):
157
-
158
- super(DBlock_res, self).__init__()
159
-
160
- self.conv1 = nn.Sequential(
161
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
162
- nn.BatchNorm2d(out_channels),
163
- nn.ReLU(inplace=True),
164
- )
165
-
166
- if attention_type == 'scse':
167
- self.attention1 = scSEBlock(in_channels)
168
- elif attention_type == 'cbam':
169
- self.attention1 = CbamModule(channels=in_channels)
170
-
171
- elif attention_type == 'transformer':
172
-
173
- self.attention1 = nn.Identity()
174
-
175
-
176
- else:
177
- self.attention1 = nn.Identity()
178
-
179
- self.conv2 = \
180
- nn.Sequential(
181
- nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1),
182
- nn.BatchNorm2d(out_channels),
183
- nn.ReLU(inplace=True),
184
- )
185
-
186
- self.conv3 = nn.Sequential(
187
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
188
- nn.BatchNorm2d(out_channels),
189
- nn.ReLU(inplace=True),
190
- )
191
- if attention_type == 'scse':
192
- self.attention2 = scSEBlock(out_channels)
193
- elif attention_type == 'cbam':
194
- self.attention2 = CbamModule(channels=out_channels)
195
-
196
- elif attention_type == 'transformer':
197
- self.attention2 = nn.Identity()
198
-
199
- else:
200
- self.attention2 = nn.Identity()
201
-
202
- def forward(self, x, skip):
203
-
204
- x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
205
-
206
- # print(x.shape,skip.shape)
207
- x = self.attention1(x)
208
- x = self.conv1(x)
209
-
210
- x = torch.cat([x, skip], dim=1)
211
-
212
- x = self.conv2(x)
213
- x = self.conv3(x)
214
- x = self.attention2(x)
215
-
216
- return x
217
-
218
-
219
- class DBlock_att(nn.Module):
220
-
221
- def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type='transformer'):
222
-
223
- super(DBlock_att, self).__init__()
224
-
225
- self.conv1 = nn.Sequential(
226
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
227
- nn.BatchNorm2d(out_channels),
228
- nn.ReLU(inplace=True),
229
- )
230
-
231
- if attention_type == 'scse':
232
- self.attention1 = scSEBlock(in_channels)
233
- elif attention_type == 'cbam':
234
- self.attention1 = CbamModule(channels=in_channels)
235
-
236
- elif attention_type == 'transformer':
237
-
238
- self.attention1 = nn.Identity()
239
-
240
-
241
- else:
242
- self.attention1 = nn.Identity()
243
-
244
- self.conv2 = \
245
- nn.Sequential(
246
- nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
247
- nn.BatchNorm2d(out_channels),
248
- nn.ReLU(inplace=True),
249
- )
250
-
251
- self.conv3 = nn.Sequential(
252
- nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
253
- nn.BatchNorm2d(out_channels),
254
- nn.ReLU(inplace=True),
255
- )
256
- if attention_type == 'scse':
257
- self.attention2 = scSEBlock(out_channels)
258
- elif attention_type == 'cbam':
259
- self.attention2 = CbamModule(channels=out_channels)
260
-
261
- elif attention_type == 'transformer':
262
- self.attention2 = nn.Identity()
263
-
264
- else:
265
- self.attention2 = nn.Identity()
266
-
267
- def forward(self, x, skip):
268
- if x.shape[1] != skip.shape[1]:
269
- x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
270
-
271
- # print(x.shape,skip.shape)
272
- x = self.attention1(x)
273
- x = self.conv1(x)
274
-
275
- x = torch.cat([x, skip], dim=1)
276
- x = self.conv2(x)
277
- x = self.conv3(x)
278
-
279
- x = self.attention2(x)
280
-
281
- return x
282
-
283
-
284
- class SegmentationHead(nn.Module):
285
- def __init__(self, in_channels, num_class, kernel_size=3, upsample=4):
286
- super(SegmentationHead, self).__init__()
287
- self.upsample = nn.UpsamplingBilinear2d(scale_factor=upsample) if upsample > 1 else nn.Identity()
288
- self.conv = nn.Conv2d(in_channels, num_class, kernel_size=kernel_size, padding=kernel_size // 2)
289
-
290
- def forward(self, x):
291
- x = self.upsample(x)
292
- x = self.conv(x)
293
- return x
294
-
295
-
296
- class AV_Cross(nn.Module):
297
-
298
- def __init__(self, channels=2, r=2, residual=True, block=4, kernel_size=1):
299
- super(AV_Cross, self).__init__()
300
- out_channels = int(channels // r)
301
- self.residual = residual
302
- self.block = block
303
- self.bn = nn.BatchNorm2d(3)
304
- self.relu = False
305
- self.kernel_size = kernel_size
306
- self.a_ve_att = nn.ModuleList()
307
- self.v_ve_att = nn.ModuleList()
308
- self.ve_att = nn.ModuleList()
309
- for i in range(self.block):
310
- self.a_ve_att.append(nn.Sequential(
311
- nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
312
- padding=(self.kernel_size - 1) // 2),
313
- nn.BatchNorm2d(out_channels),
314
- ))
315
- self.v_ve_att.append(nn.Sequential(
316
- nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
317
- padding=(self.kernel_size - 1) // 2),
318
- nn.BatchNorm2d(out_channels),
319
- ))
320
- self.ve_att.append(nn.Sequential(
321
- nn.Conv2d(3, out_channels, kernel_size=self.kernel_size, stride=1, padding=(self.kernel_size - 1) // 2),
322
- nn.BatchNorm2d(out_channels),
323
- ))
324
- self.sigmoid = nn.Sigmoid()
325
- self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
326
-
327
- def forward(self, x):
328
- a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :]
329
- for i in range(self.block):
330
- # x = self.relu(self.bn(x))
331
- a_ve = torch.concat([a, ve], dim=1)
332
- v_ve = torch.concat([v, ve], dim=1)
333
- a_v_ve = torch.concat([a, ve, v], dim=1)
334
- x_a = self.a_ve_att[i](a_ve)
335
- x_v = self.v_ve_att[i](v_ve)
336
- x_a_v = self.ve_att[i](a_v_ve)
337
- a_weight = self.sigmoid(x_a)
338
- v_weight = self.sigmoid(x_v)
339
- ve_weight = self.sigmoid(x_a_v)
340
- if self.residual:
341
- a = a + a * a_weight
342
- v = v + v * v_weight
343
- ve = ve + ve * ve_weight
344
- else:
345
- a = a * a_weight
346
- v = v * v_weight
347
- ve = ve * ve_weight
348
-
349
- out = torch.concat([a, ve, v], dim=1)
350
-
351
- if self.relu:
352
- out = F.relu(out)
353
- out = self.final(out)
354
- return out
355
-
356
-
357
- class AV_Cross_v2(nn.Module):
358
-
359
- def __init__(self, channels=2, r=2, residual=True, block=1, relu=False, kernel_size=1):
360
- super(AV_Cross_v2, self).__init__()
361
- out_channels = int(channels // r)
362
- self.residual = residual
363
- self.block = block
364
- self.relu = relu
365
- self.kernel_size = kernel_size
366
- self.a_ve_att = nn.ModuleList()
367
- self.v_ve_att = nn.ModuleList()
368
- self.ve_att = nn.ModuleList()
369
- for i in range(self.block):
370
- self.a_ve_att.append(nn.Sequential(
371
- nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
372
- padding=(self.kernel_size - 1) // 2),
373
- nn.BatchNorm2d(out_channels)
374
- ))
375
- self.v_ve_att.append(nn.Sequential(
376
- nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
377
- padding=(self.kernel_size - 1) // 2),
378
- nn.BatchNorm2d(out_channels)
379
- ))
380
- self.ve_att.append(nn.Sequential(
381
- nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
382
- padding=(self.kernel_size - 1) // 2),
383
- nn.BatchNorm2d(out_channels)
384
- ))
385
-
386
- self.sigmoid = nn.Sigmoid()
387
- self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
388
-
389
- def forward(self, x):
390
- a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :]
391
-
392
- for i in range(self.block):
393
- tmp = torch.cat([a, ve, v], dim=1)
394
- a_ve = torch.concat([a, ve], dim=1)
395
- a_ve = torch.cat([torch.max(a_ve, dim=1, keepdim=True)[0], torch.mean(a_ve, dim=1, keepdim=True)], dim=1)
396
- v_ve = torch.concat([v, ve], dim=1)
397
- v_ve = torch.cat([torch.max(v_ve, dim=1, keepdim=True)[0], torch.mean(v_ve, dim=1, keepdim=True)], dim=1)
398
- a_v_ve = torch.concat([torch.max(tmp, dim=1, keepdim=True)[0], torch.mean(tmp, dim=1, keepdim=True)], dim=1)
399
-
400
- a_ve = self.a_ve_att[i](a_ve)
401
- v_ve = self.v_ve_att[i](v_ve)
402
- a_v_ve = self.ve_att[i](a_v_ve)
403
- a_weight = self.sigmoid(a_ve)
404
- v_weight = self.sigmoid(v_ve)
405
- ve_weight = self.sigmoid(a_v_ve)
406
- if self.residual:
407
- a = a + a * a_weight
408
- v = v + v * v_weight
409
- ve = ve + ve * ve_weight
410
- else:
411
- a = a * a_weight
412
- v = v * v_weight
413
- ve = ve * ve_weight
414
-
415
- out = torch.concat([a, ve, v], dim=1)
416
-
417
- if self.relu:
418
- out = F.relu(out)
419
- out = self.final(out)
420
- return out
421
-
422
-
423
- class MultiHeadAttention(nn.Module):
424
- def __init__(self, embedding_dim, head_num):
425
- super().__init__()
426
-
427
- self.head_num = head_num
428
- self.dk = (embedding_dim // head_num) ** (1 / 2)
429
-
430
- self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
431
- self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)
432
-
433
- def forward(self, x, mask=None):
434
- qkv = self.qkv_layer(x)
435
-
436
- query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num))
437
- energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk
438
-
439
- if mask is not None:
440
- energy = energy.masked_fill(mask, -np.inf)
441
-
442
- attention = torch.softmax(energy, dim=-1)
443
-
444
- x = torch.einsum("... i j , ... j d -> ... i d", attention, value)
445
-
446
- x = rearrange(x, "b h t d -> b t (h d)")
447
- x = self.out_attention(x)
448
-
449
- return x
450
-
451
-
452
- class MLP(nn.Module):
453
- def __init__(self, embedding_dim, mlp_dim):
454
- super().__init__()
455
-
456
- self.mlp_layers = nn.Sequential(
457
- nn.Linear(embedding_dim, mlp_dim),
458
- nn.GELU(),
459
- nn.Dropout(0.1),
460
- nn.Linear(mlp_dim, embedding_dim),
461
- nn.Dropout(0.1)
462
- )
463
-
464
- def forward(self, x):
465
- x = self.mlp_layers(x)
466
-
467
- return x
468
-
469
-
470
- class TransformerEncoderBlock(nn.Module):
471
- def __init__(self, embedding_dim, head_num, mlp_dim):
472
- super().__init__()
473
-
474
- self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
475
- self.mlp = MLP(embedding_dim, mlp_dim)
476
-
477
- self.layer_norm1 = nn.LayerNorm(embedding_dim)
478
- self.layer_norm2 = nn.LayerNorm(embedding_dim)
479
-
480
- self.dropout = nn.Dropout(0.1)
481
-
482
- def forward(self, x):
483
- _x = self.multi_head_attention(x)
484
- _x = self.dropout(_x)
485
- x = x + _x
486
- x = self.layer_norm1(x)
487
-
488
- _x = self.mlp(x)
489
- x = x + _x
490
- x = self.layer_norm2(x)
491
-
492
- return x
493
-
494
-
495
- class TransformerEncoder(nn.Module):
496
- """
497
- embedding_dim: token 向量长度
498
- head_num: 自注意力头
499
- block_num: transformer个数
500
- """
501
-
502
- def __init__(self, embedding_dim, head_num, block_num=2):
503
- super().__init__()
504
- self.layer_blocks = nn.ModuleList(
505
- [TransformerEncoderBlock(embedding_dim, head_num, 2 * embedding_dim) for _ in range(block_num)])
506
-
507
- def forward(self, x):
508
- for layer_block in self.layer_blocks:
509
- x = layer_block(x)
510
- return x
511
-
512
-
513
- class PathEmbedding(nn.Module):
514
- """
515
- img_dim: 输入图的大小
516
- in_channels: 输入的通道数
517
- embedding_dim: 每个token的向量长度
518
- patch_size:输入图token化,token的大小
519
- """
520
-
521
- def __init__(self, img_dim, in_channels, embedding_dim, patch_size):
522
- super().__init__()
523
-
524
- self.patch_size = patch_size
525
- self.num_tokens = (img_dim // patch_size) ** 2
526
- self.token_dim = in_channels * (patch_size ** 2)
527
- # 1. projection
528
- self.projection = nn.Linear(self.token_dim, embedding_dim)
529
- # 2. position embedding
530
- self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))
531
- # 3. cls token
532
- self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
533
-
534
- def forward(self, x):
535
- img_patches = rearrange(x,
536
- 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
537
- patch_x=self.patch_size, patch_y=self.patch_size)
538
-
539
- batch_size, tokens_num, _ = img_patches.shape
540
-
541
- patch_token = self.projection(img_patches)
542
- cls_token = repeat(self.cls_token, 'b ... -> (b batch_size) ...',
543
- batch_size=batch_size)
544
-
545
- patches = torch.cat([cls_token, patch_token], dim=1)
546
- # add postion embedding
547
- patches += self.embedding[:tokens_num + 1, :]
548
-
549
- # B,tokens_num+1,embedding_dim
550
- return patches
551
-
552
-
553
- class TransformerBottleNeck(nn.Module):
554
- def __init__(self, img_dim, in_channels, embedding_dim, head_num,
555
- block_num, patch_size=1, classification=False, dropout=0.1, num_classes=1):
556
- super().__init__()
557
- self.patch_embedding = PathEmbedding(img_dim, in_channels, embedding_dim, patch_size)
558
- self.transformer = TransformerEncoder(embedding_dim, head_num, block_num)
559
- self.dropout = nn.Dropout(dropout)
560
- self.classification = classification
561
- if self.classification:
562
- self.mlp_head = nn.Linear(embedding_dim, num_classes)
563
-
564
- def forward(self, x):
565
- x = self.patch_embedding(x)
566
- x = self.dropout(x)
567
- x = self.transformer(x)
568
- x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]
569
- return x
570
-
571
-
572
- class PGFusion(nn.Module):
573
-
574
- def __init__(self, in_channel=384, out_channel=384):
575
-
576
- super(PGFusion, self).__init__()
577
-
578
- self.in_channel = in_channel
579
- self.out_channel = out_channel
580
-
581
- self.patch_query = nn.Conv2d(in_channel, in_channel, kernel_size=1)
582
- self.patch_key = nn.Conv2d(in_channel, in_channel, kernel_size=1)
583
- self.patch_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False)
584
- self.patch_global_query = nn.Conv2d(in_channel, in_channel, kernel_size=1)
585
-
586
- self.global_key = nn.Conv2d(in_channel, in_channel, kernel_size=1)
587
- self.global_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False)
588
-
589
- self.fusion = nn.Conv2d(in_channel * 2, in_channel * 2, kernel_size=1)
590
-
591
- self.out_patch = nn.Conv2d(in_channel, out_channel, kernel_size=1)
592
- self.out_global = nn.Conv2d(in_channel, out_channel, kernel_size=1)
593
-
594
- self.softmax = nn.Softmax(dim=2)
595
- self.softmax_concat = nn.Softmax(dim=0)
596
-
597
- # self.gamma_patch_self = nn.Parameter(torch.zeros(1))
598
- # self.gamma_patch_global = nn.Parameter(torch.zeros(1))
599
-
600
- self.init_parameters()
601
-
602
- def init_parameters(self):
603
- for m in self.modules():
604
- if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
605
- nn.init.normal_(m.weight, 0, 0.01)
606
- # nn.init.xavier_uniform_(m.weight.data)
607
- if m.bias is not None:
608
- nn.init.zeros_(m.bias)
609
- # nn.init.constant_(m.bias, 0)
610
- m.inited = True
611
-
612
- def forward(self, patch_rep, global_rep):
613
- patch_rep_ = patch_rep.clone()
614
- patch_value = self.patch_value(patch_rep)
615
- patch_value = patch_value.view(patch_value.size(0), patch_value.size(1), -1)
616
- patch_key = self.patch_key(patch_rep)
617
- patch_key = patch_key.view(patch_key.size(0), patch_key.size(1), -1)
618
- dim_k = patch_key.shape[-1]
619
- patch_query = self.patch_query(patch_rep)
620
- patch_query = patch_query.view(patch_query.size(0), patch_query.size(1), -1)
621
-
622
- patch_global_query = self.patch_global_query(patch_rep)
623
- patch_global_query = patch_global_query.view(patch_global_query.size(0), patch_global_query.size(1), -1)
624
-
625
- global_value = self.global_value(global_rep)
626
- global_value = global_value.view(global_value.size(0), global_value.size(1), -1)
627
- global_key = self.global_key(global_rep)
628
- global_key = global_key.view(global_key.size(0), global_key.size(1), -1)
629
-
630
- ### patch self attention
631
- patch_self_sim_map = patch_query @ patch_key.transpose(-2, -1) / math.sqrt(dim_k)
632
- patch_self_sim_map = self.softmax(patch_self_sim_map)
633
- patch_self_sim_map = patch_self_sim_map @ patch_value
634
- patch_self_sim_map = patch_self_sim_map.view(patch_self_sim_map.size(0), patch_self_sim_map.size(1),
635
- *patch_rep.size()[2:])
636
- # patch_self_sim_map = self.gamma_patch_self * patch_self_sim_map
637
- patch_self_sim_map = 1 * patch_self_sim_map
638
- ### patch global attention
639
- patch_global_sim_map = patch_global_query @ global_key.transpose(-2, -1) / math.sqrt(dim_k)
640
- patch_global_sim_map = self.softmax(patch_global_sim_map)
641
- patch_global_sim_map = patch_global_sim_map @ global_value
642
- patch_global_sim_map = patch_global_sim_map.view(patch_global_sim_map.size(0), patch_global_sim_map.size(1),
643
- *patch_rep.size()[2:])
644
- # patch_global_sim_map = self.gamma_patch_global * patch_global_sim_map
645
- patch_global_sim_map = 1 * patch_global_sim_map
646
-
647
- fusion_sim_weight_map = torch.cat((patch_self_sim_map, patch_global_sim_map), dim=1)
648
- fusion_sim_weight_map = self.fusion(fusion_sim_weight_map)
649
- fusion_sim_weight_map = 1 * fusion_sim_weight_map
650
-
651
- patch_self_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[0]
652
- patch_self_sim_weight_map = torch.sigmoid(patch_self_sim_weight_map) # 0-1
653
-
654
- patch_global_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[
655
- 1]
656
- patch_global_sim_weight_map = torch.sigmoid(patch_global_sim_weight_map) # 0-1
657
-
658
- patch_self_sim_weight_map = torch.unsqueeze(patch_self_sim_weight_map, 0)
659
- patch_global_sim_weight_map = torch.unsqueeze(patch_global_sim_weight_map, 0)
660
-
661
- ct = torch.concat((patch_self_sim_weight_map, patch_global_sim_weight_map), 0)
662
- ct = self.softmax_concat(ct)
663
-
664
- out = patch_rep_ + patch_self_sim_map * ct[0] + patch_global_sim_map * (1 - ct[0])
665
-
666
- return out
667
-
668
-
669
- if __name__ == '__main__':
670
- x = torch.randn((2, 384, 16, 16))
671
- m = PGFusion()
672
- print(m)
673
- # y = TransformerBottleNeck(x.shape[2],x.shape[1],x.shape[1],8,4)
674
- print(m(x, x).shape)
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ # from timm.models.layers.cbam import CbamModule
7
+ import numpy as np
8
+ from einops import rearrange, repeat
9
+ import math
10
+
11
+
12
+ class ConvBn2d(nn.Module):
13
+ def __init__(self, in_channels, out_channels, kernel_size, padding):
14
+ super(ConvBn2d, self).__init__()
15
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
16
+ self.bn = nn.BatchNorm2d(out_channels)
17
+
18
+ def forward(self, x):
19
+ x = self.conv(x)
20
+ x = self.bn(x)
21
+ return x
22
+
23
+
24
+ class sSE(nn.Module):
25
+ def __init__(self, out_channels):
26
+ super(sSE, self).__init__()
27
+ self.conv = ConvBn2d(in_channels=out_channels, out_channels=1, kernel_size=1, padding=0)
28
+
29
+ def forward(self, x):
30
+ x = self.conv(x)
31
+ # print('spatial',x.size())
32
+ x = F.sigmoid(x)
33
+ return x
34
+
35
+
36
+ class cSE(nn.Module):
37
+ def __init__(self, out_channels):
38
+ super(cSE, self).__init__()
39
+ self.conv1 = ConvBn2d(in_channels=out_channels, out_channels=int(out_channels / 2), kernel_size=1, padding=0)
40
+ self.conv2 = ConvBn2d(in_channels=int(out_channels / 2), out_channels=out_channels, kernel_size=1, padding=0)
41
+
42
+ def forward(self, x):
43
+ x = nn.AvgPool2d(x.size()[2:])(x)
44
+ # print('channel',x.size())
45
+ x = self.conv1(x)
46
+ x = F.relu(x)
47
+ x = self.conv2(x)
48
+ x = F.sigmoid(x)
49
+ return x
50
+
51
+
52
+ class scSEBlock(nn.Module):
53
+ def __init__(self, out_channels):
54
+ super(scSEBlock, self).__init__()
55
+ self.spatial_gate = sSE(out_channels)
56
+ self.channel_gate = cSE(out_channels)
57
+
58
+ def forward(self, x):
59
+ g1 = self.spatial_gate(x)
60
+ g2 = self.channel_gate(x)
61
+ x = g1 * x + g2 * x
62
+ return x
63
+
64
+
65
+ class SaveFeatures():
66
+ features = None
67
+
68
+ def __init__(self, m):
69
+ self.hook = m.register_forward_hook(self.hook_fn)
70
+
71
+ def hook_fn(self, module, input, output):
72
+ # print('input',input)
73
+ # print('output',output.size())
74
+ if len(output.shape) == 3:
75
+ B, L, C = output.shape
76
+ h = int(L ** 0.5)
77
+ output = output.view(B, h, h, C)
78
+
79
+ output = output.permute(0, 3, 1, 2).contiguous()
80
+ if len(output.shape) == 4 and output.shape[2] != output.shape[3]:
81
+ output = output.permute(0, 3, 1, 2).contiguous()
82
+ # print(module)
83
+ self.features = output
84
+
85
+ def remove(self):
86
+ self.hook.remove()
87
+
88
+
89
+ class DBlock(nn.Module):
90
+
91
+ def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None):
92
+
93
+ super(DBlock, self).__init__()
94
+
95
+ self.conv1 = nn.Sequential(
96
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
97
+ nn.BatchNorm2d(out_channels),
98
+ nn.ReLU(inplace=True),
99
+ )
100
+
101
+ if attention_type == 'scse':
102
+ self.attention1 = scSEBlock(in_channels)
103
+ elif attention_type == 'cbam':
104
+ self.attention1 = nn.Identity()
105
+
106
+ elif attention_type == 'transformer':
107
+
108
+ self.attention1 = nn.Identity()
109
+
110
+
111
+ else:
112
+ self.attention1 = nn.Identity()
113
+
114
+ self.conv2 = \
115
+ nn.Sequential(
116
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
117
+ nn.BatchNorm2d(out_channels),
118
+ nn.ReLU(inplace=True),
119
+ )
120
+
121
+ self.conv3 = nn.Sequential(
122
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
123
+ nn.BatchNorm2d(out_channels),
124
+ nn.ReLU(inplace=True),
125
+ )
126
+ if attention_type == 'scse':
127
+ self.attention2 = scSEBlock(out_channels)
128
+ elif attention_type == 'cbam':
129
+ self.attention2 = CbamModule(channels=out_channels)
130
+
131
+ elif attention_type == 'transformer':
132
+ self.attention2 = nn.Identity()
133
+
134
+ else:
135
+ self.attention2 = nn.Identity()
136
+
137
+ def forward(self, x, skip):
138
+ if x.shape[1] != skip.shape[1]:
139
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
140
+
141
+ # print(x.shape,skip.shape)
142
+ x = self.attention1(x)
143
+ x = self.conv1(x)
144
+
145
+ x = torch.cat([x, skip], dim=1)
146
+
147
+ x = self.conv2(x)
148
+ x = self.conv3(x)
149
+ x = self.attention2(x)
150
+
151
+ return x
152
+
153
+
154
+ class DBlock_res(nn.Module):
155
+
156
+ def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type=None):
157
+
158
+ super(DBlock_res, self).__init__()
159
+
160
+ self.conv1 = nn.Sequential(
161
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1),
162
+ nn.BatchNorm2d(out_channels),
163
+ nn.ReLU(inplace=True),
164
+ )
165
+
166
+ if attention_type == 'scse':
167
+ self.attention1 = scSEBlock(in_channels)
168
+ elif attention_type == 'cbam':
169
+ self.attention1 = CbamModule(channels=in_channels)
170
+
171
+ elif attention_type == 'transformer':
172
+
173
+ self.attention1 = nn.Identity()
174
+
175
+
176
+ else:
177
+ self.attention1 = nn.Identity()
178
+
179
+ self.conv2 = \
180
+ nn.Sequential(
181
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1),
182
+ nn.BatchNorm2d(out_channels),
183
+ nn.ReLU(inplace=True),
184
+ )
185
+
186
+ self.conv3 = nn.Sequential(
187
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
188
+ nn.BatchNorm2d(out_channels),
189
+ nn.ReLU(inplace=True),
190
+ )
191
+ if attention_type == 'scse':
192
+ self.attention2 = scSEBlock(out_channels)
193
+ elif attention_type == 'cbam':
194
+ self.attention2 = CbamModule(channels=out_channels)
195
+
196
+ elif attention_type == 'transformer':
197
+ self.attention2 = nn.Identity()
198
+
199
+ else:
200
+ self.attention2 = nn.Identity()
201
+
202
+ def forward(self, x, skip):
203
+
204
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
205
+
206
+ # print(x.shape,skip.shape)
207
+ x = self.attention1(x)
208
+ x = self.conv1(x)
209
+
210
+ x = torch.cat([x, skip], dim=1)
211
+
212
+ x = self.conv2(x)
213
+ x = self.conv3(x)
214
+ x = self.attention2(x)
215
+
216
+ return x
217
+
218
+
219
+ class DBlock_att(nn.Module):
220
+
221
+ def __init__(self, in_channels, out_channels, use_batchnorm=True, attention_type='transformer'):
222
+
223
+ super(DBlock_att, self).__init__()
224
+
225
+ self.conv1 = nn.Sequential(
226
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
227
+ nn.BatchNorm2d(out_channels),
228
+ nn.ReLU(inplace=True),
229
+ )
230
+
231
+ if attention_type == 'scse':
232
+ self.attention1 = scSEBlock(in_channels)
233
+ elif attention_type == 'cbam':
234
+ self.attention1 = CbamModule(channels=in_channels)
235
+
236
+ elif attention_type == 'transformer':
237
+
238
+ self.attention1 = nn.Identity()
239
+
240
+
241
+ else:
242
+ self.attention1 = nn.Identity()
243
+
244
+ self.conv2 = \
245
+ nn.Sequential(
246
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
247
+ nn.BatchNorm2d(out_channels),
248
+ nn.ReLU(inplace=True),
249
+ )
250
+
251
+ self.conv3 = nn.Sequential(
252
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
253
+ nn.BatchNorm2d(out_channels),
254
+ nn.ReLU(inplace=True),
255
+ )
256
+ if attention_type == 'scse':
257
+ self.attention2 = scSEBlock(out_channels)
258
+ elif attention_type == 'cbam':
259
+ self.attention2 = CbamModule(channels=out_channels)
260
+
261
+ elif attention_type == 'transformer':
262
+ self.attention2 = nn.Identity()
263
+
264
+ else:
265
+ self.attention2 = nn.Identity()
266
+
267
+ def forward(self, x, skip):
268
+ if x.shape[1] != skip.shape[1]:
269
+ x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
270
+
271
+ # print(x.shape,skip.shape)
272
+ x = self.attention1(x)
273
+ x = self.conv1(x)
274
+
275
+ x = torch.cat([x, skip], dim=1)
276
+ x = self.conv2(x)
277
+ x = self.conv3(x)
278
+
279
+ x = self.attention2(x)
280
+
281
+ return x
282
+
283
+
284
+ class SegmentationHead(nn.Module):
285
+ def __init__(self, in_channels, num_class, kernel_size=3, upsample=4):
286
+ super(SegmentationHead, self).__init__()
287
+ self.upsample = nn.UpsamplingBilinear2d(scale_factor=upsample) if upsample > 1 else nn.Identity()
288
+ self.conv = nn.Conv2d(in_channels, num_class, kernel_size=kernel_size, padding=kernel_size // 2)
289
+
290
+ def forward(self, x):
291
+ x = self.upsample(x)
292
+ x = self.conv(x)
293
+ return x
294
+
295
+
296
+ class AV_Cross(nn.Module):
297
+
298
+ def __init__(self, channels=2, r=2, residual=True, block=4, kernel_size=1):
299
+ super(AV_Cross, self).__init__()
300
+ out_channels = int(channels // r)
301
+ self.residual = residual
302
+ self.block = block
303
+ self.bn = nn.BatchNorm2d(3)
304
+ self.relu = False
305
+ self.kernel_size = kernel_size
306
+ self.a_ve_att = nn.ModuleList()
307
+ self.v_ve_att = nn.ModuleList()
308
+ self.ve_att = nn.ModuleList()
309
+ for i in range(self.block):
310
+ self.a_ve_att.append(nn.Sequential(
311
+ nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
312
+ padding=(self.kernel_size - 1) // 2),
313
+ nn.BatchNorm2d(out_channels),
314
+ ))
315
+ self.v_ve_att.append(nn.Sequential(
316
+ nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
317
+ padding=(self.kernel_size - 1) // 2),
318
+ nn.BatchNorm2d(out_channels),
319
+ ))
320
+ self.ve_att.append(nn.Sequential(
321
+ nn.Conv2d(3, out_channels, kernel_size=self.kernel_size, stride=1, padding=(self.kernel_size - 1) // 2),
322
+ nn.BatchNorm2d(out_channels),
323
+ ))
324
+ self.sigmoid = nn.Sigmoid()
325
+ self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
326
+
327
+ def forward(self, x):
328
+ a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :]
329
+ for i in range(self.block):
330
+ # x = self.relu(self.bn(x))
331
+ a_ve = torch.concat([a, ve], dim=1)
332
+ v_ve = torch.concat([v, ve], dim=1)
333
+ a_v_ve = torch.concat([a, ve, v], dim=1)
334
+ x_a = self.a_ve_att[i](a_ve)
335
+ x_v = self.v_ve_att[i](v_ve)
336
+ x_a_v = self.ve_att[i](a_v_ve)
337
+ a_weight = self.sigmoid(x_a)
338
+ v_weight = self.sigmoid(x_v)
339
+ ve_weight = self.sigmoid(x_a_v)
340
+ if self.residual:
341
+ a = a + a * a_weight
342
+ v = v + v * v_weight
343
+ ve = ve + ve * ve_weight
344
+ else:
345
+ a = a * a_weight
346
+ v = v * v_weight
347
+ ve = ve * ve_weight
348
+
349
+ out = torch.concat([a, ve, v], dim=1)
350
+
351
+ if self.relu:
352
+ out = F.relu(out)
353
+ out = self.final(out)
354
+ return out
355
+
356
+
357
+ class AV_Cross_v2(nn.Module):
358
+
359
+ def __init__(self, channels=2, r=2, residual=True, block=1, relu=False, kernel_size=1):
360
+ super(AV_Cross_v2, self).__init__()
361
+ out_channels = int(channels // r)
362
+ self.residual = residual
363
+ self.block = block
364
+ self.relu = relu
365
+ self.kernel_size = kernel_size
366
+ self.a_ve_att = nn.ModuleList()
367
+ self.v_ve_att = nn.ModuleList()
368
+ self.ve_att = nn.ModuleList()
369
+ for i in range(self.block):
370
+ self.a_ve_att.append(nn.Sequential(
371
+ nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
372
+ padding=(self.kernel_size - 1) // 2),
373
+ nn.BatchNorm2d(out_channels)
374
+ ))
375
+ self.v_ve_att.append(nn.Sequential(
376
+ nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
377
+ padding=(self.kernel_size - 1) // 2),
378
+ nn.BatchNorm2d(out_channels)
379
+ ))
380
+ self.ve_att.append(nn.Sequential(
381
+ nn.Conv2d(channels, out_channels, kernel_size=self.kernel_size, stride=1,
382
+ padding=(self.kernel_size - 1) // 2),
383
+ nn.BatchNorm2d(out_channels)
384
+ ))
385
+
386
+ self.sigmoid = nn.Sigmoid()
387
+ self.final = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
388
+
389
+ def forward(self, x):
390
+ a, ve, v = x[:, 0:1, :, :], x[:, 1:2, :, :], x[:, 2:, :, :]
391
+
392
+ for i in range(self.block):
393
+ tmp = torch.cat([a, ve, v], dim=1)
394
+ a_ve = torch.concat([a, ve], dim=1)
395
+ a_ve = torch.cat([torch.max(a_ve, dim=1, keepdim=True)[0], torch.mean(a_ve, dim=1, keepdim=True)], dim=1)
396
+ v_ve = torch.concat([v, ve], dim=1)
397
+ v_ve = torch.cat([torch.max(v_ve, dim=1, keepdim=True)[0], torch.mean(v_ve, dim=1, keepdim=True)], dim=1)
398
+ a_v_ve = torch.concat([torch.max(tmp, dim=1, keepdim=True)[0], torch.mean(tmp, dim=1, keepdim=True)], dim=1)
399
+
400
+ a_ve = self.a_ve_att[i](a_ve)
401
+ v_ve = self.v_ve_att[i](v_ve)
402
+ a_v_ve = self.ve_att[i](a_v_ve)
403
+ a_weight = self.sigmoid(a_ve)
404
+ v_weight = self.sigmoid(v_ve)
405
+ ve_weight = self.sigmoid(a_v_ve)
406
+ if self.residual:
407
+ a = a + a * a_weight
408
+ v = v + v * v_weight
409
+ ve = ve + ve * ve_weight
410
+ else:
411
+ a = a * a_weight
412
+ v = v * v_weight
413
+ ve = ve * ve_weight
414
+
415
+ out = torch.concat([a, ve, v], dim=1)
416
+
417
+ if self.relu:
418
+ out = F.relu(out)
419
+ out = self.final(out)
420
+ return out
421
+
422
+
423
+ class MultiHeadAttention(nn.Module):
424
+ def __init__(self, embedding_dim, head_num):
425
+ super().__init__()
426
+
427
+ self.head_num = head_num
428
+ self.dk = (embedding_dim // head_num) ** (1 / 2)
429
+
430
+ self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)
431
+ self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)
432
+
433
+ def forward(self, x, mask=None):
434
+ qkv = self.qkv_layer(x)
435
+
436
+ query, key, value = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.head_num))
437
+ energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk
438
+
439
+ if mask is not None:
440
+ energy = energy.masked_fill(mask, -np.inf)
441
+
442
+ attention = torch.softmax(energy, dim=-1)
443
+
444
+ x = torch.einsum("... i j , ... j d -> ... i d", attention, value)
445
+
446
+ x = rearrange(x, "b h t d -> b t (h d)")
447
+ x = self.out_attention(x)
448
+
449
+ return x
450
+
451
+
452
+ class MLP(nn.Module):
453
+ def __init__(self, embedding_dim, mlp_dim):
454
+ super().__init__()
455
+
456
+ self.mlp_layers = nn.Sequential(
457
+ nn.Linear(embedding_dim, mlp_dim),
458
+ nn.GELU(),
459
+ nn.Dropout(0.1),
460
+ nn.Linear(mlp_dim, embedding_dim),
461
+ nn.Dropout(0.1)
462
+ )
463
+
464
+ def forward(self, x):
465
+ x = self.mlp_layers(x)
466
+
467
+ return x
468
+
469
+
470
+ class TransformerEncoderBlock(nn.Module):
471
+ def __init__(self, embedding_dim, head_num, mlp_dim):
472
+ super().__init__()
473
+
474
+ self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)
475
+ self.mlp = MLP(embedding_dim, mlp_dim)
476
+
477
+ self.layer_norm1 = nn.LayerNorm(embedding_dim)
478
+ self.layer_norm2 = nn.LayerNorm(embedding_dim)
479
+
480
+ self.dropout = nn.Dropout(0.1)
481
+
482
+ def forward(self, x):
483
+ _x = self.multi_head_attention(x)
484
+ _x = self.dropout(_x)
485
+ x = x + _x
486
+ x = self.layer_norm1(x)
487
+
488
+ _x = self.mlp(x)
489
+ x = x + _x
490
+ x = self.layer_norm2(x)
491
+
492
+ return x
493
+
494
+
495
+ class TransformerEncoder(nn.Module):
496
+ """
497
+ embedding_dim: token 向量长度
498
+ head_num: 自注意力头
499
+ block_num: transformer个数
500
+ """
501
+
502
+ def __init__(self, embedding_dim, head_num, block_num=2):
503
+ super().__init__()
504
+ self.layer_blocks = nn.ModuleList(
505
+ [TransformerEncoderBlock(embedding_dim, head_num, 2 * embedding_dim) for _ in range(block_num)])
506
+
507
+ def forward(self, x):
508
+ for layer_block in self.layer_blocks:
509
+ x = layer_block(x)
510
+ return x
511
+
512
+
513
+ class PathEmbedding(nn.Module):
514
+ """
515
+ img_dim: 输入图的大小
516
+ in_channels: 输入的通道数
517
+ embedding_dim: 每个token的向量长度
518
+ patch_size:输入图token化,token的大小
519
+ """
520
+
521
+ def __init__(self, img_dim, in_channels, embedding_dim, patch_size):
522
+ super().__init__()
523
+
524
+ self.patch_size = patch_size
525
+ self.num_tokens = (img_dim // patch_size) ** 2
526
+ self.token_dim = in_channels * (patch_size ** 2)
527
+ # 1. projection
528
+ self.projection = nn.Linear(self.token_dim, embedding_dim)
529
+ # 2. position embedding
530
+ self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))
531
+ # 3. cls token
532
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
533
+
534
+ def forward(self, x):
535
+ img_patches = rearrange(x,
536
+ 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
537
+ patch_x=self.patch_size, patch_y=self.patch_size)
538
+
539
+ batch_size, tokens_num, _ = img_patches.shape
540
+
541
+ patch_token = self.projection(img_patches)
542
+ cls_token = repeat(self.cls_token, 'b ... -> (b batch_size) ...',
543
+ batch_size=batch_size)
544
+
545
+ patches = torch.cat([cls_token, patch_token], dim=1)
546
+ # add postion embedding
547
+ patches += self.embedding[:tokens_num + 1, :]
548
+
549
+ # B,tokens_num+1,embedding_dim
550
+ return patches
551
+
552
+
553
+ class TransformerBottleNeck(nn.Module):
554
+ def __init__(self, img_dim, in_channels, embedding_dim, head_num,
555
+ block_num, patch_size=1, classification=False, dropout=0.1, num_classes=1):
556
+ super().__init__()
557
+ self.patch_embedding = PathEmbedding(img_dim, in_channels, embedding_dim, patch_size)
558
+ self.transformer = TransformerEncoder(embedding_dim, head_num, block_num)
559
+ self.dropout = nn.Dropout(dropout)
560
+ self.classification = classification
561
+ if self.classification:
562
+ self.mlp_head = nn.Linear(embedding_dim, num_classes)
563
+
564
+ def forward(self, x):
565
+ x = self.patch_embedding(x)
566
+ x = self.dropout(x)
567
+ x = self.transformer(x)
568
+ x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]
569
+ return x
570
+
571
+
572
+ class PGFusion(nn.Module):
573
+
574
+ def __init__(self, in_channel=384, out_channel=384):
575
+
576
+ super(PGFusion, self).__init__()
577
+
578
+ self.in_channel = in_channel
579
+ self.out_channel = out_channel
580
+
581
+ self.patch_query = nn.Conv2d(in_channel, in_channel, kernel_size=1)
582
+ self.patch_key = nn.Conv2d(in_channel, in_channel, kernel_size=1)
583
+ self.patch_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False)
584
+ self.patch_global_query = nn.Conv2d(in_channel, in_channel, kernel_size=1)
585
+
586
+ self.global_key = nn.Conv2d(in_channel, in_channel, kernel_size=1)
587
+ self.global_value = nn.Conv2d(in_channel, in_channel, kernel_size=1, bias=False)
588
+
589
+ self.fusion = nn.Conv2d(in_channel * 2, in_channel * 2, kernel_size=1)
590
+
591
+ self.out_patch = nn.Conv2d(in_channel, out_channel, kernel_size=1)
592
+ self.out_global = nn.Conv2d(in_channel, out_channel, kernel_size=1)
593
+
594
+ self.softmax = nn.Softmax(dim=2)
595
+ self.softmax_concat = nn.Softmax(dim=0)
596
+
597
+ # self.gamma_patch_self = nn.Parameter(torch.zeros(1))
598
+ # self.gamma_patch_global = nn.Parameter(torch.zeros(1))
599
+
600
+ self.init_parameters()
601
+
602
+ def init_parameters(self):
603
+ for m in self.modules():
604
+ if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
605
+ nn.init.normal_(m.weight, 0, 0.01)
606
+ # nn.init.xavier_uniform_(m.weight.data)
607
+ if m.bias is not None:
608
+ nn.init.zeros_(m.bias)
609
+ # nn.init.constant_(m.bias, 0)
610
+ m.inited = True
611
+
612
+ def forward(self, patch_rep, global_rep):
613
+ patch_rep_ = patch_rep.clone()
614
+ patch_value = self.patch_value(patch_rep)
615
+ patch_value = patch_value.view(patch_value.size(0), patch_value.size(1), -1)
616
+ patch_key = self.patch_key(patch_rep)
617
+ patch_key = patch_key.view(patch_key.size(0), patch_key.size(1), -1)
618
+ dim_k = patch_key.shape[-1]
619
+ patch_query = self.patch_query(patch_rep)
620
+ patch_query = patch_query.view(patch_query.size(0), patch_query.size(1), -1)
621
+
622
+ patch_global_query = self.patch_global_query(patch_rep)
623
+ patch_global_query = patch_global_query.view(patch_global_query.size(0), patch_global_query.size(1), -1)
624
+
625
+ global_value = self.global_value(global_rep)
626
+ global_value = global_value.view(global_value.size(0), global_value.size(1), -1)
627
+ global_key = self.global_key(global_rep)
628
+ global_key = global_key.view(global_key.size(0), global_key.size(1), -1)
629
+
630
+ ### patch self attention
631
+ patch_self_sim_map = patch_query @ patch_key.transpose(-2, -1) / math.sqrt(dim_k)
632
+ patch_self_sim_map = self.softmax(patch_self_sim_map)
633
+ patch_self_sim_map = patch_self_sim_map @ patch_value
634
+ patch_self_sim_map = patch_self_sim_map.view(patch_self_sim_map.size(0), patch_self_sim_map.size(1),
635
+ *patch_rep.size()[2:])
636
+ patch_self_sim_map = self.gamma_patch_self * patch_self_sim_map
637
+ # patch_self_sim_map = 1 * patch_self_sim_map
638
+ ### patch global attention
639
+ patch_global_sim_map = patch_global_query @ global_key.transpose(-2, -1) / math.sqrt(dim_k)
640
+ patch_global_sim_map = self.softmax(patch_global_sim_map)
641
+ patch_global_sim_map = patch_global_sim_map @ global_value
642
+ patch_global_sim_map = patch_global_sim_map.view(patch_global_sim_map.size(0), patch_global_sim_map.size(1),
643
+ *patch_rep.size()[2:])
644
+ patch_global_sim_map = self.gamma_patch_global * patch_global_sim_map
645
+ # patch_global_sim_map = 1 * patch_global_sim_map
646
+
647
+ fusion_sim_weight_map = torch.cat((patch_self_sim_map, patch_global_sim_map), dim=1)
648
+ fusion_sim_weight_map = self.fusion(fusion_sim_weight_map)
649
+ fusion_sim_weight_map = 1 * fusion_sim_weight_map
650
+
651
+ patch_self_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[0]
652
+ patch_self_sim_weight_map = torch.sigmoid(patch_self_sim_weight_map) # 0-1
653
+
654
+ patch_global_sim_weight_map = torch.split(fusion_sim_weight_map, dim=1, split_size_or_sections=self.in_channel)[
655
+ 1]
656
+ patch_global_sim_weight_map = torch.sigmoid(patch_global_sim_weight_map) # 0-1
657
+
658
+ patch_self_sim_weight_map = torch.unsqueeze(patch_self_sim_weight_map, 0)
659
+ patch_global_sim_weight_map = torch.unsqueeze(patch_global_sim_weight_map, 0)
660
+
661
+ ct = torch.concat((patch_self_sim_weight_map, patch_global_sim_weight_map), 0)
662
+ ct = self.softmax_concat(ct)
663
+
664
+ out = patch_rep_ + patch_self_sim_map * ct[0] + patch_global_sim_map * (1 - ct[0])
665
+
666
+ return out
667
+
668
+
669
+ if __name__ == '__main__':
670
+ x = torch.randn((2, 384, 16, 16))
671
+ m = PGFusion()
672
+ print(m)
673
+ # y = TransformerBottleNeck(x.shape[2],x.shape[1],x.shape[1],8,4)
674
+ print(m(x, x).shape)