Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import whisperx | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForCausalLM | |
from transformers import CLIPVisionModel, CLIPImageProcessor | |
import peft | |
import gradio as gr | |
device = 'cpu' | |
user = "VarunSivamani" | |
model_name = "QLoRA-phi2" | |
model_id = f"{user}/{model_name}" | |
model_name = "microsoft/phi-2" | |
phi2_model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
device_map = 'cpu' | |
) | |
phi2_model.config.use_cache = False | |
whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32') | |
image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32') | |
clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32') | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.bos_token = tokenizer.eos_token | |
def text_to_embeddings(text): | |
input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False) | |
return phi2_model.get_input_embeddings()(input_tokens.input_ids) | |
def audio_to_text_embeds(file_name): | |
result = whisper_model.transcribe(file_name) | |
res_text = '' | |
for segment in result['segments']: | |
res_text = res_text + segment['text'] | |
return res_text.strip() | |
def select_features(image_out): | |
image_features = image_out.hidden_states[-1] | |
return image_features[:, 1:, :] | |
def CLIP_embeddings(image): | |
_ = clip_model.requires_grad_(False) | |
image = image_processor(images=image, return_tensors="pt") | |
image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True) | |
return select_features(image_out) | |
class ResBlock(nn.Module): | |
def __init__(self, input_size): | |
super().__init__() | |
self.pre_norm = nn.LayerNorm(input_size) | |
self.proj = nn.Sequential( | |
nn.Linear(input_size, input_size), | |
nn.GELU(), | |
nn.Linear(input_size, input_size) | |
) | |
def forward(self, x): | |
x = self.pre_norm(x) | |
return x + self.proj(x) | |
class CLIP_projection(nn.Module): | |
def __init__( | |
self, | |
dim_input_CLIP = 768, | |
dim_input_Phi2 = 2560 | |
): | |
super(CLIP_projection, self).__init__() | |
self.projection_img = nn.Linear( | |
dim_input_CLIP, dim_input_Phi2, bias=False | |
) | |
self.resblock = ResBlock(dim_input_Phi2) | |
def forward(self, x): | |
x = self.projection_img(x) | |
return self.resblock(x) | |
proj_layer = CLIP_projection() | |
proj_layer.projection_img.load_state_dict(torch.load("proj.pth", map_location='cpu')) | |
proj_layer.resblock.load_state_dict(torch.load("block.pth", map_location='cpu')) | |
def img_embeddings(image): | |
clip_embeddings = CLIP_embeddings(image) | |
return proj_layer(clip_embeddings) | |
phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id) | |
def multimodal_phi2(image=None, audio=None, text=None): | |
if len(text) == 0: | |
text = None | |
if image is None and audio is None and text is None: | |
return None | |
context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False) | |
input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids) | |
if image is not None: | |
query = text | |
image_embeds = img_embeddings(image) | |
input_embeds = torch.cat((input_embeds, image_embeds), dim=1) | |
if audio is not None: | |
audio_transcribed = audio_to_text_embeds(audio) | |
audio_embeds = text_to_embeddings(audio_transcribed) | |
input_embeds = torch.cat((input_embeds, audio_embeds), dim=1) | |
if text is not None: | |
query = text | |
text_embeds = text_to_embeddings(text) | |
input_embeds = torch.cat((input_embeds, text_embeds), dim=1) | |
question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False) | |
question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids) | |
input_embeds = torch.cat((input_embeds, question_embeds), dim=1) | |
answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False) | |
answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids) | |
input_embeds = torch.cat((input_embeds, answer_embeds), dim=1) | |
result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id) | |
process = tokenizer.batch_decode(result)[0] | |
process = process.split(tokenizer.eos_token) | |
if process[0] == '': | |
return process[1] | |
else: | |
return process[0] | |
demo = gr.Interface( | |
fn=multimodal_phi2, | |
inputs = [ | |
gr.Image(label="Image"), | |
gr.Audio(label="Audio", sources=["microphone", "upload"], type="filepath"), | |
gr.Textbox(label="Text"), | |
], | |
outputs = [ | |
gr.Textbox(label='Answer'), | |
], | |
) | |
demo.launch() |