VideoChatGPT / models /videochat.py
ynhe's picture
Update models/videochat.py
6afc77b
raw
history blame
9.73 kB
import os
import psutil
import random
import logging
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
from .blip2 import Blip2Base, disabled_train
from .modeling_llama import LlamaForCausalLM
from transformers import LlamaTokenizer, LlamaConfig
class VideoChat(Blip2Base):
"""
VideoChat model.
"""
def __init__(self, config):
super().__init__()
vit_model = config.get("vit_model", "eva_clip_g")
vit_model_path = config.get("vit_model_path", None)
q_former_model_path = config.get("q_former_model_path", None)
llama_model_path = config.get("llama_model_path")
videochat_model_path = config.get("videochat_model_path", "")
img_size = config.get("img_size")
drop_path_rate = config.get("drop_path_rate", 0)
use_grad_checkpoint = config.get("use_grad_checkpoint", False)
vit_precision = config.get("vit_precision", "fp16")
freeze_vit = config.get("freeze_vit", True)
freeze_qformer = config.get("freeze_qformer", True)
low_resource = config.get("low_resource", False) # use 8 bit and put vit in cpu
max_txt_len = config.get("max_txt_len", 32)
# uniformerv2
freeze_mhra = config.get("freeze_mhra", False)
temporal_downsample = config.get("temporal_downsample", True)
no_lmhra = config.get("no_lmhra", False)
double_lmhra = config.get("double_lmhra", False)
lmhra_reduction = config.get("lmhra_reduction", 2.0)
gmhra_layers = config.get("gmhra_layers", 8)
gmhra_drop_path_rate = config.get("gmhra_drop_path_rate", 0.)
gmhra_dropout = config.get("gmhra_dropout", 0.5)
# qformer
num_query_token = config.get("num_query_token")
extra_num_query_token = config.get("extra_num_query_token", 64)
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
self.vit_precision = vit_precision
print(f'Loading VIT. Use fp16: {vit_precision}')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
vit_model, img_size, drop_path_rate,
use_grad_checkpoint, vit_precision, vit_model_path,
temporal_downsample=temporal_downsample,
no_lmhra=no_lmhra,
double_lmhra=double_lmhra,
lmhra_reduction=lmhra_reduction,
gmhra_layers=gmhra_layers,
gmhra_drop_path_rate=gmhra_drop_path_rate,
gmhra_dropout=gmhra_dropout,
)
if freeze_vit:
print("freeze vision encoder")
if not freeze_mhra:
open_list = []
for name, param in self.visual_encoder.named_parameters():
if 'mhra' not in name:
param.requires_grad = False
else:
open_list.append(name)
print(f"open module: {open_list}")
print("open ln_vision")
else:
for name, param in self.visual_encoder.named_parameters():
param.requires_grad = False
self.visual_encoder = self.visual_encoder.eval()
self.visual_encoder.train = disabled_train
for name, param in self.ln_vision.named_parameters():
param.requires_grad = False
self.ln_vision = self.ln_vision.eval()
self.ln_vision.train = disabled_train
print('Loading VIT Done')
print('Loading Q-Former')
self.Qformer, self.query_tokens = self.init_Qformer(
num_query_token, self.visual_encoder.num_features,
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.load_from_pretrained(model_path=q_former_model_path)
print(f"Add extra {extra_num_query_token} tokens in QFormer")
self.extra_query_tokens = nn.Parameter(
torch.zeros(1, extra_num_query_token, self.query_tokens.shape[-1])
)
if freeze_qformer:
print("freeze Qformer")
for name, param in self.Qformer.named_parameters():
param.requires_grad = False
self.Qformer = self.Qformer.eval()
self.Qformer.train = disabled_train
self.query_tokens.requires_grad = False
print('Loading Q-Former Done')
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False, use_auth_token=os.environ["HF_TOKEN"])
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
print(u'ε½“ε‰θΏ›η¨‹ηš„ε†…ε­˜δ½Ώη”¨οΌš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
info = psutil.virtual_memory()
print( u'η”΅θ„‘ζ€»ε†…ε­˜οΌš%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'ε½“ε‰δ½Ώη”¨ηš„ζ€»ε†…ε­˜ε ζ―”οΌš',info.percent)
print(u'cpuδΈͺζ•°οΌš',psutil.cpu_count())
if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model_path,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
use_auth_token=os.environ["HF_TOKEN"],
)
else:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model_path,
torch_dtype=torch.float16,
use_auth_token=os.environ["HF_TOKEN"],
load_in_8bit=True,
device_map="auto"
)
print("freeze LLAMA")
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
print('Loading LLAMA Done')
print(u'ε½“ε‰θΏ›η¨‹ηš„ε†…ε­˜δ½Ώη”¨οΌš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
info = psutil.virtual_memory()
print( u'η”΅θ„‘ζ€»ε†…ε­˜οΌš%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'ε½“ε‰δ½Ώη”¨ηš„ζ€»ε†…ε­˜ε ζ―”οΌš',info.percent)
print(u'cpuδΈͺζ•°οΌš',psutil.cpu_count())
self.llama_proj = nn.Linear(
self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
)
self.max_txt_len = max_txt_len
print(u'ε½“ε‰θΏ›η¨‹ηš„ε†…ε­˜δ½Ώη”¨οΌš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
info = psutil.virtual_memory()
print( u'η”΅θ„‘ζ€»ε†…ε­˜οΌš%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'ε½“ε‰δ½Ώη”¨ηš„ζ€»ε†…ε­˜ε ζ―”οΌš',info.percent)
print(u'cpuδΈͺζ•°οΌš',psutil.cpu_count())
# load weights of VideoChat
if videochat_model_path:
print(u'ε½“ε‰θΏ›η¨‹ηš„ε†…ε­˜δ½Ώη”¨οΌš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
info = psutil.virtual_memory()
print( u'η”΅θ„‘ζ€»ε†…ε­˜οΌš%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'ε½“ε‰δ½Ώη”¨ηš„ζ€»ε†…ε­˜ε ζ―”οΌš',info.percent)
print(u'cpuδΈͺζ•°οΌš',psutil.cpu_count())
print(f"Load VideoChat from: {videochat_model_path}")
ckpt = torch.load(videochat_model_path, map_location="cpu")
print(u'ckpt load success.ε½“ε‰θΏ›η¨‹ηš„ε†…ε­˜δ½Ώη”¨οΌš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
info = psutil.virtual_memory()
print( u'η”΅θ„‘ζ€»ε†…ε­˜οΌš%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'ε½“ε‰δ½Ώη”¨ηš„ζ€»ε†…ε­˜ε ζ―”οΌš',info.percent)
print(u'cpuδΈͺζ•°οΌš',psutil.cpu_count())
msg = self.load_state_dict(ckpt['model'], strict=False)
print(msg)
print(u'ε½“ε‰θΏ›η¨‹ηš„ε†…ε­˜δ½Ώη”¨οΌš%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) )
info = psutil.virtual_memory()
print( u'η”΅θ„‘ζ€»ε†…ε­˜οΌš%.4f GB' % (info.total / 1024 / 1024 / 1024) )
print(u'ε½“ε‰δ½Ώη”¨ηš„ζ€»ε†…ε­˜ε ζ―”οΌš',info.percent)
print(u'cpuδΈͺζ•°οΌš',psutil.cpu_count())
def vit_to_cpu(self):
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
def encode_img(self, image):
device = image.device
if self.low_resource:
self.vit_to_cpu()
image = image.to("cpu")
with self.maybe_autocast():
T = image.shape[1]
# use_image = True if T == 1 else False
image = image.permute(0, 2, 1, 3, 4) # [B,T,C,H,W] -> [B,C,T,H,W]
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1)
query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
inputs_llama = self.llama_proj(query_output.last_hidden_state)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama