Spaces:
Runtime error
Runtime error
File size: 4,345 Bytes
2e5e07d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import torch
# import argparse
# from omegaconf import OmegaConf
# from models import get_models
# import sys
# import os
# from PIL import Image
# from copy import deepcopy
def tca_transform_model(model):
for down_block in model.down_blocks:
try:
for attention in down_block.attentions:
attention.transformer_blocks[0].tca_transform()
attention.transformer_blocks[0].tca_transform()
except:
continue
for attention in model.mid_block.attentions:
attention.transformer_blocks[0].tca_transform()
attention.transformer_blocks[0].tca_transform()
for up_block in model.up_blocks:
try:
for attention in up_block.attentions:
attention.transformer_blocks[0].tca_transform()
attention.transformer_blocks[0].tca_transform()
except:
continue
return model
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
def ip_transform_model(model):
model.image_proj_model = ImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024,
clip_extra_context_tokens=4).to(model.device)
for down_block in model.down_blocks:
try:
for attention in down_block.attentions:
attention.transformer_blocks[0].attn2.ip_transform()
attention.transformer_blocks[0].attn2.ip_transform()
except:
continue
for attention in model.mid_block.attentions:
attention.transformer_blocks[0].attn2.ip_transform()
attention.transformer_blocks[0].attn2.ip_transform()
for up_block in model.up_blocks:
try:
for attention in up_block.attentions:
attention.transformer_blocks[0].attn2.ip_transform()
attention.transformer_blocks[0].attn2.ip_transform()
except:
continue
return model
def ip_scale_set(model, scale):
for down_block in model.down_blocks:
try:
for attention in down_block.attentions:
attention.transformer_blocks[0].attn2.set_scale(scale)
attention.transformer_blocks[0].attn2.set_scale(scale)
except:
continue
for attention in model.mid_block.attentions:
attention.transformer_blocks[0].attn2.set_scale(scale)
attention.transformer_blocks[0].attn2.set_scale(scale)
for up_block in model.up_blocks:
try:
for attention in up_block.attentions:
attention.transformer_blocks[0].attn2.set_scale(scale)
attention.transformer_blocks[0].attn2.set_scale(scale)
except:
continue
return model
def ip_train_set(model):
model.requires_grad_(False)
model.image_proj_model.requires_grad_(True)
for down_block in model.down_blocks:
try:
for attention in down_block.attentions:
attention.transformer_blocks[0].attn2.ip_train_set()
attention.transformer_blocks[0].attn2.ip_train_set()
except:
continue
for attention in model.mid_block.attentions:
attention.transformer_blocks[0].attn2.ip_train_set()
attention.transformer_blocks[0].attn2.ip_train_set()
for up_block in model.up_blocks:
try:
for attention in up_block.attentions:
attention.transformer_blocks[0].attn2.ip_train_set()
attention.transformer_blocks[0].attn2.ip_train_set()
except:
continue
return model
|