|
|
|
|
|
|
|
import torch |
|
import types |
|
import einops |
|
|
|
from abc import ABCMeta |
|
from transformers import CLIPVisionModelWithProjection |
|
|
|
|
|
def preprocess(image): |
|
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None] |
|
std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None] |
|
|
|
scale = 16 / min(image.shape[2], image.shape[3]) |
|
image = torch.nn.functional.interpolate( |
|
image, |
|
size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])), |
|
mode="bicubic", |
|
antialias=True |
|
) |
|
|
|
return (image - mean) / std |
|
|
|
|
|
def arbitrary_positional_encoding(p, H, W): |
|
weight = p.weight |
|
cls = weight[:1] |
|
pos = weight[1:] |
|
pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16) |
|
pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest") |
|
pos = einops.rearrange(pos, '1 C H W -> (H W) C') |
|
weight = torch.cat([cls, pos])[None] |
|
return weight |
|
|
|
|
|
def improved_clipvision_embedding_forward(self, pixel_values): |
|
pixel_values = pixel_values * 0.5 + 0.5 |
|
pixel_values = preprocess(pixel_values) |
|
batch_size = pixel_values.shape[0] |
|
target_dtype = self.patch_embedding.weight.dtype |
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) |
|
B, C, H, W = patch_embeds.shape |
|
patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C') |
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
|
embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W) |
|
return embeddings |
|
|
|
|
|
class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.vision_model.embeddings.forward = types.MethodType( |
|
improved_clipvision_embedding_forward, |
|
self.vision_model.embeddings |
|
) |
|
|