|
import timm |
|
from timm.models._factory import load_checkpoint |
|
import torch |
|
import os |
|
from typing import List, Union |
|
from torch import nn |
|
from torch.jit import Final |
|
from einops import rearrange, repeat |
|
from einops.layers.torch import Rearrange |
|
from utils.dl.common.model import get_model_device, set_module |
|
import torch.nn.functional as F |
|
from utils.common.log import logger |
|
|
|
from transformers import CLIPProcessor, CLIPModel, CLIPVisionConfig, CLIPConfig |
|
from dnns.clip.custom_clip import CLIPModelCanReceiveTextEmbeds |
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
class Clip_ViTB16(nn.Module): |
|
def __init__(self, img_size): |
|
super(Clip_ViTB16, self).__init__() |
|
|
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") |
|
self.model: CLIPModel = CLIPModelCanReceiveTextEmbeds.from_pretrained("openai/clip-vit-base-patch16") |
|
|
|
self.img_size = img_size |
|
|
|
|
|
vm_embed = self.model.vision_model.embeddings |
|
raw_num_patches = vm_embed.num_patches |
|
vm_embed.num_patches = (img_size // self.model.vision_model.embeddings.patch_size) ** 2 |
|
vm_embed.num_positions = vm_embed.num_patches + 1 |
|
vm_embed.register_buffer("position_ids", torch.arange(vm_embed.num_positions).expand((1, -1)), persistent=False) |
|
|
|
logger.info(f'due to changed input image size ({img_size}), num patches are updated from {raw_num_patches} to {vm_embed.num_patches}') |
|
|
|
self.first_inference = True |
|
|
|
def forward(self, images, texts: Union[List[List[str]], torch.Tensor], for_training, disable_return_loss=False, only_return_logits_per_text=False, no_grad_text=False): |
|
|
|
if isinstance(texts[0], str): |
|
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True) |
|
else: |
|
|
|
|
|
inputs = self.processor(images=images, return_tensors="pt") |
|
inputs['attention_mask'] = torch.ones((texts.size(0), texts.size(1))) |
|
inputs['input_embeds'] = texts |
|
|
|
if for_training and not disable_return_loss: |
|
inputs['return_loss'] = True |
|
else: |
|
inputs['return_loss'] = False |
|
|
|
inputs['only_return_logits_per_text'] = only_return_logits_per_text |
|
inputs['no_grad_text'] = no_grad_text |
|
|
|
for k, v in inputs.items(): |
|
if isinstance(v, torch.Tensor): |
|
inputs[k] = v.to('cuda') |
|
|
|
if self.first_inference: |
|
logger.info(f'before input size: {inputs["pixel_values"].size()}') |
|
|
|
|
|
|
|
inputs['pixel_values'] = F.interpolate(inputs['pixel_values'], size=(self.img_size, self.img_size)) |
|
|
|
|
|
if self.first_inference: |
|
logger.info(f'after input size: {inputs["pixel_values"].size()}') |
|
self.first_inference = False |
|
|
|
return self.model(**inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def clip_vit_b_16(img_size): |
|
|
|
return Clip_ViTB16(img_size) |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
model = clip_vit_b_16().cuda() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from data import get_dataset, build_dataloader |
|
from torchvision.transforms import Compose, ToTensor, Resize |
|
dataset = get_dataset('Caltech256', '/data/zql/datasets/Caltech-256/data/caltech256/256_ObjectCategories/', 'train', transform=Compose([ |
|
Resize((32, 32)), ToTensor() |
|
])) |
|
dataloader = build_dataloader(dataset, 8, 0, True, None) |
|
|
|
from PIL import Image |
|
import requests |
|
images, labels = next(iter(dataloader)) |
|
|
|
|
|
classes = dataset.classes |
|
text = [f"a photo of a {classes[i]}" for i in labels] |
|
print(text) |
|
print(images.size()) |
|
|
|
o = model(images, text, True) |
|
print(o) |
|
print(o.logits_per_image.softmax(dim=1)) |
|
|
|
|
|
|
|
|