File size: 2,257 Bytes
8fa1f84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn

from transformers import CLIPImageProcessor
from .clip import CLIP

class CLIPVisionTower(nn.Module):
    def __init__(self, args, img_size=512, delay_load=False):
        super().__init__()

        # test
        if hasattr(args, 'mm_vision_tower'):
            self.clip_model = args.mm_vision_tower
        else: # train
            self.clip_model = args.vision_tower
        self.is_loaded = False
        self.img_size = img_size

        if not delay_load:
            self.load_model()

    def load_model(self):
        self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":self.img_size}, resample=3,  do_center_crop=True, crop_size={"height": self.img_size, "width": self.img_size},
                                                  do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
                                                  image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )

        self.vision_tower = CLIP()

        self.vision_tower.load_state_dict(torch.load(self.clip_model),strict=False)
        
        self.is_loaded = True

    @torch.no_grad()
    def forward(self, images):
        if type(images) is list:
            image_features = []
            image_features_dict = []
            for image in images:
                image_feature_dict = self.vision_tower(image.unsqueeze(0))
                image_features_dict.append(image_feature_dict)
                image_feature = image_feature_dict['res4']
                image_feature = image_feature.reshape(*image_feature.shape[:2],-1).permute(0,2,1)
                image_features.append(image_feature)
        else:
            # print(images.device)
            # print(self.vision_tower.device)
            image_features_dict = self.vision_tower(images)
            image_features = image_features_dict['res4']
            image_features = image_features.reshape(*image_features.shape[:2],-1).permute(0,2,1)

        return image_features, image_features_dict

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device