Zevin2023's picture
add online demo
8fa1f84
raw
history blame
2.26 kB
import torch
import torch.nn as nn
from transformers import CLIPImageProcessor
from .clip import CLIP
class CLIPVisionTower(nn.Module):
def __init__(self, args, img_size=512, delay_load=False):
super().__init__()
# test
if hasattr(args, 'mm_vision_tower'):
self.clip_model = args.mm_vision_tower
else: # train
self.clip_model = args.vision_tower
self.is_loaded = False
self.img_size = img_size
if not delay_load:
self.load_model()
def load_model(self):
self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":self.img_size}, resample=3, do_center_crop=True, crop_size={"height": self.img_size, "width": self.img_size},
do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
self.vision_tower = CLIP()
self.vision_tower.load_state_dict(torch.load(self.clip_model),strict=False)
self.is_loaded = True
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
image_features_dict = []
for image in images:
image_feature_dict = self.vision_tower(image.unsqueeze(0))
image_features_dict.append(image_feature_dict)
image_feature = image_feature_dict['res4']
image_feature = image_feature.reshape(*image_feature.shape[:2],-1).permute(0,2,1)
image_features.append(image_feature)
else:
# print(images.device)
# print(self.vision_tower.device)
image_features_dict = self.vision_tower(images)
image_features = image_features_dict['res4']
image_features = image_features.reshape(*image_features.shape[:2],-1).permute(0,2,1)
return image_features, image_features_dict
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device