PaintsUndo / diffusers_vdm /improved_clip_vision.py
MohamedRashad's picture
Upload code
6dd488f
raw
history blame
2.2 kB
# A CLIP Vision supporting arbitrary aspect ratios, by lllyasviel
# The input range is changed to [-1, 1] rather than [0, 1] !!!! (same as VAE's range)
import torch
import types
import einops
from abc import ABCMeta
from transformers import CLIPVisionModelWithProjection
def preprocess(image):
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=image.device, dtype=image.dtype)[None, :, None, None]
std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=image.device, dtype=image.dtype)[None, :, None, None]
scale = 16 / min(image.shape[2], image.shape[3])
image = torch.nn.functional.interpolate(
image,
size=(14 * round(scale * image.shape[2]), 14 * round(scale * image.shape[3])),
mode="bicubic",
antialias=True
)
return (image - mean) / std
def arbitrary_positional_encoding(p, H, W):
weight = p.weight
cls = weight[:1]
pos = weight[1:]
pos = einops.rearrange(pos, '(H W) C -> 1 C H W', H=16, W=16)
pos = torch.nn.functional.interpolate(pos, size=(H, W), mode="nearest")
pos = einops.rearrange(pos, '1 C H W -> (H W) C')
weight = torch.cat([cls, pos])[None]
return weight
def improved_clipvision_embedding_forward(self, pixel_values):
pixel_values = pixel_values * 0.5 + 0.5
pixel_values = preprocess(pixel_values)
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
B, C, H, W = patch_embeds.shape
patch_embeds = einops.rearrange(patch_embeds, 'B C H W -> B (H W) C')
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + arbitrary_positional_encoding(self.position_embedding, H, W)
return embeddings
class ImprovedCLIPVisionModelWithProjection(CLIPVisionModelWithProjection, metaclass=ABCMeta):
def __init__(self, config):
super().__init__(config)
self.vision_model.embeddings.forward = types.MethodType(
improved_clipvision_embedding_forward,
self.vision_model.embeddings
)