File size: 2,198 Bytes
6dd488f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel
# The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range)

import torch
import types
import einops

from abc import ABCMeta
from transformers import CLIPVisionModelWithProjection


def preprocess(image):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None]
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None]

    scale = 16 / min(image.shape[2], image.shape[3])
    image = torch.nn.functional.interpolate(
        image,
        size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])),
        mode="bicubic",
        antialias=True
    )

    return (image - mean) / std


def arbitrary_positional_encoding(p, H, W):
    weight = p.weight
    cls = weight[:1]
    pos = weight[1:]
    pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16)
    pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest")
    pos = einops.rearrange(pos, '1 C H W -> (H W) C')
    weight = torch.cat([cls, pos])[None]
    return weight


def improved_clipvision_embedding_forward(self, pixel_values):
    pixel_values = pixel_values * 0.5 + 0.5
    pixel_values = preprocess(pixel_values)
    batch_size = pixel_values.shape[0]
    target_dtype = self.patch_embedding.weight.dtype
    patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
    B, C, H, W = patch_embeds.shape
    patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C')
    class_embeds = self.class_embedding.expand(batch_size, 1, -1)
    embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
    embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W)
    return embeddings


class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta):
    def __init__(self, config):
        super().__init__(config)
        self.vision_model.embeddings.forward = types.MethodType(
            improved_clipvision_embedding_forward,
            self.vision_model.embeddings
        )