import torch import torch.nn as nn import torch.nn.functional as F import whisperx from transformers import AutoTokenizer from transformers import AutoModelForCausalLM from transformers import CLIPVisionModel, CLIPImageProcessor import peft import gradio as gr device = 'cpu' user = "VarunSivamani" model_name = "QLoRA-phi2" model_id = f"{user}/{model_name}" model_name = "microsoft/phi-2" phi2_model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, device_map = 'cpu' ) phi2_model.config.use_cache = False whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32') image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-base-patch32') clip_model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32') tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) tokenizer.pad_token = tokenizer.eos_token tokenizer.bos_token = tokenizer.eos_token def text_to_embeddings(text): input_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False) return phi2_model.get_input_embeddings()(input_tokens.input_ids) def audio_to_text_embeds(file_name): result = whisper_model.transcribe(file_name) res_text = '' for segment in result['segments']: res_text = res_text + segment['text'] return res_text.strip() def select_features(image_out): image_features = image_out.hidden_states[-1] return image_features[:, 1:, :] def CLIP_embeddings(image): _ = clip_model.requires_grad_(False) image = image_processor(images=image, return_tensors="pt") image_out = clip_model(image['pixel_values'].to(device=clip_model.device), output_hidden_states=True) return select_features(image_out) class ResBlock(nn.Module): def __init__(self, input_size): super().__init__() self.pre_norm = nn.LayerNorm(input_size) self.proj = nn.Sequential( nn.Linear(input_size, input_size), nn.GELU(), nn.Linear(input_size, input_size) ) def forward(self, x): x = self.pre_norm(x) return x + self.proj(x) class CLIP_projection(nn.Module): def __init__( self, dim_input_CLIP = 768, dim_input_Phi2 = 2560 ): super(CLIP_projection, self).__init__() self.projection_img = nn.Linear( dim_input_CLIP, dim_input_Phi2, bias=False ) self.resblock = ResBlock(dim_input_Phi2) def forward(self, x): x = self.projection_img(x) return self.resblock(x) proj_layer = CLIP_projection() proj_layer.projection_img.load_state_dict(torch.load("proj.pth", map_location='cpu')) proj_layer.resblock.load_state_dict(torch.load("block.pth", map_location='cpu')) def img_embeddings(image): clip_embeddings = CLIP_embeddings(image) return proj_layer(clip_embeddings) phi2_model_peft = peft.PeftModel.from_pretrained(phi2_model, model_id) def multimodal_phi2(image=None, audio=None, text=None): if len(text) == 0: text = None if image is None and audio is None and text is None: return None context = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False) input_embeds = phi2_model_peft.get_input_embeddings()(context.input_ids) if image is not None: query = text image_embeds = img_embeddings(image) input_embeds = torch.cat((input_embeds, image_embeds), dim=1) if audio is not None: audio_transcribed = audio_to_text_embeds(audio) audio_embeds = text_to_embeddings(audio_transcribed) input_embeds = torch.cat((input_embeds, audio_embeds), dim=1) if text is not None: query = text text_embeds = text_to_embeddings(text) input_embeds = torch.cat((input_embeds, text_embeds), dim=1) question = tokenizer(" Question: " + query, return_tensors="pt", return_attention_mask=False) question_embeds = phi2_model_peft.get_input_embeddings()(question.input_ids) input_embeds = torch.cat((input_embeds, question_embeds), dim=1) answer = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False) answer_embeds = phi2_model_peft.get_input_embeddings()(answer.input_ids) input_embeds = torch.cat((input_embeds, answer_embeds), dim=1) result = phi2_model_peft.generate(inputs_embeds=input_embeds, bos_token_id = tokenizer.bos_token_id) process = tokenizer.batch_decode(result)[0] process = process.split(tokenizer.eos_token) if process[0] == '': return process[1] else: return process[0] demo = gr.Interface( fn=multimodal_phi2, inputs = [ gr.Image(label="Image"), gr.Audio(label="Audio", sources=["microphone", "upload"], type="filepath"), gr.Textbox(label="Text"), ], outputs = [ gr.Textbox(label='Answer'), ], ) demo.launch()