jwyang commited on
Commit
2fafc55
1 Parent(s): d51616b

push unicl demo

Browse files
Files changed (43) hide show
  1. README.md +1 -1
  2. app.py +143 -0
  3. apple_with_ipod.jpg +0 -0
  4. config.py +245 -0
  5. configs/unicl_focalnet_giant.yaml +16 -0
  6. configs/unicl_swin_base.yaml +16 -0
  7. configs/unicl_swin_tiny.yaml +16 -0
  8. crowd2.jpg +0 -0
  9. elephants.png +0 -0
  10. model/__init__.py +1 -0
  11. model/__pycache__/__init__.cpython-39.pyc +0 -0
  12. model/__pycache__/model.cpython-39.pyc +0 -0
  13. model/__pycache__/templates.cpython-39.pyc +0 -0
  14. model/image_encoder/__init__.py +1 -0
  15. model/image_encoder/__pycache__/__init__.cpython-38.pyc +0 -0
  16. model/image_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
  17. model/image_encoder/__pycache__/build.cpython-38.pyc +0 -0
  18. model/image_encoder/__pycache__/build.cpython-39.pyc +0 -0
  19. model/image_encoder/__pycache__/focalnet.cpython-38.pyc +0 -0
  20. model/image_encoder/__pycache__/focalnet.cpython-39.pyc +0 -0
  21. model/image_encoder/__pycache__/swin_transformer.cpython-38.pyc +0 -0
  22. model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc +0 -0
  23. model/image_encoder/build.py +59 -0
  24. model/image_encoder/focalnet.py +649 -0
  25. model/image_encoder/swin_transformer.py +586 -0
  26. model/model.py +204 -0
  27. model/templates.py +83 -0
  28. model/text_encoder/__init__.py +9 -0
  29. model/text_encoder/__pycache__/__init__.cpython-38.pyc +0 -0
  30. model/text_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
  31. model/text_encoder/__pycache__/build.cpython-38.pyc +0 -0
  32. model/text_encoder/__pycache__/build.cpython-39.pyc +0 -0
  33. model/text_encoder/__pycache__/hf_model.cpython-38.pyc +0 -0
  34. model/text_encoder/__pycache__/hf_model.cpython-39.pyc +0 -0
  35. model/text_encoder/__pycache__/registry.cpython-38.pyc +0 -0
  36. model/text_encoder/__pycache__/registry.cpython-39.pyc +0 -0
  37. model/text_encoder/__pycache__/transformer.cpython-38.pyc +0 -0
  38. model/text_encoder/__pycache__/transformer.cpython-39.pyc +0 -0
  39. model/text_encoder/build.py +31 -0
  40. model/text_encoder/hf_model.py +27 -0
  41. model/text_encoder/registry.py +18 -0
  42. model/text_encoder/transformer.py +194 -0
  43. requirements.txt +7 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Unicl Img Recognition Demo
3
  emoji: 🏢
4
  colorFrom: red
5
  colorTo: purple
 
1
  ---
2
+ title: Unicl Image Recognition Demo
3
  emoji: 🏢
4
  colorFrom: red
5
  colorTo: purple
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import requests
3
+ import gradio as gr
4
+ import numpy as np
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.data import create_transform
12
+ from config import get_config
13
+ from model import build_model
14
+
15
+ # Download human-readable labels for ImageNet.
16
+ response = requests.get("https://git.io/JJkYN")
17
+ labels = response.text.split("\n")
18
+
19
+ def parse_option():
20
+ parser = argparse.ArgumentParser('UniCL demo script', add_help=False)
21
+ parser.add_argument('--cfg', type=str, default="configs/unicl_swin_base.yaml", metavar="FILE", help='path to config file', )
22
+ args, unparsed = parser.parse_known_args()
23
+
24
+ config = get_config(args)
25
+
26
+ return args, config
27
+
28
+ def build_transforms(img_size, center_crop=True):
29
+ t = [transforms.ToPILImage()]
30
+ if center_crop:
31
+ size = int((256 / 224) * img_size)
32
+ t.append(
33
+ transforms.Resize(size)
34
+ )
35
+ t.append(
36
+ transforms.CenterCrop(img_size)
37
+ )
38
+ else:
39
+ t.append(
40
+ transforms.Resize(img_size)
41
+ )
42
+ t.append(transforms.ToTensor())
43
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
44
+ return transforms.Compose(t)
45
+
46
+ def build_transforms4display(img_size, center_crop=True):
47
+ t = [transforms.ToPILImage()]
48
+ if center_crop:
49
+ size = int((256 / 224) * img_size)
50
+ t.append(
51
+ transforms.Resize(size)
52
+ )
53
+ t.append(
54
+ transforms.CenterCrop(img_size)
55
+ )
56
+ else:
57
+ t.append(
58
+ transforms.Resize(img_size)
59
+ )
60
+ t.append(transforms.ToTensor())
61
+ return transforms.Compose(t)
62
+
63
+ args, config = parse_option()
64
+
65
+ '''
66
+ build model
67
+ '''
68
+ model = build_model(config)
69
+
70
+ url = './in21k_yfcc14m_gcc15m_swin_base.pth'
71
+ checkpoint = torch.load(url, map_location="cpu")
72
+ model.load_state_dict(checkpoint["model"])
73
+ model.eval()
74
+
75
+ '''
76
+ build data transform
77
+ '''
78
+ eval_transforms = build_transforms(224, center_crop=True)
79
+ display_transforms = build_transforms4display(224, center_crop=True)
80
+
81
+ '''
82
+ build upsampler
83
+ '''
84
+ # upsampler = nn.Upsample(scale_factor=16, mode='bilinear')
85
+
86
+ '''
87
+ borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
88
+ '''
89
+ def show_cam_on_image(img: np.ndarray,
90
+ mask: np.ndarray,
91
+ use_rgb: bool = False,
92
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
93
+ """ This function overlays the cam mask on the image as an heatmap.
94
+ By default the heatmap is in BGR format.
95
+ :param img: The base image in RGB or BGR format.
96
+ :param mask: The cam mask.
97
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
98
+ :param colormap: The OpenCV colormap to be used.
99
+ :returns: The default image with the cam overlay.
100
+ """
101
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
102
+ if use_rgb:
103
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
104
+ heatmap = np.float32(heatmap) / 255
105
+
106
+ if np.max(img) > 1:
107
+ raise Exception(
108
+ "The input image should np.float32 in the range [0, 1]")
109
+
110
+ cam = 0.7*heatmap + 0.3*img
111
+ # cam = cam / np.max(cam)
112
+ return np.uint8(255 * cam)
113
+
114
+ def recognize_image(image, texts):
115
+ img_t = eval_transforms(image)
116
+ img_d = display_transforms(image).permute(1, 2, 0).numpy()
117
+
118
+ text_embeddings = model.get_text_embeddings(texts.split(';'))
119
+
120
+ # compute output
121
+ feat_img = model.encode_image(img_t.unsqueeze(0))
122
+ output = model.logit_scale.exp() * feat_img @ text_embeddings.t()
123
+ prediction = output.softmax(-1).flatten()
124
+
125
+ return {texts.split(';')[i]: float(prediction[i]) for i in range(len(texts.split(';')))}
126
+
127
+
128
+ image = gr.inputs.Image()
129
+ label = gr.outputs.Label(num_top_classes=100)
130
+
131
+ gr.Interface(
132
+ description="UniCL for Zero-shot Image Recognition Demo (https://github.com/microsoft/unicl)",
133
+ fn=recognize_image,
134
+ inputs=["image", "text"],
135
+ outputs=[
136
+ label,
137
+ ],
138
+ examples=[
139
+ ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
140
+ ["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"],
141
+ ["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"],
142
+ ],
143
+ ).launch()
apple_with_ipod.jpg ADDED
config.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Unified Contrastive Learning (UniCL)
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang ([email protected])
6
+ # Based on Swin Transformer written by Zhe Liu
7
+ # --------------------------------------------------------
8
+
9
+ import os
10
+ import yaml
11
+ from yacs.config import CfgNode as CN
12
+
13
+ _C = CN()
14
+ _C.VERBOSE = False
15
+
16
+ # Base config files
17
+ _C.BASE = ['']
18
+
19
+ # -----------------------------------------------------------------------------
20
+ # Data settings
21
+ # -----------------------------------------------------------------------------
22
+ _C.DATA = CN()
23
+ # Batch size for a single GPU, could be overwritten by command line argument
24
+ _C.DATA.BATCH_SIZE = 128
25
+ # Path to dataset, could be overwritten by command line argument
26
+ _C.DATA.DATA_PATH = ''
27
+ # Dataset name
28
+ _C.DATA.DATASET = 'imagenet'
29
+ # Input image size
30
+ _C.DATA.IMG_SIZE = 224
31
+ # Interpolation to resize image (random, bilinear, bicubic)
32
+ _C.DATA.INTERPOLATION = 'bicubic'
33
+ # Use zipped dataset instead of folder dataset
34
+ # could be overwritten by command line argument
35
+ _C.DATA.ZIP_MODE = False
36
+ # Cache Data in Memory, could be overwritten by command line argument
37
+ _C.DATA.CACHE_MODE = 'part'
38
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
39
+ _C.DATA.PIN_MEMORY = True
40
+ # Number of data loading threads
41
+ _C.DATA.NUM_WORKERS = 8
42
+
43
+ # -----------------------------------------------------------------------------
44
+ # Model settings
45
+ # -----------------------------------------------------------------------------
46
+ _C.MODEL = CN()
47
+ # Model name
48
+ _C.MODEL.NAME = ''
49
+ # Checkpoint to resume, could be overwritten by command line argument
50
+ _C.MODEL.RESUME = ''
51
+ # Number of classes, overwritten in data preparation
52
+ _C.MODEL.NUM_CLASSES = 0
53
+ # Label Smoothing
54
+ _C.MODEL.LABEL_SMOOTHING = 0.1
55
+ # Whether load pretrained model
56
+ _C.MODEL.PRETRAINED = ''
57
+ # Projection dimension
58
+ _C.MODEL.DIM_PROJECTION = 512
59
+ # Mode specific
60
+ _C.MODEL.SPEC = CN(new_allowed=True)
61
+ # -----------------------------------------------------------------------------
62
+ # Build Image Encoder
63
+ # -----------------------------------------------------------------------------
64
+ _C.MODEL.IMAGE_ENCODER = CN()
65
+ # Image encoder type
66
+ _C.MODEL.IMAGE_ENCODER.TYPE = 'swin'
67
+ # Input image size
68
+ _C.MODEL.IMAGE_ENCODER.IMG_SIZE = 224
69
+ # Dropout rate
70
+ _C.MODEL.IMAGE_ENCODER.DROP_RATE = 0.0
71
+ # Drop path rate
72
+ _C.MODEL.IMAGE_ENCODER.DROP_PATH_RATE = 0.1
73
+
74
+ # Swin Transformer parameters
75
+ _C.MODEL.IMAGE_ENCODER.SWIN = CN()
76
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_SIZE = 4
77
+ _C.MODEL.IMAGE_ENCODER.SWIN.IN_CHANS = 3
78
+ _C.MODEL.IMAGE_ENCODER.SWIN.EMBED_DIM = 96
79
+ _C.MODEL.IMAGE_ENCODER.SWIN.DEPTHS = [2, 2, 6, 2]
80
+ _C.MODEL.IMAGE_ENCODER.SWIN.NUM_HEADS = [3, 6, 12, 24]
81
+ _C.MODEL.IMAGE_ENCODER.SWIN.WINDOW_SIZE = 7
82
+ _C.MODEL.IMAGE_ENCODER.SWIN.MLP_RATIO = 4.
83
+ _C.MODEL.IMAGE_ENCODER.SWIN.QKV_BIAS = True
84
+ _C.MODEL.IMAGE_ENCODER.SWIN.QK_SCALE = None
85
+ _C.MODEL.IMAGE_ENCODER.SWIN.APE = False
86
+ _C.MODEL.IMAGE_ENCODER.SWIN.PATCH_NORM = True
87
+
88
+ # FocalNet parameters
89
+ _C.MODEL.IMAGE_ENCODER.FOCAL = CN()
90
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_SIZE = 4
91
+ _C.MODEL.IMAGE_ENCODER.FOCAL.IN_CHANS = 3
92
+ _C.MODEL.IMAGE_ENCODER.FOCAL.EMBED_DIM = 96
93
+ _C.MODEL.IMAGE_ENCODER.FOCAL.DEPTHS = [2, 2, 6, 2]
94
+ _C.MODEL.IMAGE_ENCODER.FOCAL.MLP_RATIO = 4.
95
+ _C.MODEL.IMAGE_ENCODER.FOCAL.PATCH_NORM = True
96
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_LEVELS = [2, 2, 2, 2]
97
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_WINDOWS = [3, 3, 3, 3]
98
+ _C.MODEL.IMAGE_ENCODER.FOCAL.FOCAL_FACTORS = [2, 2, 2, 2]
99
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_CONV_EMBED = False
100
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_LAYERSCALE = False
101
+ _C.MODEL.IMAGE_ENCODER.FOCAL.USE_POSTLN = False
102
+
103
+ # -----------------------------------------------------------------------------
104
+ # Build Text Encoder
105
+ # -----------------------------------------------------------------------------
106
+ _C.MODEL.TEXT_ENCODER = CN()
107
+
108
+ _C.MODEL.TEXT_ENCODER.NAME = 'transformer'
109
+ _C.MODEL.TEXT_ENCODER.LOAD_PRETRAINED = False
110
+ _C.MODEL.TEXT_ENCODER.PRETRAINED = ''
111
+ _C.MODEL.TEXT_ENCODER.TOKENIZER = 'clip'
112
+ _C.MODEL.TEXT_ENCODER.CONTEXT_LENGTH = 77
113
+ _C.MODEL.TEXT_ENCODER.WIDTH = 1024
114
+ _C.MODEL.TEXT_ENCODER.HEADS = 16
115
+ _C.MODEL.TEXT_ENCODER.LAYERS = 12
116
+ _C.MODEL.TEXT_ENCODER.AUTOGRESSIVE = True
117
+
118
+ # -----------------------------------------------------------------------------
119
+ # Training settings
120
+ # -----------------------------------------------------------------------------
121
+ _C.TRAIN = CN()
122
+ _C.TRAIN.START_EPOCH = 0
123
+ _C.TRAIN.EPOCHS = 32
124
+ _C.TRAIN.WARMUP_EPOCHS = 5
125
+ _C.TRAIN.WEIGHT_DECAY = 0.1
126
+ _C.TRAIN.BASE_LR = 5e-4
127
+ _C.TRAIN.WARMUP_LR = 5e-7
128
+ _C.TRAIN.MIN_LR = 5e-6
129
+ # Clip gradient norm
130
+ _C.TRAIN.CLIP_GRAD = 5.0
131
+ # Auto resume from latest checkpoint
132
+ _C.TRAIN.AUTO_RESUME = True
133
+ # Gradient accumulation steps
134
+ # could be overwritten by command line argument
135
+ _C.TRAIN.ACCUMULATION_STEPS = 0
136
+ # Whether to use gradient checkpointing to save memory
137
+ # could be overwritten by command line argument
138
+ _C.TRAIN.USE_CHECKPOINT = False
139
+
140
+ # LR scheduler
141
+ _C.TRAIN.LR_SCHEDULER = CN()
142
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
143
+ # Epoch interval to decay LR, used in StepLRScheduler
144
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
145
+ # LR decay rate, used in StepLRScheduler
146
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
147
+
148
+ # Optimizer
149
+ _C.TRAIN.OPTIMIZER = CN()
150
+ _C.TRAIN.OPTIMIZER.NAME = 'adamw'
151
+ # Optimizer Epsilon
152
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
153
+ # Optimizer Betas
154
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
155
+ # SGD momentum
156
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
157
+
158
+ # -----------------------------------------------------------------------------
159
+ # Augmentation settings
160
+ # -----------------------------------------------------------------------------
161
+ _C.AUG = CN()
162
+ # Color jitter factor
163
+ _C.AUG.COLOR_JITTER = 0.4
164
+ # Use AutoAugment policy. "v0" or "original"
165
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
166
+ # Random erase prob
167
+ _C.AUG.REPROB = 0.25
168
+ # Random erase mode
169
+ _C.AUG.REMODE = 'pixel'
170
+ # Random erase count
171
+ _C.AUG.RECOUNT = 1
172
+ # Mixup alpha, mixup enabled if > 0
173
+ _C.AUG.MIXUP = 0.8
174
+ # Cutmix alpha, cutmix enabled if > 0
175
+ _C.AUG.CUTMIX = 1.0
176
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
177
+ _C.AUG.CUTMIX_MINMAX = None
178
+ # Probability of performing mixup or cutmix when either/both is enabled
179
+ _C.AUG.MIXUP_PROB = 1.0
180
+ # Probability of switching to cutmix when both mixup and cutmix enabled
181
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
182
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
183
+ _C.AUG.MIXUP_MODE = 'batch'
184
+
185
+ # -----------------------------------------------------------------------------
186
+ # Testing settings
187
+ # -----------------------------------------------------------------------------
188
+ _C.TEST = CN()
189
+ # Whether to use center crop when testing
190
+ _C.TEST.CROP = True
191
+
192
+ # -----------------------------------------------------------------------------
193
+ # Misc
194
+ # -----------------------------------------------------------------------------
195
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
196
+ # overwritten by command line argument
197
+ _C.AMP_OPT_LEVEL = ''
198
+ # Path to output folder, overwritten by command line argument
199
+ _C.OUTPUT = ''
200
+ # Tag of experiment, overwritten by command line argument
201
+ _C.TAG = 'default'
202
+ # Frequency to save checkpoint
203
+ _C.SAVE_FREQ = 1
204
+ # Frequency to logging info
205
+ _C.PRINT_FREQ = 100
206
+ # Fixed random seed
207
+ _C.SEED = 0
208
+ # Perform evaluation only, overwritten by command line argument
209
+ _C.EVAL_MODE = False
210
+ # Test throughput only, overwritten by command line argument
211
+ _C.THROUGHPUT_MODE = False
212
+ # Debug only so that skip dataloader initialization, overwritten by command line argument
213
+ _C.DEBUG_MODE = False
214
+ # local rank for DistributedDataParallel, given by command line argument
215
+ _C.LOCAL_RANK = 0
216
+
217
+
218
+ def _update_config_from_file(config, cfg_file):
219
+ config.defrost()
220
+ with open(cfg_file, 'r') as f:
221
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
222
+
223
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
224
+ if cfg:
225
+ _update_config_from_file(
226
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
227
+ )
228
+ print('=> merge config from {}'.format(cfg_file))
229
+ config.merge_from_file(cfg_file)
230
+ config.freeze()
231
+
232
+
233
+ def update_config(config, args):
234
+ _update_config_from_file(config, args.cfg)
235
+ config.freeze()
236
+
237
+
238
+ def get_config(args):
239
+ """Get a yacs CfgNode object with default values."""
240
+ # Return a clone so that the defaults will not be altered
241
+ # This is for the "local variable" use pattern
242
+ config = _C.clone()
243
+ update_config(config, args)
244
+
245
+ return config
configs/unicl_focalnet_giant.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NAME: unicl_focalnet_giant
3
+ DIM_PROJECTION: 1024
4
+ IMAGE_ENCODER:
5
+ TYPE: focalnet_giant_lrf
6
+ DROP_PATH_RATE: 0.5
7
+ FOCAL:
8
+ USE_POSTLN: False
9
+ USE_CONV_EMBED: False
10
+ EMBED_DIM: 512
11
+ USE_LAYERSCALE: True
12
+ TEXT_ENCODER:
13
+ NAME: 'transformer'
14
+ WIDTH: 1024
15
+ HEADS: 16
16
+ LAYERS: 16
configs/unicl_swin_base.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NAME: unicl_swin_base
3
+ DIM_PROJECTION: 512
4
+ IMAGE_ENCODER:
5
+ TYPE: swin
6
+ DROP_PATH_RATE: 0.5
7
+ SWIN:
8
+ EMBED_DIM: 128
9
+ DEPTHS: [ 2, 2, 18, 2 ]
10
+ NUM_HEADS: [ 4, 8, 16, 32 ]
11
+ WINDOW_SIZE: 7
12
+ TEXT_ENCODER:
13
+ NAME: 'transformer'
14
+ WIDTH: 512
15
+ HEADS: 8
16
+ LAYERS: 12
configs/unicl_swin_tiny.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ NAME: unicl_swin_tiny
3
+ DIM_PROJECTION: 512
4
+ IMAGE_ENCODER:
5
+ TYPE: swin
6
+ DROP_PATH_RATE: 0.2
7
+ SWIN:
8
+ EMBED_DIM: 96
9
+ DEPTHS: [ 2, 2, 6, 2 ]
10
+ NUM_HEADS: [ 3, 6, 12, 24 ]
11
+ WINDOW_SIZE: 7
12
+ TEXT_ENCODER:
13
+ NAME: 'transformer'
14
+ WIDTH: 512
15
+ HEADS: 8
16
+ LAYERS: 12
crowd2.jpg ADDED
elephants.png ADDED
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import build_unicl_model as build_model
model/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (190 Bytes). View file
 
model/__pycache__/model.cpython-39.pyc ADDED
Binary file (6.82 kB). View file
 
model/__pycache__/templates.cpython-39.pyc ADDED
Binary file (1.99 kB). View file
 
model/image_encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .build import build_model as build_image_encoder
model/image_encoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (201 Bytes). View file
 
model/image_encoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (206 Bytes). View file
 
model/image_encoder/__pycache__/build.cpython-38.pyc ADDED
Binary file (1.13 kB). View file
 
model/image_encoder/__pycache__/build.cpython-39.pyc ADDED
Binary file (1.36 kB). View file
 
model/image_encoder/__pycache__/focalnet.cpython-38.pyc ADDED
Binary file (19.6 kB). View file
 
model/image_encoder/__pycache__/focalnet.cpython-39.pyc ADDED
Binary file (19.8 kB). View file
 
model/image_encoder/__pycache__/swin_transformer.cpython-38.pyc ADDED
Binary file (19.9 kB). View file
 
model/image_encoder/__pycache__/swin_transformer.cpython-39.pyc ADDED
Binary file (19.8 kB). View file
 
model/image_encoder/build.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm.models import create_model
2
+ from .swin_transformer import SwinTransformer
3
+ from . import focalnet
4
+
5
+ def build_model(config):
6
+ model_type = config.TYPE
7
+ print(f"Creating model: {model_type}")
8
+
9
+ if "swin" in model_type:
10
+ model = SwinTransformer(
11
+ num_classes=0,
12
+ img_size=config.IMG_SIZE,
13
+ patch_size=config.SWIN.PATCH_SIZE,
14
+ in_chans=config.SWIN.IN_CHANS,
15
+ embed_dim=config.SWIN.EMBED_DIM,
16
+ depths=config.SWIN.DEPTHS,
17
+ num_heads=config.SWIN.NUM_HEADS,
18
+ window_size=config.SWIN.WINDOW_SIZE,
19
+ mlp_ratio=config.SWIN.MLP_RATIO,
20
+ qkv_bias=config.SWIN.QKV_BIAS,
21
+ qk_scale=config.SWIN.QK_SCALE,
22
+ drop_rate=config.DROP_RATE,
23
+ drop_path_rate=config.DROP_PATH_RATE,
24
+ ape=config.SWIN.APE,
25
+ patch_norm=config.SWIN.PATCH_NORM,
26
+ use_checkpoint=False
27
+ )
28
+ elif "focal" in model_type:
29
+ model = create_model(
30
+ model_type,
31
+ pretrained=False,
32
+ img_size=config.IMG_SIZE,
33
+ num_classes=0,
34
+ drop_path_rate=config.DROP_PATH_RATE,
35
+ use_conv_embed=config.FOCAL.USE_CONV_EMBED,
36
+ use_layerscale=config.FOCAL.USE_LAYERSCALE,
37
+ use_postln=config.FOCAL.USE_POSTLN
38
+ )
39
+
40
+ elif "vit" in model_type:
41
+ model = create_model(
42
+ model_type,
43
+ pretrained=is_pretrained,
44
+ img_size=config.DATA.IMG_SIZE,
45
+ num_classes=config.MODEL.NUM_CLASSES,
46
+ )
47
+ elif "resnet" in model_type:
48
+ model = create_model(
49
+ model_type,
50
+ pretrained=is_pretrained,
51
+ num_classes=config.MODEL.NUM_CLASSES
52
+ )
53
+ else:
54
+ model = create_model(
55
+ model_type,
56
+ pretrained=is_pretrained,
57
+ num_classes=config.MODEL.NUM_CLASSES
58
+ )
59
+ return model
model/image_encoder/focalnet.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # FocalNets -- Focal Modulation Networks
3
+ # Copyright (c) 2022 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Jianwei Yang ([email protected])
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.utils.checkpoint as checkpoint
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+ from timm.models.registry import register_model
14
+
15
+ from torchvision import transforms
16
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
17
+ from timm.data import create_transform
18
+ from timm.data.transforms import _pil_interp
19
+
20
+ class Mlp(nn.Module):
21
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
22
+ super().__init__()
23
+ out_features = out_features or in_features
24
+ hidden_features = hidden_features or in_features
25
+ self.fc1 = nn.Linear(in_features, hidden_features)
26
+ self.act = act_layer()
27
+ self.fc2 = nn.Linear(hidden_features, out_features)
28
+ self.drop = nn.Dropout(drop)
29
+
30
+ def forward(self, x):
31
+ x = self.fc1(x)
32
+ x = self.act(x)
33
+ x = self.drop(x)
34
+ x = self.fc2(x)
35
+ x = self.drop(x)
36
+ return x
37
+
38
+ class FocalModulation(nn.Module):
39
+ def __init__(self, dim, focal_window, focal_level, focal_factor=2, bias=True, proj_drop=0.):
40
+ super().__init__()
41
+
42
+ self.dim = dim
43
+ self.focal_window = focal_window
44
+ self.focal_level = focal_level
45
+ self.focal_factor = focal_factor
46
+
47
+ self.f = nn.Linear(dim, 2*dim + (self.focal_level+1), bias=bias)
48
+ self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
49
+
50
+ self.act = nn.GELU()
51
+ self.proj = nn.Linear(dim, dim)
52
+ self.proj_drop = nn.Dropout(proj_drop)
53
+ self.focal_layers = nn.ModuleList()
54
+
55
+ self.kernel_sizes = []
56
+ for k in range(self.focal_level):
57
+ kernel_size = self.focal_factor*k + self.focal_window
58
+ self.focal_layers.append(
59
+ nn.Sequential(
60
+ nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1,
61
+ groups=dim, padding=kernel_size//2, bias=False),
62
+ nn.GELU(),
63
+ )
64
+ )
65
+ self.kernel_sizes.append(kernel_size)
66
+ def forward(self, x):
67
+ """
68
+ Args:
69
+ x: input features with shape of (B, H, W, C)
70
+ """
71
+ C = x.shape[-1]
72
+
73
+ # pre linear projection
74
+ x = self.f(x).permute(0, 3, 1, 2).contiguous()
75
+ q, ctx, self.gates = torch.split(x, (C, C, self.focal_level+1), 1)
76
+
77
+ # context aggreation
78
+ ctx_all = 0
79
+ for l in range(self.focal_level):
80
+ ctx = self.focal_layers[l](ctx)
81
+ ctx_all = ctx_all + ctx*self.gates[:, l:l+1]
82
+ ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
83
+ ctx_all = ctx_all + ctx_global*self.gates[:,self.focal_level:]
84
+
85
+ # focal modulation
86
+ self.modulator = self.h(ctx_all)
87
+ x_out = q*self.modulator
88
+ x_out = x_out.permute(0, 2, 3, 1).contiguous()
89
+
90
+ # post linear porjection
91
+ x_out = self.proj(x_out)
92
+ x_out = self.proj_drop(x_out)
93
+ return x_out
94
+
95
+ def extra_repr(self) -> str:
96
+ return f'dim={self.dim}'
97
+
98
+ def flops(self, N):
99
+ # calculate flops for 1 window with token length of N
100
+ flops = 0
101
+
102
+ flops += N * self.dim * (self.dim * 2 + (self.focal_level+1))
103
+
104
+ # focal convolution
105
+ for k in range(self.focal_level):
106
+ flops += N * (self.kernel_sizes[k]**2+1) * self.dim
107
+
108
+ # global gating
109
+ flops += N * 1 * self.dim
110
+
111
+ # self.linear
112
+ flops += N * self.dim * (self.dim + 1)
113
+
114
+ # x = self.proj(x)
115
+ flops += N * self.dim * self.dim
116
+ return flops
117
+
118
+ class FocalNetBlock(nn.Module):
119
+ r""" Focal Modulation Network Block.
120
+
121
+ Args:
122
+ dim (int): Number of input channels.
123
+ input_resolution (tuple[int]): Input resulotion.
124
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
125
+ drop (float, optional): Dropout rate. Default: 0.0
126
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
127
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
128
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
129
+ focal_level (int): Number of focal levels.
130
+ focal_window (int): Focal window size at first focal level
131
+ use_layerscale (bool): Whether use layerscale
132
+ layerscale_value (float): Initial layerscale value
133
+ use_postln (bool): Whether use layernorm after modulation
134
+ """
135
+
136
+ def __init__(self, dim, input_resolution, mlp_ratio=4., drop=0., drop_path=0.,
137
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm,
138
+ focal_level=1, focal_window=3,
139
+ use_layerscale=False, layerscale_value=1e-4,
140
+ use_postln=False):
141
+ super().__init__()
142
+ self.dim = dim
143
+ self.input_resolution = input_resolution
144
+ self.mlp_ratio = mlp_ratio
145
+
146
+ self.focal_window = focal_window
147
+ self.focal_level = focal_level
148
+ self.use_postln = use_postln
149
+
150
+ self.norm1 = norm_layer(dim)
151
+ self.modulation = FocalModulation(dim, proj_drop=drop, focal_window=focal_window, focal_level=self.focal_level)
152
+
153
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
154
+ self.norm2 = norm_layer(dim)
155
+ mlp_hidden_dim = int(dim * mlp_ratio)
156
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
157
+
158
+ self.alpha = 3.0 if self.use_postln else 1.0
159
+
160
+ self.gamma_1 = 1.0
161
+ self.gamma_2 = 1.0
162
+ if use_layerscale:
163
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
164
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
165
+
166
+ self.H = None
167
+ self.W = None
168
+
169
+ def forward(self, x):
170
+ H, W = self.H, self.W
171
+ B, L, C = x.shape
172
+ shortcut = x
173
+
174
+ # Focal Modulation
175
+ if not self.use_postln:
176
+ x = self.norm1(x)
177
+ x = x.view(B, H, W, C)
178
+ x = self.modulation(x).view(B, H * W, C)
179
+
180
+ # FFN
181
+ x = shortcut*self.alpha + self.drop_path(self.gamma_1 * x)
182
+ if self.use_postln:
183
+ x = self.norm1(x)
184
+
185
+ if not self.use_postln:
186
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
187
+ else:
188
+ x = x*self.alpha + self.drop_path(self.gamma_2 * self.mlp(x))
189
+ x = self.norm2(x)
190
+
191
+ return x
192
+
193
+ def extra_repr(self) -> str:
194
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, " \
195
+ f"mlp_ratio={self.mlp_ratio}"
196
+
197
+ def flops(self):
198
+ flops = 0
199
+ H, W = self.input_resolution
200
+ # norm1
201
+ flops += self.dim * H * W
202
+
203
+ # W-MSA/SW-MSA
204
+ flops += self.modulation.flops(H*W)
205
+
206
+ # mlp
207
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
208
+ # norm2
209
+ flops += self.dim * H * W
210
+ return flops
211
+
212
+ class BasicLayer(nn.Module):
213
+ """ A basic Focal Transformer layer for one stage.
214
+
215
+ Args:
216
+ dim (int): Number of input channels.
217
+ input_resolution (tuple[int]): Input resolution.
218
+ depth (int): Number of blocks.
219
+ window_size (int): Local window size.
220
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
221
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
222
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
223
+ drop (float, optional): Dropout rate. Default: 0.0
224
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
225
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
226
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
227
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
228
+ focal_level (int): Number of focal levels
229
+ focal_window (int): Focal window size at first focal level
230
+ use_layerscale (bool): Whether use layerscale
231
+ layerscale_value (float): Initial layerscale value
232
+ use_postln (bool): Whether use layernorm after modulation
233
+ """
234
+
235
+ def __init__(self, dim, out_dim, input_resolution, depth,
236
+ mlp_ratio=4., drop=0., drop_path=0., norm_layer=nn.LayerNorm,
237
+ downsample=None, use_checkpoint=False,
238
+ focal_level=1, focal_window=1,
239
+ use_conv_embed=False,
240
+ use_layerscale=False, layerscale_value=1e-4, use_postln=False):
241
+
242
+ super().__init__()
243
+ self.dim = dim
244
+ self.input_resolution = input_resolution
245
+ self.depth = depth
246
+ self.use_checkpoint = use_checkpoint
247
+
248
+ # build blocks
249
+ self.blocks = nn.ModuleList([
250
+ FocalNetBlock(
251
+ dim=dim,
252
+ input_resolution=input_resolution,
253
+ mlp_ratio=mlp_ratio,
254
+ drop=drop,
255
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
256
+ norm_layer=norm_layer,
257
+ focal_level=focal_level,
258
+ focal_window=focal_window,
259
+ use_layerscale=use_layerscale,
260
+ layerscale_value=layerscale_value,
261
+ use_postln=use_postln,
262
+ )
263
+ for i in range(depth)])
264
+
265
+ if downsample is not None:
266
+ self.downsample = downsample(
267
+ img_size=input_resolution,
268
+ patch_size=2,
269
+ in_chans=dim,
270
+ embed_dim=out_dim,
271
+ use_conv_embed=use_conv_embed,
272
+ norm_layer=norm_layer,
273
+ is_stem=False
274
+ )
275
+ else:
276
+ self.downsample = None
277
+
278
+ def forward(self, x, H, W):
279
+ for blk in self.blocks:
280
+ blk.H, blk.W = H, W
281
+ if self.use_checkpoint:
282
+ x = checkpoint.checkpoint(blk, x)
283
+ else:
284
+ x = blk(x)
285
+
286
+ if self.downsample is not None:
287
+ x = x.transpose(1, 2).reshape(x.shape[0], -1, H, W)
288
+ x, Ho, Wo = self.downsample(x)
289
+ else:
290
+ Ho, Wo = H, W
291
+ return x, Ho, Wo
292
+
293
+ def extra_repr(self) -> str:
294
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
295
+
296
+ def flops(self):
297
+ flops = 0
298
+ for blk in self.blocks:
299
+ flops += blk.flops()
300
+ if self.downsample is not None:
301
+ flops += self.downsample.flops()
302
+ return flops
303
+
304
+ class PatchEmbed(nn.Module):
305
+ r""" Image to Patch Embedding
306
+
307
+ Args:
308
+ img_size (int): Image size. Default: 224.
309
+ patch_size (int): Patch token size. Default: 4.
310
+ in_chans (int): Number of input image channels. Default: 3.
311
+ embed_dim (int): Number of linear projection output channels. Default: 96.
312
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
313
+ """
314
+
315
+ def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, use_conv_embed=False, norm_layer=None, is_stem=False):
316
+ super().__init__()
317
+ patch_size = to_2tuple(patch_size)
318
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
319
+ self.img_size = img_size
320
+ self.patch_size = patch_size
321
+ self.patches_resolution = patches_resolution
322
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
323
+
324
+ self.in_chans = in_chans
325
+ self.embed_dim = embed_dim
326
+
327
+ if use_conv_embed:
328
+ # if we choose to use conv embedding, then we treat the stem and non-stem differently
329
+ if is_stem:
330
+ kernel_size = 7; padding = 2; stride = 4
331
+ else:
332
+ kernel_size = 3; padding = 1; stride = 2
333
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
334
+ else:
335
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
336
+
337
+ if norm_layer is not None:
338
+ self.norm = norm_layer(embed_dim)
339
+ else:
340
+ self.norm = None
341
+
342
+ def forward(self, x):
343
+ B, C, H, W = x.shape
344
+
345
+ x = self.proj(x)
346
+ H, W = x.shape[2:]
347
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
348
+ if self.norm is not None:
349
+ x = self.norm(x)
350
+ return x, H, W
351
+
352
+ def flops(self):
353
+ Ho, Wo = self.patches_resolution
354
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
355
+ if self.norm is not None:
356
+ flops += Ho * Wo * self.embed_dim
357
+ return flops
358
+
359
+ class FocalNet(nn.Module):
360
+ r""" Focal Modulation Networks (FocalNets)
361
+
362
+ Args:
363
+ img_size (int | tuple(int)): Input image size. Default 224
364
+ patch_size (int | tuple(int)): Patch size. Default: 4
365
+ in_chans (int): Number of input image channels. Default: 3
366
+ num_classes (int): Number of classes for classification head. Default: 1000
367
+ embed_dim (int): Patch embedding dimension. Default: 96
368
+ depths (tuple(int)): Depth of each Focal Transformer layer.
369
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
370
+ drop_rate (float): Dropout rate. Default: 0
371
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
372
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
373
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
374
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
375
+ focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1]
376
+ focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
377
+ use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance, but we do not use it by default. Default: False
378
+ use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False
379
+ layerscale_value (float): Value for layer scale. Default: 1e-4
380
+ use_postln (bool): Whether use layernorm after modulation (it helps stablize training of large models)
381
+ """
382
+ def __init__(self,
383
+ img_size=224,
384
+ patch_size=4,
385
+ in_chans=3,
386
+ num_classes=1000,
387
+ embed_dim=96,
388
+ depths=[2, 2, 6, 2],
389
+ mlp_ratio=4.,
390
+ drop_rate=0.,
391
+ drop_path_rate=0.1,
392
+ norm_layer=nn.LayerNorm,
393
+ patch_norm=True,
394
+ use_checkpoint=False,
395
+ focal_levels=[2, 2, 2, 2],
396
+ focal_windows=[3, 3, 3, 3],
397
+ use_conv_embed=False,
398
+ use_layerscale=False,
399
+ layerscale_value=1e-4,
400
+ use_postln=False,
401
+ **kwargs):
402
+ super().__init__()
403
+
404
+ self.num_layers = len(depths)
405
+ embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)]
406
+
407
+ self.num_classes = num_classes
408
+ self.embed_dim = embed_dim
409
+ self.patch_norm = patch_norm
410
+ self.num_features = embed_dim[-1]
411
+ self.mlp_ratio = mlp_ratio
412
+
413
+ # split image into patches using either non-overlapped embedding or overlapped embedding
414
+ self.patch_embed = PatchEmbed(
415
+ img_size=to_2tuple(img_size),
416
+ patch_size=patch_size,
417
+ in_chans=in_chans,
418
+ embed_dim=embed_dim[0],
419
+ use_conv_embed=use_conv_embed,
420
+ norm_layer=norm_layer if self.patch_norm else None,
421
+ is_stem=True)
422
+
423
+ num_patches = self.patch_embed.num_patches
424
+ patches_resolution = self.patch_embed.patches_resolution
425
+ self.patches_resolution = patches_resolution
426
+ self.pos_drop = nn.Dropout(p=drop_rate)
427
+
428
+ # stochastic depth
429
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
430
+
431
+ # build layers
432
+ self.layers = nn.ModuleList()
433
+ for i_layer in range(self.num_layers):
434
+ layer = BasicLayer(dim=embed_dim[i_layer],
435
+ out_dim=embed_dim[i_layer+1] if (i_layer < self.num_layers - 1) else None,
436
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
437
+ patches_resolution[1] // (2 ** i_layer)),
438
+ depth=depths[i_layer],
439
+ mlp_ratio=self.mlp_ratio,
440
+ drop=drop_rate,
441
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
442
+ norm_layer=norm_layer,
443
+ downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None,
444
+ focal_level=focal_levels[i_layer],
445
+ focal_window=focal_windows[i_layer],
446
+ use_conv_embed=use_conv_embed,
447
+ use_checkpoint=use_checkpoint,
448
+ use_layerscale=use_layerscale,
449
+ layerscale_value=layerscale_value,
450
+ use_postln=use_postln,
451
+ )
452
+ self.layers.append(layer)
453
+
454
+ self.norm = norm_layer(self.num_features)
455
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
456
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
457
+ self.dim_out = self.num_features
458
+
459
+ self.apply(self._init_weights)
460
+
461
+ def _init_weights(self, m):
462
+ if isinstance(m, nn.Linear):
463
+ trunc_normal_(m.weight, std=.02)
464
+ if isinstance(m, nn.Linear) and m.bias is not None:
465
+ nn.init.constant_(m.bias, 0)
466
+ elif isinstance(m, nn.LayerNorm):
467
+ nn.init.constant_(m.bias, 0)
468
+ nn.init.constant_(m.weight, 1.0)
469
+
470
+ @torch.jit.ignore
471
+ def no_weight_decay(self):
472
+ return {''}
473
+
474
+ @torch.jit.ignore
475
+ def no_weight_decay_keywords(self):
476
+ return {''}
477
+
478
+ def forward_features(self, x):
479
+ x, H, W = self.patch_embed(x)
480
+ x = self.pos_drop(x)
481
+
482
+ for layer in self.layers:
483
+ x, H, W = layer(x, H, W)
484
+ x = self.norm(x) # B L C
485
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
486
+ x = torch.flatten(x, 1)
487
+ return x
488
+
489
+ def forward(self, x):
490
+ x = self.forward_features(x)
491
+ x = self.head(x)
492
+ return x
493
+
494
+ def flops(self):
495
+ flops = 0
496
+ flops += self.patch_embed.flops()
497
+ for i, layer in enumerate(self.layers):
498
+ flops += layer.flops()
499
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
500
+ flops += self.num_features * self.num_classes
501
+ return flops
502
+
503
+ def build_transforms(img_size, center_crop=False):
504
+ t = []
505
+ if center_crop:
506
+ size = int((256 / 224) * img_size)
507
+ t.append(
508
+ transforms.Resize(size, interpolation=_pil_interp('bicubic'))
509
+ )
510
+ t.append(
511
+ transforms.CenterCrop(img_size)
512
+ )
513
+ else:
514
+ t.append(
515
+ transforms.Resize(img_size, interpolation=_pil_interp('bicubic'))
516
+ )
517
+ t.append(transforms.ToTensor())
518
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
519
+ return transforms.Compose(t)
520
+
521
+ def build_transforms4display(img_size, center_crop=False):
522
+ t = []
523
+ if center_crop:
524
+ size = int((256 / 224) * img_size)
525
+ t.append(
526
+ transforms.Resize(size, interpolation=_pil_interp('bicubic'))
527
+ )
528
+ t.append(
529
+ transforms.CenterCrop(img_size)
530
+ )
531
+ else:
532
+ t.append(
533
+ transforms.Resize(img_size, interpolation=_pil_interp('bicubic'))
534
+ )
535
+ t.append(transforms.ToTensor())
536
+ return transforms.Compose(t)
537
+
538
+ model_urls = {
539
+ "focalnet_tiny_srf": "",
540
+ "focalnet_small_srf": "",
541
+ "focalnet_base_srf": "",
542
+ "focalnet_tiny_lrf": "",
543
+ "focalnet_small_lrf": "",
544
+ "focalnet_base_lrf": "",
545
+ }
546
+
547
+ @register_model
548
+ def focalnet_tiny_srf(pretrained=False, **kwargs):
549
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, **kwargs)
550
+ if pretrained:
551
+ url = model_urls['focalnet_tiny_srf']
552
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
553
+ model.load_state_dict(checkpoint["model"])
554
+ return model
555
+
556
+ @register_model
557
+ def focalnet_small_srf(pretrained=False, **kwargs):
558
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, **kwargs)
559
+ if pretrained:
560
+ url = model_urls['focalnet_small_srf']
561
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
562
+ model.load_state_dict(checkpoint["model"])
563
+ return model
564
+
565
+ @register_model
566
+ def focalnet_base_srf(pretrained=False, **kwargs):
567
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, **kwargs)
568
+ if pretrained:
569
+ url = model_urls['focalnet_base_srf']
570
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
571
+ model.load_state_dict(checkpoint["model"])
572
+ return model
573
+
574
+ @register_model
575
+ def focalnet_tiny_lrf(pretrained=False, **kwargs):
576
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
577
+ if pretrained:
578
+ url = model_urls['focalnet_tiny_lrf']
579
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
580
+ model.load_state_dict(checkpoint["model"])
581
+ return model
582
+
583
+ @register_model
584
+ def focalnet_small_lrf(pretrained=False, **kwargs):
585
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
586
+ if pretrained:
587
+ url = model_urls['focalnet_small_lrf']
588
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
589
+ model.load_state_dict(checkpoint["model"])
590
+ return model
591
+
592
+ @register_model
593
+ def focalnet_base_lrf(pretrained=False, **kwargs):
594
+ model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
595
+ if pretrained:
596
+ url = model_urls['focalnet_base_lrf']
597
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
598
+ model.load_state_dict(checkpoint["model"])
599
+ return model
600
+
601
+ @register_model
602
+ def focalnet_giant_lrf(pretrained=False, **kwargs):
603
+ model = FocalNet(depths=[2, 2, 42, 2], embed_dim=512, focal_levels=[3, 3, 3, 3], **kwargs)
604
+ if pretrained:
605
+ url = model_urls['focalnet_giant_lrf']
606
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
607
+ model.load_state_dict(checkpoint["model"])
608
+ return model
609
+
610
+ @register_model
611
+ def focalnet_tiny_iso_16(pretrained=False, **kwargs):
612
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=192, focal_levels=[3], focal_windows=[3], **kwargs)
613
+ if pretrained:
614
+ url = model_urls['focalnet_tiny_iso_16']
615
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
616
+ model.load_state_dict(checkpoint["model"])
617
+ return model
618
+
619
+ @register_model
620
+ def focalnet_small_iso_16(pretrained=False, **kwargs):
621
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=384, focal_levels=[3], focal_windows=[3], **kwargs)
622
+ if pretrained:
623
+ url = model_urls['focalnet_small_iso_16']
624
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
625
+ model.load_state_dict(checkpoint["model"])
626
+ return model
627
+
628
+ @register_model
629
+ def focalnet_base_iso_16(pretrained=False, **kwargs):
630
+ model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True, **kwargs)
631
+ if pretrained:
632
+ url = model_urls['focalnet_base_iso_16']
633
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
634
+ model.load_state_dict(checkpoint["model"])
635
+ return model
636
+
637
+ if __name__ == '__main__':
638
+ img_size = 224
639
+ x = torch.rand(16, 3, img_size, img_size).cuda()
640
+ # model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96)
641
+ # model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], focal_factors=[2])
642
+ model = FocalNet(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3]).cuda()
643
+ print(model); model(x)
644
+
645
+ flops = model.flops()
646
+ print(f"number of GFLOPs: {flops / 1e9}")
647
+
648
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
649
+ print(f"number of params: {n_parameters}")
model/image_encoder/swin_transformer.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+
56
+ Returns:
57
+ x: (B, H, W, C)
58
+ """
59
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
60
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
61
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
62
+ return x
63
+
64
+
65
+ class WindowAttention(nn.Module):
66
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
67
+ It supports both of shifted and non-shifted window.
68
+
69
+ Args:
70
+ dim (int): Number of input channels.
71
+ window_size (tuple[int]): The height and width of the window.
72
+ num_heads (int): Number of attention heads.
73
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
74
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
75
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
76
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
77
+ """
78
+
79
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
80
+
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.window_size = window_size # Wh, Ww
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ self.scale = qk_scale or head_dim ** -0.5
87
+
88
+ # define a parameter table of relative position bias
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+
92
+ # get pair-wise relative position index for each token inside the window
93
+ coords_h = torch.arange(self.window_size[0])
94
+ coords_w = torch.arange(self.window_size[1])
95
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
96
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
97
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
98
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
99
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
100
+ relative_coords[:, :, 1] += self.window_size[1] - 1
101
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
102
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ self.register_buffer("relative_position_index", relative_position_index)
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ trunc_normal_(self.relative_position_bias_table, std=.02)
111
+ self.softmax = nn.Softmax(dim=-1)
112
+
113
+ def forward(self, x, mask=None):
114
+ """
115
+ Args:
116
+ x: input features with shape of (num_windows*B, N, C)
117
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
118
+ """
119
+ B_, N, C = x.shape
120
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
122
+
123
+ q = q * self.scale
124
+ attn = (q @ k.transpose(-2, -1))
125
+
126
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
127
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
128
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
129
+ attn = attn + relative_position_bias.unsqueeze(0)
130
+
131
+ if mask is not None:
132
+ nW = mask.shape[0]
133
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
134
+ attn = attn.view(-1, self.num_heads, N, N)
135
+ attn = self.softmax(attn)
136
+ else:
137
+ attn = self.softmax(attn)
138
+
139
+ attn = self.attn_drop(attn)
140
+
141
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
142
+ x = self.proj(x)
143
+ x = self.proj_drop(x)
144
+ return x
145
+
146
+ def extra_repr(self) -> str:
147
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
148
+
149
+ def flops(self, N):
150
+ # calculate flops for 1 window with token length of N
151
+ flops = 0
152
+ # qkv = self.qkv(x)
153
+ flops += N * self.dim * 3 * self.dim
154
+ # attn = (q @ k.transpose(-2, -1))
155
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
156
+ # x = (attn @ v)
157
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
158
+ # x = self.proj(x)
159
+ flops += N * self.dim * self.dim
160
+ return flops
161
+
162
+
163
+ class SwinTransformerBlock(nn.Module):
164
+ r""" Swin Transformer Block.
165
+
166
+ Args:
167
+ dim (int): Number of input channels.
168
+ input_resolution (tuple[int]): Input resulotion.
169
+ num_heads (int): Number of attention heads.
170
+ window_size (int): Window size.
171
+ shift_size (int): Shift size for SW-MSA.
172
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
173
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
174
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
175
+ drop (float, optional): Dropout rate. Default: 0.0
176
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
177
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
178
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
179
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
180
+ """
181
+
182
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
183
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
184
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.input_resolution = input_resolution
188
+ self.num_heads = num_heads
189
+ self.window_size = window_size
190
+ self.shift_size = shift_size
191
+ self.mlp_ratio = mlp_ratio
192
+ if min(self.input_resolution) <= self.window_size:
193
+ # if window size is larger than input resolution, we don't partition windows
194
+ self.shift_size = 0
195
+ self.window_size = min(self.input_resolution)
196
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
197
+
198
+ self.norm1 = norm_layer(dim)
199
+ self.attn = WindowAttention(
200
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
201
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
202
+
203
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
204
+ self.norm2 = norm_layer(dim)
205
+ mlp_hidden_dim = int(dim * mlp_ratio)
206
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
207
+
208
+ if self.shift_size > 0:
209
+ # calculate attention mask for SW-MSA
210
+ H, W = self.input_resolution
211
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
212
+ h_slices = (slice(0, -self.window_size),
213
+ slice(-self.window_size, -self.shift_size),
214
+ slice(-self.shift_size, None))
215
+ w_slices = (slice(0, -self.window_size),
216
+ slice(-self.window_size, -self.shift_size),
217
+ slice(-self.shift_size, None))
218
+ cnt = 0
219
+ for h in h_slices:
220
+ for w in w_slices:
221
+ img_mask[:, h, w, :] = cnt
222
+ cnt += 1
223
+
224
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
225
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
226
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
227
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
228
+ else:
229
+ attn_mask = None
230
+
231
+ self.register_buffer("attn_mask", attn_mask)
232
+
233
+ def forward(self, x):
234
+ H, W = self.input_resolution
235
+ B, L, C = x.shape
236
+ assert L == H * W, "input feature has wrong size"
237
+
238
+ shortcut = x
239
+ x = self.norm1(x)
240
+ x = x.view(B, H, W, C)
241
+
242
+ # cyclic shift
243
+ if self.shift_size > 0:
244
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
245
+ else:
246
+ shifted_x = x
247
+
248
+ # partition windows
249
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
250
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
251
+
252
+ # W-MSA/SW-MSA
253
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
254
+
255
+ # merge windows
256
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
257
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
258
+
259
+ # reverse cyclic shift
260
+ if self.shift_size > 0:
261
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
262
+ else:
263
+ x = shifted_x
264
+ x = x.view(B, H * W, C)
265
+
266
+ # FFN
267
+ x = shortcut + self.drop_path(x)
268
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
269
+
270
+ return x
271
+
272
+ def extra_repr(self) -> str:
273
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
274
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
275
+
276
+ def flops(self):
277
+ flops = 0
278
+ H, W = self.input_resolution
279
+ # norm1
280
+ flops += self.dim * H * W
281
+ # W-MSA/SW-MSA
282
+ nW = H * W / self.window_size / self.window_size
283
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
284
+ # mlp
285
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
286
+ # norm2
287
+ flops += self.dim * H * W
288
+ return flops
289
+
290
+
291
+ class PatchMerging(nn.Module):
292
+ r""" Patch Merging Layer.
293
+
294
+ Args:
295
+ input_resolution (tuple[int]): Resolution of input feature.
296
+ dim (int): Number of input channels.
297
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
298
+ """
299
+
300
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
301
+ super().__init__()
302
+ self.input_resolution = input_resolution
303
+ self.dim = dim
304
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
305
+ self.norm = norm_layer(4 * dim)
306
+
307
+ def forward(self, x):
308
+ """
309
+ x: B, H*W, C
310
+ """
311
+ H, W = self.input_resolution
312
+ B, L, C = x.shape
313
+ assert L == H * W, "input feature has wrong size"
314
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
315
+
316
+ x = x.view(B, H, W, C)
317
+
318
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
319
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
320
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
321
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
322
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
323
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
324
+
325
+ x = self.norm(x)
326
+ x = self.reduction(x)
327
+
328
+ return x
329
+
330
+ def extra_repr(self) -> str:
331
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
332
+
333
+ def flops(self):
334
+ H, W = self.input_resolution
335
+ flops = H * W * self.dim
336
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
337
+ return flops
338
+
339
+
340
+ class BasicLayer(nn.Module):
341
+ """ A basic Swin Transformer layer for one stage.
342
+
343
+ Args:
344
+ dim (int): Number of input channels.
345
+ input_resolution (tuple[int]): Input resolution.
346
+ depth (int): Number of blocks.
347
+ num_heads (int): Number of attention heads.
348
+ window_size (int): Local window size.
349
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
350
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
351
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
352
+ drop (float, optional): Dropout rate. Default: 0.0
353
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
354
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
355
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
356
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
357
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
358
+ """
359
+
360
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
361
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
362
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
363
+
364
+ super().__init__()
365
+ self.dim = dim
366
+ self.input_resolution = input_resolution
367
+ self.depth = depth
368
+ self.use_checkpoint = use_checkpoint
369
+
370
+ # build blocks
371
+ self.blocks = nn.ModuleList([
372
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
373
+ num_heads=num_heads, window_size=window_size,
374
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
375
+ mlp_ratio=mlp_ratio,
376
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
377
+ drop=drop, attn_drop=attn_drop,
378
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
379
+ norm_layer=norm_layer)
380
+ for i in range(depth)])
381
+
382
+ # patch merging layer
383
+ if downsample is not None:
384
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
385
+ else:
386
+ self.downsample = None
387
+
388
+ def forward(self, x):
389
+ for blk in self.blocks:
390
+ if self.use_checkpoint:
391
+ x = checkpoint.checkpoint(blk, x)
392
+ else:
393
+ x = blk(x)
394
+ if self.downsample is not None:
395
+ x = self.downsample(x)
396
+ return x
397
+
398
+ def extra_repr(self) -> str:
399
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
400
+
401
+ def flops(self):
402
+ flops = 0
403
+ for blk in self.blocks:
404
+ flops += blk.flops()
405
+ if self.downsample is not None:
406
+ flops += self.downsample.flops()
407
+ return flops
408
+
409
+
410
+ class PatchEmbed(nn.Module):
411
+ r""" Image to Patch Embedding
412
+
413
+ Args:
414
+ img_size (int): Image size. Default: 224.
415
+ patch_size (int): Patch token size. Default: 4.
416
+ in_chans (int): Number of input image channels. Default: 3.
417
+ embed_dim (int): Number of linear projection output channels. Default: 96.
418
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
419
+ """
420
+
421
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
422
+ super().__init__()
423
+ img_size = to_2tuple(img_size)
424
+ patch_size = to_2tuple(patch_size)
425
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
426
+ self.img_size = img_size
427
+ self.patch_size = patch_size
428
+ self.patches_resolution = patches_resolution
429
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
430
+
431
+ self.in_chans = in_chans
432
+ self.embed_dim = embed_dim
433
+
434
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
435
+ if norm_layer is not None:
436
+ self.norm = norm_layer(embed_dim)
437
+ else:
438
+ self.norm = None
439
+
440
+ def forward(self, x):
441
+ B, C, H, W = x.shape
442
+ # FIXME look at relaxing size constraints
443
+ assert H == self.img_size[0] and W == self.img_size[1], \
444
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
445
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
446
+ if self.norm is not None:
447
+ x = self.norm(x)
448
+ return x
449
+
450
+ def flops(self):
451
+ Ho, Wo = self.patches_resolution
452
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
453
+ if self.norm is not None:
454
+ flops += Ho * Wo * self.embed_dim
455
+ return flops
456
+
457
+
458
+ class SwinTransformer(nn.Module):
459
+ r""" Swin Transformer
460
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
461
+ https://arxiv.org/pdf/2103.14030
462
+
463
+ Args:
464
+ img_size (int | tuple(int)): Input image size. Default 224
465
+ patch_size (int | tuple(int)): Patch size. Default: 4
466
+ in_chans (int): Number of input image channels. Default: 3
467
+ num_classes (int): Number of classes for classification head. Default: 1000
468
+ embed_dim (int): Patch embedding dimension. Default: 96
469
+ depths (tuple(int)): Depth of each Swin Transformer layer.
470
+ num_heads (tuple(int)): Number of attention heads in different layers.
471
+ window_size (int): Window size. Default: 7
472
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
473
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
474
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
475
+ drop_rate (float): Dropout rate. Default: 0
476
+ attn_drop_rate (float): Attention dropout rate. Default: 0
477
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
478
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
479
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
480
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
481
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
482
+ """
483
+
484
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
485
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
486
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
487
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
488
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
489
+ use_checkpoint=False, **kwargs):
490
+ super().__init__()
491
+
492
+ self.num_classes = num_classes
493
+ self.num_layers = len(depths)
494
+ self.embed_dim = embed_dim
495
+ self.ape = ape
496
+ self.patch_norm = patch_norm
497
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
498
+ self.mlp_ratio = mlp_ratio
499
+
500
+ # split image into non-overlapping patches
501
+ self.patch_embed = PatchEmbed(
502
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
503
+ norm_layer=norm_layer if self.patch_norm else None)
504
+ num_patches = self.patch_embed.num_patches
505
+ patches_resolution = self.patch_embed.patches_resolution
506
+ self.patches_resolution = patches_resolution
507
+
508
+ # absolute position embedding
509
+ if self.ape:
510
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
511
+ trunc_normal_(self.absolute_pos_embed, std=.02)
512
+
513
+ self.pos_drop = nn.Dropout(p=drop_rate)
514
+
515
+ # stochastic depth
516
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
517
+
518
+ # build layers
519
+ self.layers = nn.ModuleList()
520
+ for i_layer in range(self.num_layers):
521
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
522
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
523
+ patches_resolution[1] // (2 ** i_layer)),
524
+ depth=depths[i_layer],
525
+ num_heads=num_heads[i_layer],
526
+ window_size=window_size,
527
+ mlp_ratio=self.mlp_ratio,
528
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
529
+ drop=drop_rate, attn_drop=attn_drop_rate,
530
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
531
+ norm_layer=norm_layer,
532
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
533
+ use_checkpoint=use_checkpoint)
534
+ self.layers.append(layer)
535
+
536
+ self.norm = norm_layer(self.num_features)
537
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
538
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
539
+ self.dim_out = self.num_features
540
+
541
+ self.apply(self._init_weights)
542
+
543
+ def _init_weights(self, m):
544
+ if isinstance(m, nn.Linear):
545
+ trunc_normal_(m.weight, std=.02)
546
+ if isinstance(m, nn.Linear) and m.bias is not None:
547
+ nn.init.constant_(m.bias, 0)
548
+ elif isinstance(m, nn.LayerNorm):
549
+ nn.init.constant_(m.bias, 0)
550
+ nn.init.constant_(m.weight, 1.0)
551
+
552
+ @torch.jit.ignore
553
+ def no_weight_decay(self):
554
+ return {'absolute_pos_embed'}
555
+
556
+ @torch.jit.ignore
557
+ def no_weight_decay_keywords(self):
558
+ return {'relative_position_bias_table'}
559
+
560
+ def forward_features(self, x):
561
+ x = self.patch_embed(x)
562
+ if self.ape:
563
+ x = x + self.absolute_pos_embed
564
+ x = self.pos_drop(x)
565
+
566
+ for layer in self.layers:
567
+ x = layer(x)
568
+
569
+ x = self.norm(x) # B L C
570
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
571
+ x = torch.flatten(x, 1)
572
+ return x
573
+
574
+ def forward(self, x):
575
+ x = self.forward_features(x)
576
+ x = self.head(x)
577
+ return x
578
+
579
+ def flops(self):
580
+ flops = 0
581
+ flops += self.patch_embed.flops()
582
+ for i, layer in enumerate(self.layers):
583
+ flops += layer.flops()
584
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
585
+ flops += self.num_features * self.num_classes
586
+ return flops
model/model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ import tempfile
3
+ from collections import OrderedDict
4
+ from typing import Tuple, Union
5
+ import logging
6
+ import os
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+
13
+ from timm.models.layers import DropPath, trunc_normal_
14
+
15
+ from .image_encoder import build_image_encoder
16
+ from .text_encoder import build_text_encoder
17
+ from .text_encoder import build_tokenizer
18
+ from .templates import DEFAULT_TEMPLATES
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class UniCLModel(nn.Module):
24
+ def __init__(self, config: dict,):
25
+ super().__init__()
26
+
27
+ self.conf_lang_encoder = config['MODEL']['TEXT_ENCODER']
28
+ self.tokenizer = build_tokenizer(self.conf_lang_encoder)
29
+
30
+ self.text_encoder = build_text_encoder(self.conf_lang_encoder, self.tokenizer, config['VERBOSE'])
31
+
32
+ dim_projection = config['MODEL']['DIM_PROJECTION']
33
+ if hasattr(self.text_encoder, 'dim_out'):
34
+ dim_out = self.text_encoder.dim_out
35
+ else:
36
+ with torch.no_grad():
37
+ dim_out = self.text_encoder(
38
+ torch.zeros(1,1).type(torch.LongTensor)
39
+ )['last_hidden_state'].size(2)
40
+
41
+ self.text_projection = nn.Parameter(torch.empty(dim_out, dim_projection))
42
+
43
+ self.conf_image_encoder = config['MODEL']['IMAGE_ENCODER']
44
+ self.image_encoder = build_image_encoder(self.conf_image_encoder)
45
+
46
+ self.image_projection = nn.Parameter(
47
+ torch.empty(self.image_encoder.dim_out, dim_projection)
48
+ )
49
+
50
+ self.logit_scale = nn.Parameter(torch.ones([]))
51
+
52
+ trunc_normal_(self.text_projection, std=.02)
53
+ trunc_normal_(self.image_projection, std=.02)
54
+
55
+ def _convert_old_weights(self, model_dict):
56
+ model_dict_updated = {}
57
+ for k, v in model_dict.items():
58
+ if k.startswith('visual.'):
59
+ model_dict_updated['image_encoder.'+k[7:]] = v
60
+ elif k.startswith('text.'):
61
+ model_dict_updated['lang_encoder.'+k[5:]] = v
62
+ elif k == 'vision_projection':
63
+ model_dict_updated['image_projection'] = v
64
+ elif k == 'text_projection':
65
+ model_dict_updated['text_projection'] = v
66
+ else:
67
+ model_dict_updated[k] = v
68
+
69
+ return model_dict_updated
70
+
71
+ def from_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
72
+ if not os.path.isfile(pretrained):
73
+ logger.warning(f'=> Pretrained model ({pretrained}) is not a file, skip init weight')
74
+ return
75
+
76
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
77
+ logger.info(f'=> Loading pretrained model {pretrained}')
78
+ pretrained_dict = self._convert_old_weights(pretrained_dict)
79
+ model_dict = self.state_dict()
80
+ pretrained_dict = {
81
+ k: v for k, v in pretrained_dict.items()
82
+ if k in model_dict.keys()
83
+ }
84
+ need_init_state_dict = {}
85
+ image_encoder_state_dict = {}
86
+ for k, v in pretrained_dict.items():
87
+ need_init = (
88
+ k.split('.')[0] in pretrained_layers
89
+ or pretrained_layers[0] == '*'
90
+ )
91
+
92
+ if need_init:
93
+ if k.startswith('image_encoder.'):
94
+ image_encoder_state_dict[k] = v
95
+ else:
96
+ if verbose:
97
+ logger.info(f'=> init {k} from {pretrained}')
98
+
99
+ need_init_state_dict[k] = v
100
+ self.image_encoder.from_state_dict(image_encoder_state_dict, ['*'], verbose)
101
+ self.load_state_dict(need_init_state_dict, strict=False)
102
+
103
+ @torch.jit.ignore
104
+ def no_weight_decay(self):
105
+ no_weight_decay = {'logit_scale'}
106
+ if hasattr(self.text_encoder, 'no_weight_decay'):
107
+ for k in self.text_encoder.no_weight_decay():
108
+ no_weight_decay.add('lang_encoder.'+k)
109
+
110
+ if hasattr(self.image_encoder, 'no_weight_decay'):
111
+ for k in self.image_encoder.no_weight_decay():
112
+ no_weight_decay.add('image_encoder.'+k)
113
+
114
+ return no_weight_decay
115
+
116
+ @property
117
+ def dtype(self):
118
+ return self.logit_scale.dtype
119
+
120
+ def get_imnet_embeddings(self):
121
+ templates = IMAGENET_DEFAULT_TEMPLATES[:1]
122
+ clss_embeddings = []
123
+ for clss in IMAGENET_CLASSES:
124
+ txts = [template.format(clss) for template in templates]
125
+
126
+ tokens = self.tokenizer(
127
+ txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt'
128
+ )
129
+ tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()}
130
+
131
+ clss_embedding = self.encode_text(tokens)
132
+ clss_embedding = clss_embedding.mean(dim=0)
133
+ clss_embedding /= clss_embedding.norm()
134
+ clss_embeddings.append(clss_embedding)
135
+ imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
136
+ return imnet_text_embeddings
137
+
138
+ def get_text_embeddings(self, texts):
139
+ templates = DEFAULT_TEMPLATES[:1]
140
+ clss_embeddings = []
141
+ for clss in texts:
142
+ txts = [template.format(clss) for template in templates]
143
+
144
+ tokens = self.tokenizer(
145
+ txts, padding='max_length', truncation=True, max_length=77, return_tensors='pt'
146
+ )
147
+ tokens = {key:(val.cuda() if next(self.parameters()).is_cuda else val) for key,val in tokens.items()}
148
+
149
+ clss_embedding = self.encode_text(tokens)
150
+ clss_embedding = clss_embedding.mean(dim=0)
151
+ clss_embedding /= clss_embedding.norm()
152
+ clss_embeddings.append(clss_embedding)
153
+ imnet_text_embeddings = torch.stack(clss_embeddings, dim=0)
154
+ return imnet_text_embeddings
155
+
156
+ def encode_image(self, image, norm=True):
157
+ x = self.image_encoder.forward_features(image)
158
+ x = x @ self.image_projection
159
+
160
+ if norm:
161
+ x = x / x.norm(dim=-1, keepdim=True)
162
+
163
+ return x
164
+
165
+ def encode_text(self, text, norm=True):
166
+ x = self.text_encoder(**text)
167
+ x = x['last_hidden_state']
168
+
169
+ if self.conf_lang_encoder['TOKENIZER'] == 'clip':
170
+ x = x[torch.arange(x.size(0)), text['input_ids'].argmax(dim=-1)]
171
+ else:
172
+ x = x[:, 0]
173
+
174
+ x = x @ self.text_projection
175
+
176
+ if norm:
177
+ x = x / x.norm(dim=-1, keepdim=True)
178
+
179
+ return x
180
+
181
+ def forward(self, image, text):
182
+ features_image = self.encode_image(image)
183
+ features_text = self.encode_text(text)
184
+
185
+ # cosine similarity as logits
186
+ T = self.logit_scale.exp()
187
+
188
+ return features_image, features_text, T
189
+
190
+
191
+ def build_unicl_model(config, **kwargs):
192
+ model = UniCLModel(config)
193
+ if config['MODEL']['PRETRAINED'] != '':
194
+ pretrained_path = config['MODEL']['PRETRAINED']
195
+ from ..Utils.Utils import is_valid_url, download_file
196
+ if is_valid_url(pretrained_path):
197
+ with tempfile.TemporaryDirectory() as tmp_path:
198
+ file_local_path = pathlib.Path(tmp_path) / 'base_model.pt'
199
+ download_file(pretrained_path, file_local_path)
200
+ model.from_pretrained(str(file_local_path), config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE'])
201
+ else:
202
+ model.from_pretrained(pretrained_path, config['MODEL']['PRETRAINED_LAYERS'], config['VERBOSE'])
203
+
204
+ return model
model/templates.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_TEMPLATES = [
2
+ '{}.',
3
+ 'a bad photo of a {}.',
4
+ 'a photo of many {}.',
5
+ 'a sculpture of a {}.',
6
+ 'a photo of the hard to see {}.',
7
+ 'a low resolution photo of the {}.',
8
+ 'a rendering of a {}.',
9
+ 'graffiti of a {}.',
10
+ 'a bad photo of the {}.',
11
+ 'a cropped photo of the {}.',
12
+ 'a tattoo of a {}.',
13
+ 'the embroidered {}.',
14
+ 'a photo of a hard to see {}.',
15
+ 'a bright photo of a {}.',
16
+ 'a photo of a clean {}.',
17
+ 'a photo of a dirty {}.',
18
+ 'a dark photo of the {}.',
19
+ 'a drawing of a {}.',
20
+ 'a photo of my {}.',
21
+ 'the plastic {}.',
22
+ 'a photo of the cool {}.',
23
+ 'a close-up photo of a {}.',
24
+ 'a black and white photo of the {}.',
25
+ 'a painting of the {}.',
26
+ 'a painting of a {}.',
27
+ 'a pixelated photo of the {}.',
28
+ 'a sculpture of the {}.',
29
+ 'a bright photo of the {}.',
30
+ 'a cropped photo of a {}.',
31
+ 'a plastic {}.',
32
+ 'a photo of the dirty {}.',
33
+ 'a jpeg corrupted photo of a {}.',
34
+ 'a blurry photo of the {}.',
35
+ 'a photo of the {}.',
36
+ 'a good photo of the {}.',
37
+ 'a rendering of the {}.',
38
+ 'a {} in a video game.',
39
+ 'a photo of one {}.',
40
+ 'a doodle of a {}.',
41
+ 'a close-up photo of the {}.',
42
+ 'a photo of a {}.',
43
+ 'the origami {}.',
44
+ 'the {} in a video game.',
45
+ 'a sketch of a {}.',
46
+ 'a doodle of the {}.',
47
+ 'a origami {}.',
48
+ 'a low resolution photo of a {}.',
49
+ 'the toy {}.',
50
+ 'a rendition of the {}.',
51
+ 'a photo of the clean {}.',
52
+ 'a photo of a large {}.',
53
+ 'a rendition of a {}.',
54
+ 'a photo of a nice {}.',
55
+ 'a photo of a weird {}.',
56
+ 'a blurry photo of a {}.',
57
+ 'a cartoon {}.',
58
+ 'art of a {}.',
59
+ 'a sketch of the {}.',
60
+ 'a embroidered {}.',
61
+ 'a pixelated photo of a {}.',
62
+ 'itap of the {}.',
63
+ 'a jpeg corrupted photo of the {}.',
64
+ 'a good photo of a {}.',
65
+ 'a plushie {}.',
66
+ 'a photo of the nice {}.',
67
+ 'a photo of the small {}.',
68
+ 'a photo of the weird {}.',
69
+ 'the cartoon {}.',
70
+ 'art of the {}.',
71
+ 'a drawing of the {}.',
72
+ 'a photo of the large {}.',
73
+ 'a black and white photo of a {}.',
74
+ 'the plushie {}.',
75
+ 'a dark photo of a {}.',
76
+ 'itap of a {}.',
77
+ 'graffiti of the {}.',
78
+ 'a toy {}.',
79
+ 'itap of my {}.',
80
+ 'a photo of a cool {}.',
81
+ 'a photo of a small {}.',
82
+ 'a tattoo of the {}.',
83
+ ]
model/text_encoder/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ from .build import build_lang_encoder as build_text_encoder
6
+ from .build import build_tokenizer
7
+
8
+ from .transformer import *
9
+ from .hf_model import *
model/text_encoder/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (415 Bytes). View file
 
model/text_encoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (420 Bytes). View file
 
model/text_encoder/__pycache__/build.cpython-38.pyc ADDED
Binary file (1.01 kB). View file
 
model/text_encoder/__pycache__/build.cpython-39.pyc ADDED
Binary file (1.01 kB). View file
 
model/text_encoder/__pycache__/hf_model.cpython-38.pyc ADDED
Binary file (786 Bytes). View file
 
model/text_encoder/__pycache__/hf_model.cpython-39.pyc ADDED
Binary file (791 Bytes). View file
 
model/text_encoder/__pycache__/registry.cpython-38.pyc ADDED
Binary file (598 Bytes). View file
 
model/text_encoder/__pycache__/registry.cpython-39.pyc ADDED
Binary file (603 Bytes). View file
 
model/text_encoder/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (6.79 kB). View file
 
model/text_encoder/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (6.78 kB). View file
 
model/text_encoder/build.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import CLIPTokenizer
4
+ from transformers import AutoTokenizer
5
+
6
+ from .registry import lang_encoders
7
+ from .registry import is_lang_encoder
8
+
9
+
10
+ def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
11
+ model_name = config_encoder['NAME']
12
+
13
+ if not is_lang_encoder(model_name):
14
+ raise ValueError(f'Unknown model: {model_name}')
15
+
16
+ return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs)
17
+
18
+
19
+ def build_tokenizer(config_encoder):
20
+ tokenizer = None
21
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
22
+ if config_encoder['TOKENIZER'] == 'clip':
23
+ pretrained_tokenizer = config_encoder.get(
24
+ 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32'
25
+ )
26
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer)
27
+ tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token})
28
+ else:
29
+ tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER'])
30
+
31
+ return tokenizer
model/text_encoder/hf_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from transformers import AutoConfig
4
+ from transformers import AutoModel
5
+
6
+ from .registry import register_lang_encoder
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ @register_lang_encoder
12
+ def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
13
+
14
+ hf_model = None
15
+ if config_encoder['LOAD_PRETRAINED']:
16
+ hf_model = AutoModel.from_pretrained(config_encoder['HF_MODEL'])
17
+ else:
18
+ hf_config = AutoConfig.from_pretrained(config_encoder['HF_MODEL'])
19
+
20
+ if 'CONFIG_OVERRIDE' in config_encoder:
21
+ logger.warning(f'Override config: {config_encoder["CONFIG_OVERRIDE"]}')
22
+ hf_config.update(config_encoder['CONFIG_OVERRIDE'])
23
+
24
+ logger.info(f'HF model config: {hf_config}')
25
+ hf_model = AutoModel.from_config(hf_config)
26
+
27
+ return hf_model
model/text_encoder/registry.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _lang_encoders = {}
2
+
3
+
4
+ def register_lang_encoder(fn):
5
+ module_name_split = fn.__module__.split('.')
6
+ model_name = module_name_split[-1]
7
+
8
+ _lang_encoders[model_name] = fn
9
+
10
+ return fn
11
+
12
+
13
+ def lang_encoders(model_name):
14
+ return _lang_encoders[model_name]
15
+
16
+
17
+ def is_lang_encoder(model_name):
18
+ return model_name in _lang_encoders
model/text_encoder/transformer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+ import logging
4
+ import os
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from timm.models.layers import DropPath, trunc_normal_
12
+
13
+ from .registry import register_lang_encoder
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, hidden_size, eps=1e-12):
19
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
20
+ """
21
+ super(LayerNorm, self).__init__()
22
+ self.weight = nn.Parameter(torch.ones(hidden_size))
23
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
24
+ self.variance_epsilon = eps
25
+
26
+ def forward(self, x):
27
+ pdtype = x.dtype
28
+ x = x.float()
29
+ u = x.mean(-1, keepdim=True)
30
+ s = (x - u).pow(2).mean(-1, keepdim=True)
31
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
32
+ return self.weight * x.to(pdtype) + self.bias
33
+
34
+
35
+ class QuickGELU(nn.Module):
36
+ def forward(self, x: torch.Tensor):
37
+ return x * torch.sigmoid(1.702 * x)
38
+
39
+
40
+ class ResidualAttentionBlock(nn.Module):
41
+ def __init__(self,
42
+ d_model: int,
43
+ n_head: int,
44
+ attn_mask: torch.Tensor = None,
45
+ drop_path: float = 0.0):
46
+ super().__init__()
47
+
48
+ self.attn = nn.MultiheadAttention(d_model, n_head)
49
+ self.ln_1 = LayerNorm(d_model)
50
+ self.mlp = nn.Sequential(OrderedDict([
51
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
52
+ ("gelu", QuickGELU()),
53
+ ("c_proj", nn.Linear(d_model * 4, d_model))
54
+ ]))
55
+ self.ln_2 = LayerNorm(d_model)
56
+ self.attn_mask = attn_mask
57
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
58
+
59
+ def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
60
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \
61
+ if self.attn_mask is not None else None
62
+
63
+
64
+ return self.attn(
65
+ x, x, x,
66
+ key_padding_mask=key_padding_mask,
67
+ need_weights=False,
68
+ attn_mask=self.attn_mask
69
+ )[0]
70
+
71
+ def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None):
72
+ x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask))
73
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
74
+ return x
75
+
76
+
77
+ class Transformer(nn.Module):
78
+ def __init__(self,
79
+ context_length: int,
80
+ vocab_size: int,
81
+ width: int,
82
+ layers: int,
83
+ heads: int,
84
+ drop_path: float = 0.0,
85
+ autogressive: bool =True):
86
+ super().__init__()
87
+
88
+ self.token_embedding = nn.Embedding(vocab_size, width)
89
+
90
+ self.context_length = context_length
91
+ self.positional_embedding = nn.Parameter(
92
+ torch.empty(self.context_length, width)
93
+ )
94
+
95
+ self.width = width
96
+ self.layers = layers
97
+ self.autogressive = autogressive
98
+ attn_mask = self.build_attention_mask() if autogressive else None
99
+ dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule
100
+ self.resblocks = nn.ModuleList(
101
+ [
102
+ ResidualAttentionBlock(width, heads, attn_mask, dpr[i])
103
+ for i in range(layers)
104
+ ]
105
+ )
106
+
107
+ self.ln_final = LayerNorm(width)
108
+
109
+ trunc_normal_(self.positional_embedding, std=.02)
110
+ # nn.init.normal_(self.token_embedding, std=.02)
111
+ trunc_normal_(self.token_embedding.weight, std=.02)
112
+ self.apply(self._init_weights)
113
+
114
+ @property
115
+ def dim_out(self):
116
+ return self.width
117
+
118
+ def build_attention_mask(self):
119
+ # lazily create causal attention mask, with full attention between the vision tokens
120
+ # pytorch uses additive attention mask; fill with -inf
121
+ mask = torch.empty(self.context_length, self.context_length)
122
+ mask.fill_(float("-inf"))
123
+ mask.triu_(1) # zero out the lower diagonal
124
+ return mask
125
+
126
+ def _init_weights(self, m):
127
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
128
+ logger.info('=> init weight of Linear/Conv2d from trunc norm')
129
+ trunc_normal_(m.weight, std=0.02)
130
+ if m.bias is not None:
131
+ logger.info('=> init bias of Linear/Conv2d to zeros')
132
+ nn.init.constant_(m.bias, 0)
133
+ elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
134
+ nn.init.constant_(m.bias, 0)
135
+
136
+ def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True):
137
+ if os.path.isfile(pretrained):
138
+ pretrained_dict = torch.load(pretrained, map_location='cpu')
139
+ logging.info(f'=> loading pretrained model {pretrained}')
140
+ model_dict = self.state_dict()
141
+ pretrained_dict = {
142
+ k: v for k, v in pretrained_dict.items()
143
+ if k in model_dict.keys()
144
+ }
145
+ need_init_state_dict = {}
146
+ for k, v in pretrained_dict.items():
147
+ need_init = (
148
+ k.split('.')[0] in pretrained_layers
149
+ or pretrained_layers[0] == '*'
150
+ )
151
+ if need_init:
152
+ if verbose:
153
+ logging.info(f'=> init {k} from {pretrained}')
154
+
155
+ need_init_state_dict[k] = v
156
+ self.load_state_dict(need_init_state_dict, strict=False)
157
+
158
+
159
+ @torch.jit.ignore
160
+ def no_weight_decay(self):
161
+ return {
162
+ 'positional_embedding',
163
+ 'token_embedding',
164
+ }
165
+
166
+ def forward(self, input_ids, attention_mask=None):
167
+ key_padding_mask = (input_ids == 0) if not self.autogressive else None
168
+ x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model]
169
+ x = x + self.positional_embedding
170
+ x = x.permute(1, 0, 2) # NLD -> LND
171
+ for block in self.resblocks:
172
+ x = block(x, key_padding_mask)
173
+ x = x.permute(1, 0, 2) # LND -> NLD
174
+
175
+ x = self.ln_final(x)
176
+
177
+ return {'last_hidden_state': x}
178
+
179
+
180
+ @register_lang_encoder
181
+ def lang_encoder(config_encoder, tokenizer, verbose, **kwargs):
182
+ transformer = Transformer(
183
+ context_length=config_encoder['CONTEXT_LENGTH'],
184
+ vocab_size=tokenizer.vocab_size,
185
+ width=config_encoder['WIDTH'],
186
+ layers=config_encoder['LAYERS'],
187
+ heads=config_encoder['HEADS'],
188
+ autogressive=config_encoder.get('AUTOGRESSIVE', True)
189
+ )
190
+
191
+ if config_encoder['LOAD_PRETRAINED']:
192
+ transformer.load_pretrained()
193
+
194
+ return transformer
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchvision==0.11.2
3
+ opencv-python-headless==4.5.3.56
4
+ timm==0.4.12
5
+ numpy
6
+ yacs
7
+ transformers