jadechoghari commited on
Commit
e376079
1 Parent(s): 806ce4b

Create clip_encoder.py

Browse files
Files changed (1) hide show
  1. clip_encoder.py +195 -0
clip_encoder.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5
+
6
+ # Added for customized Processor.
7
+ import math
8
+ import numpy as np
9
+ from typing import Dict
10
+ from transformers.image_utils import PILImageResampling, ChannelDimension
11
+ from transformers.image_processing_utils import get_size_dict
12
+ from transformers.image_transforms import (
13
+ get_resize_output_image_size,
14
+ resize,
15
+ )
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+
19
+ class CLIPImageProcessor_Ferret(CLIPImageProcessor):
20
+ def resize(
21
+ self,
22
+ image: np.ndarray,
23
+ size: Dict[str, int],
24
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
25
+ data_format: Optional[Union[str, ChannelDimension]] = None,
26
+ **kwargs,
27
+ ) -> np.ndarray:
28
+ """
29
+ Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
30
+ resized to keep the input aspect ratio.
31
+ Args:
32
+ image (`np.ndarray`):
33
+ Image to resize.
34
+ size (`Dict[str, int]`):
35
+ Size of the output image.
36
+ resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
37
+ Resampling filter to use when resiizing the image.
38
+ data_format (`str` or `ChannelDimension`, *optional*):
39
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
40
+ """
41
+ size = get_size_dict(size, default_to_square=True, height_width_order=True)
42
+ # Hack: Bypass the shortest_edge detection. We hope to get a {"height": size[0], "width": size[1]}, where w=h.
43
+ # if "shortest_edge" not in size:
44
+ # raise ValueError(f"The `size` parameter must contain the key `shortest_edge`. Got {size.keys()}")
45
+ # output_size = get_resize_output_image_size(image, size=size["shortest_edge"], default_to_square=True)
46
+ output_size = get_resize_output_image_size(image, size=(size["height"], size["width"]), default_to_square=True)
47
+ return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
48
+
49
+
50
+ class CLIPVisionTower(nn.Module):
51
+ def __init__(self, vision_tower, args, delay_load=False):
52
+ super().__init__()
53
+
54
+ self.is_loaded = False
55
+
56
+ self.preprocess_type = getattr(args, 'version', 'ferret_v1')
57
+ self.vision_tower_name = vision_tower
58
+ self.select_layer = args.mm_vision_select_layer
59
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
60
+
61
+ if not delay_load:
62
+ self.load_model()
63
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
64
+ self.load_model()
65
+ else:
66
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
67
+
68
+ def load_model(self, device_map=None):
69
+ if self.is_loaded:
70
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
71
+ return
72
+
73
+ if "ferret" in self.preprocess_type:
74
+ self.image_processor = CLIPImageProcessor_Ferret.from_pretrained(self.vision_tower_name)
75
+ else:
76
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
77
+
78
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
79
+ self.vision_tower.requires_grad_(False)
80
+
81
+ self.is_loaded = True
82
+
83
+ def feature_select(self, image_forward_outs):
84
+ image_features = image_forward_outs.hidden_states[self.select_layer]
85
+ if self.select_feature == 'patch':
86
+ image_features = image_features[:, 1:]
87
+ elif self.select_feature == 'cls_patch':
88
+ image_features = image_features
89
+ else:
90
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
91
+ return image_features
92
+
93
+ # @torch.no_grad()
94
+ def forward(self, images):
95
+ if type(images) is list:
96
+ image_features = []
97
+ for image in images:
98
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
99
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
100
+ image_features.append(image_feature)
101
+ else:
102
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
103
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
104
+
105
+ return image_features
106
+
107
+ @property
108
+ def dummy_feature(self):
109
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
110
+
111
+ @property
112
+ def dtype(self):
113
+ return self.vision_tower.dtype
114
+
115
+ @property
116
+ def device(self):
117
+ return self.vision_tower.device
118
+
119
+ @property
120
+ def config(self):
121
+ if self.is_loaded:
122
+ return self.vision_tower.config
123
+ else:
124
+ return self.cfg_only
125
+
126
+ @property
127
+ def hidden_size(self):
128
+ return self.config.hidden_size
129
+
130
+ @property
131
+ def num_patches_per_side(self):
132
+ return self.config.image_size // self.config.patch_size
133
+
134
+ @property
135
+ def num_patches(self):
136
+ return (self.config.image_size // self.config.patch_size) ** 2
137
+
138
+
139
+
140
+ class CLIPVisionTowerS2(CLIPVisionTower):
141
+ def __init__(self, vision_tower, args, delay_load=False):
142
+ super().__init__(vision_tower, args, delay_load)
143
+
144
+ self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
145
+ self.s2_scales = list(map(int, self.s2_scales.split(',')))
146
+ self.s2_scales.sort()
147
+ self.s2_split_size = self.s2_scales[0]
148
+ self.s2_image_size = self.s2_scales[-1]
149
+
150
+ try:
151
+ from s2wrapper import forward as multiscale_forward
152
+ except ImportError:
153
+ raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
154
+ self.multiscale_forward = multiscale_forward
155
+
156
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
157
+ if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
158
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
159
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
160
+
161
+ def load_model(self, device_map=None):
162
+ if self.is_loaded:
163
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
164
+ return
165
+
166
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
167
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
168
+ self.vision_tower.requires_grad_(False)
169
+
170
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
171
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
172
+
173
+ self.is_loaded = True
174
+
175
+ @torch.no_grad()
176
+ def forward_feature(self, images):
177
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
178
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
179
+ return image_features
180
+
181
+ @torch.no_grad()
182
+ def forward(self, images):
183
+ if type(images) is list:
184
+ image_features = []
185
+ for image in images:
186
+ image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
187
+ image_features.append(image_feature)
188
+ else:
189
+ image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
190
+
191
+ return image_features
192
+
193
+ @property
194
+ def hidden_size(self):
195
+ return self.config.hidden_size * len(self.s2_scales)