loveisgone's picture
Upload folder using huggingface_hub
0c0cc86 verified
from typing import List
import warnings
import torch
from torch import nn, Tensor
from torchvision import transforms
from torchtune.models.llama3 import lora_llama3_8b, llama3_8b
from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear
from torchtune.modules import TransformerDecoder
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from imagebind.models import imagebind_model
from models.imagebind_wrapper import get_imagebind_v2, V2_PATH
from models.imagebind_wrapper import ImageBind
IMAGEBIND_DIM = 1024
CLIP_DIM = 768
class MMEmbedding(nn.Embedding):
def __init__(self, e, perception_tokens=1, use_clip=False):
super().__init__(
num_embeddings=e.num_embeddings,
embedding_dim=e.embedding_dim,
padding_idx=e.padding_idx,
max_norm=e.max_norm,
norm_type=e.norm_type,
scale_grad_by_freq=e.scale_grad_by_freq,
sparse=e.sparse,
)
self._perception_tokens = perception_tokens
self._context = []
self._use_clip = use_clip
dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0)
dim_out = e.embedding_dim * perception_tokens
self.proj_to_llama = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.LayerNorm(dim_out),
nn.Linear(dim_out, dim_out),
)
def set_context(self, context):
self._context = context
def forward(self, input: Tensor) -> Tensor:
r = super().forward(input)
# self._context is first indexed by batch idx
for b, context_dict in enumerate(self._context):
# then by sequence idx
for s, embed in context_dict.items():
# and then must be transformed from imagebind dim -> llama3 dim
if self._use_clip:
llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]]))
else:
llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]]))
r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1)
return r
class MMLinear(nn.Linear):
def __init__(self, o):
super().__init__(
in_features=o.in_features,
out_features=o.out_features,
bias=(o.bias != None)
)
self._context = []
dim_out = CLIP_DIM
dim_in = o.in_features
self.proj_from_llama = nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.GELU(),
nn.LayerNorm(dim_out),
nn.Linear(dim_out, dim_out),
)
def set_context(self, context):
self._context = context
def forward(self, input_bsd: Tensor) -> Tensor:
# self._context has the indexes of image llama tokens: process these with proj_from_llama
self._clip_projections = []
# # self._context is first indexed by batch idx
# for b, context_dict in enumerate(self._context):
# # then by sequence idx
# for s, embed in context_dict.items():
# # and then must be transformed from llama3 dim -> clip dim
# self._clip_projections.append((
# self.proj_from_llama(input_bsd[b, s]),
# (embed["clip_embed"] if "clip_embed" in embed else None) # terrible
# ))
r = super().forward(input_bsd)
return r
def lora_mmllama3_8b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
quantize_base: bool = False,
perception_tokens: int = 2,
use_clip: bool = False
) -> TransformerDecoder:
llama3 = lora_llama3_8b(
lora_attn_modules,
apply_lora_to_mlp,
apply_lora_to_output,
lora_rank,
lora_alpha,
quantize_base,
)
llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
llama3.output = MMLinear(llama3.output)
return llama3
def mmllama3_8b(
perception_tokens: int = 2,
use_clip: bool = False
) -> TransformerDecoder:
llama3 = llama3_8b()
llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
llama3.output = MMLinear(llama3.output)
return llama3
def imagebind_huge(use_v2: bool=True):
if use_v2:
imagebind = ImageBind(v2=True)
else:
imagebind = imagebind_model.imagebind_huge(pretrained=True)
imagebind.transform_from_pil = transforms.Compose([
transforms.Resize(
224, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
return imagebind