Spaces:
Build error
Build error
import random | |
import logging | |
import torch | |
from torch.cuda.amp import autocast as autocast | |
import torch.nn as nn | |
from .blip2 import Blip2Base, disabled_train | |
from .modeling_llama import LlamaForCausalLM | |
from transformers import LlamaTokenizer, LlamaConfig | |
class VideoChat(Blip2Base): | |
""" | |
VideoChat model. | |
""" | |
def __init__(self, config): | |
super().__init__() | |
vit_model = config.get("vit_model", "eva_clip_g") | |
vit_model_path = config.get("vit_model_path", None) | |
q_former_model_path = config.get("q_former_model_path", None) | |
llama_model_path = config.get("llama_model_path") | |
videochat_model_path = config.get("videochat_model_path", "") | |
img_size = config.get("img_size") | |
drop_path_rate = config.get("drop_path_rate", 0) | |
use_grad_checkpoint = config.get("use_grad_checkpoint", False) | |
vit_precision = config.get("vit_precision", "fp16") | |
freeze_vit = config.get("freeze_vit", True) | |
freeze_qformer = config.get("freeze_qformer", True) | |
low_resource = config.get("low_resource", False) # use 8 bit and put vit in cpu | |
max_txt_len = config.get("max_txt_len", 32) | |
# uniformerv2 | |
freeze_mhra = config.get("freeze_mhra", False) | |
temporal_downsample = config.get("temporal_downsample", True) | |
no_lmhra = config.get("no_lmhra", False) | |
double_lmhra = config.get("double_lmhra", False) | |
lmhra_reduction = config.get("lmhra_reduction", 2.0) | |
gmhra_layers = config.get("gmhra_layers", 8) | |
gmhra_drop_path_rate = config.get("gmhra_drop_path_rate", 0.) | |
gmhra_dropout = config.get("gmhra_dropout", 0.5) | |
# qformer | |
num_query_token = config.get("num_query_token") | |
extra_num_query_token = config.get("extra_num_query_token", 64) | |
self.tokenizer = self.init_tokenizer() | |
self.low_resource = low_resource | |
self.vit_precision = vit_precision | |
print(f'Loading VIT. Use fp16: {vit_precision}') | |
self.visual_encoder, self.ln_vision = self.init_vision_encoder( | |
vit_model, img_size, drop_path_rate, | |
use_grad_checkpoint, vit_precision, vit_model_path, | |
temporal_downsample=temporal_downsample, | |
no_lmhra=no_lmhra, | |
double_lmhra=double_lmhra, | |
lmhra_reduction=lmhra_reduction, | |
gmhra_layers=gmhra_layers, | |
gmhra_drop_path_rate=gmhra_drop_path_rate, | |
gmhra_dropout=gmhra_dropout, | |
) | |
if freeze_vit: | |
print("freeze vision encoder") | |
if not freeze_mhra: | |
open_list = [] | |
for name, param in self.visual_encoder.named_parameters(): | |
if 'mhra' not in name: | |
param.requires_grad = False | |
else: | |
open_list.append(name) | |
print(f"open module: {open_list}") | |
print("open ln_vision") | |
else: | |
for name, param in self.visual_encoder.named_parameters(): | |
param.requires_grad = False | |
self.visual_encoder = self.visual_encoder.eval() | |
self.visual_encoder.train = disabled_train | |
for name, param in self.ln_vision.named_parameters(): | |
param.requires_grad = False | |
self.ln_vision = self.ln_vision.eval() | |
self.ln_vision.train = disabled_train | |
print('Loading VIT Done') | |
print('Loading Q-Former') | |
self.Qformer, self.query_tokens = self.init_Qformer( | |
num_query_token, self.visual_encoder.num_features, | |
) | |
self.Qformer.cls = None | |
self.Qformer.bert.embeddings.word_embeddings = None | |
self.Qformer.bert.embeddings.position_embeddings = None | |
for layer in self.Qformer.bert.encoder.layer: | |
layer.output = None | |
layer.intermediate = None | |
self.load_from_pretrained(model_path=q_former_model_path) | |
print(f"Add extra {extra_num_query_token} tokens in QFormer") | |
self.extra_query_tokens = nn.Parameter( | |
torch.zeros(1, extra_num_query_token, self.query_tokens.shape[-1]) | |
) | |
if freeze_qformer: | |
print("freeze Qformer") | |
for name, param in self.Qformer.named_parameters(): | |
param.requires_grad = False | |
self.Qformer = self.Qformer.eval() | |
self.Qformer.train = disabled_train | |
self.query_tokens.requires_grad = False | |
print('Loading Q-Former Done') | |
print('Loading LLAMA') | |
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False, use_auth_token=os.environ["HF_TOKEN"]) | |
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token | |
if self.low_resource: | |
self.llama_model = LlamaForCausalLM.from_pretrained( | |
llama_model_path, | |
torch_dtype=torch.float16, | |
load_in_8bit=True, | |
device_map="auto", | |
use_auth_token=os.environ["HF_TOKEN"], | |
) | |
else: | |
self.llama_model = LlamaForCausalLM.from_pretrained( | |
llama_model_path, | |
torch_dtype=torch.float16, | |
use_auth_token=os.environ["HF_TOKEN"], | |
) | |
print("freeze LLAMA") | |
for name, param in self.llama_model.named_parameters(): | |
param.requires_grad = False | |
print('Loading LLAMA Done') | |
self.llama_proj = nn.Linear( | |
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size | |
) | |
self.max_txt_len = max_txt_len | |
# load weights of VideoChat | |
if videochat_model_path: | |
print(f"Load VideoChat from: {videochat_model_path}") | |
ckpt = torch.load(videochat_model_path, map_location="cpu") | |
msg = self.load_state_dict(ckpt['model'], strict=False) | |
print(msg) | |
def vit_to_cpu(self): | |
self.ln_vision.to("cpu") | |
self.ln_vision.float() | |
self.visual_encoder.to("cpu") | |
self.visual_encoder.float() | |
def encode_img(self, image): | |
device = image.device | |
if self.low_resource: | |
self.vit_to_cpu() | |
image = image.to("cpu") | |
with self.maybe_autocast(): | |
T = image.shape[1] | |
# use_image = True if T == 1 else False | |
image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W] | |
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) | |
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) | |
query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1) | |
query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_output = self.Qformer.bert( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_atts, | |
return_dict=True, | |
) | |
inputs_llama = self.llama_proj(query_output.last_hidden_state) | |
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) | |
return inputs_llama, atts_llama | |