Pheye / pheye_builder /encoder.py
miguelcarv's picture
first commit
34f251f
from transformers import CLIPModel
from torch import nn
from peft import LoraConfig, get_peft_model
import torch
from torch import nn
import PIL
from PIL.Image import BICUBIC
import math
from torchvision import transforms
import torch.nn.functional as F
# level 4 which has 21 patches was being used in previous experiments so now I can't remove it or won't be able to load older models....
LEVELS_TO_PATCHES = {
1 : 1,
2 : 5,
3 : 10,
4 : 21
}
def cut_image_patches(image: PIL.Image, encoder_resolution: int = 224):
coordinates = []
width, height = image.size
width_tiles = [i*encoder_resolution for i in range(math.ceil(width/encoder_resolution)-1)]
width_tiles.append(width-encoder_resolution)
height_tiles = [i*encoder_resolution for i in range(math.ceil(height/encoder_resolution)-1)]
height_tiles.append(height-encoder_resolution)
for w in width_tiles:
for h in height_tiles:
coordinates.append((w,h,w+encoder_resolution,h+encoder_resolution))
cropped_images = [image.crop(c) for c in coordinates]
return cropped_images
class Encoder(nn.Module):
def __init__(self, clip_name, level = 2, dtype = None, use_dropout = True) -> None:
super().__init__()
if level not in LEVELS_TO_PATCHES:
raise ValueError("Resolution not supported")
self.n_patches = LEVELS_TO_PATCHES[level]
self.vision_model = CLIPModel.from_pretrained(clip_name, torch_dtype=dtype).vision_model
self.has_first_adapter = False
self.image_size = self.vision_model.config.image_size
self.patch_size = self.vision_model.config.patch_size
self.use_dropout = use_dropout
self.dtype = dtype
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.norm_lvl_1 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
self.norm_lvl_2 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
# this was being used in previous experiments so now I can't remove it or won't be able to load older models....
self.norm_lvl_3 = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
if level == 1:
self.connector = nn.LayerNorm(self.vision_model.config.hidden_size, dtype=dtype)
else:
self.connector = Position(self.n_patches, self.vision_model.config.hidden_size, dtype=dtype)
config_level2 = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"],
lora_dropout=0.05 if self.use_dropout else 0,
bias="none"
)
self.vision_model = get_peft_model(self.vision_model, config_level2, "second")
def add_first_level_adapter(self):
config_224 = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj", "patch_embedding", "fc1", "fc2"],
lora_dropout=0.05 if self.use_dropout else 0,
bias="none"
)
self.vision_model.add_adapter("first", config_224)
self.has_first_adapter = True
def forward(self, images: list, device = "cpu", **kwargs):
"""
shape (B, C, H, W) in list form
"""
B = len(images)
h = int((self.image_size/self.patch_size) ** 2 + 1)
resized_images = {1: [], 2: []}
for i in images:
resized_images[1].append(self.image_transform(i.resize((self.image_size,self.image_size), resample=BICUBIC)))
if self.n_patches == 5:
for crop in cut_image_patches(i.resize((self.image_size * 2,self.image_size * 2), resample=BICUBIC), encoder_resolution=self.image_size):
resized_images[2].append(self.image_transform(crop))
elif self.n_patches == 10:
for crop in cut_image_patches(i.resize((self.image_size * 3,self.image_size * 3), resample=BICUBIC), encoder_resolution=self.image_size):
resized_images[2].append(self.image_transform(crop))
vision_features = []
for res, imgs in resized_images.items():
if imgs != []:
resized_images[res] = torch.stack(imgs, dim = 0).to(device)
if res == 1 and self.has_first_adapter:
self.vision_model.set_adapter("first")
vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state))
elif res == 1:
with self.vision_model.disable_adapter():
vision_features.append(self.norm_lvl_1(self.vision_model(resized_images[res]).last_hidden_state))
elif res == 2:
self.vision_model.set_adapter("second")
if self.n_patches == 5:
vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 4, -1)))
elif self.n_patches == 10:
vision_features.append(self.norm_lvl_2(self.vision_model(resized_images[res]).last_hidden_state.view(B, h * 9, -1)))
vision_features = torch.cat(vision_features, dim = 1)
vision_features = self.connector(vision_features)
return vision_features
class Position(nn.Module):
def __init__(self, n_patches, dim, dtype) -> None:
super().__init__()
self.embedding = nn.Embedding(max(LEVELS_TO_PATCHES.values()), dim, dtype=dtype)
self.n_patches = n_patches
self.apply(self._init_weights)
def forward(self, vision_features):
batch_size, seq_len, dim = vision_features.size()
single_encoder_dim = seq_len // self.n_patches
device = vision_features.get_device()
pos = torch.LongTensor(list(range(self.n_patches))).to(device if device != -1 else "cpu")
pos = torch.repeat_interleave(self.embedding(pos).unsqueeze(0), single_encoder_dim, 1).expand(batch_size, -1, -1)
return vision_features + pos
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if name == "fc1.weight" or name == "fc2.weight" or name == "to_out.weight":
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_decoder_layers)))