RashiAgarwal commited on
Commit
35c891d
·
verified ·
1 Parent(s): 9bfe3b8

Delete app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +0 -126
app_gradio.py DELETED
@@ -1,126 +0,0 @@
1
-
2
- import gradio as gr
3
- import peft
4
- from peft import LoraConfig, PeftModel
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
6
- import torch
7
- from PIL import Image
8
- import requests
9
- import numpy as np
10
- import torch.nn as nn
11
- import whisperx
12
- import ffmpeg, pydub
13
- from pydub import AudioSegment
14
-
15
- clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
16
- phi_model_name = "microsoft/phi-2"
17
-
18
- tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
19
- processor = AutoProcessor.from_pretrained(clip_model_name)
20
- tokenizer.pad_token = tokenizer.eos_token
21
- IMAGE_TOKEN_ID = 23893 # token for word comment
22
- device = "cuda" if torch.cuda.is_available() else "cpu"
23
- clip_embed = 640
24
- phi_embed = 2560
25
- compute_type = "float16"
26
- audio_batch_size = 1
27
-
28
- import gc
29
-
30
- # models
31
- clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
32
- projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
33
- gc.collect()
34
- phi_model = AutoModelForCausalLM.from_pretrained(
35
- phi_model_name,
36
- trust_remote_code=True,
37
- )
38
- audio_model = whisperx.load_model("small", device, compute_type=compute_type)
39
-
40
- # load weights
41
- model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/qlora_adaptor')
42
- merged_model = model_to_merge.merge_and_unload().to(device)
43
- projection.load_state_dict(torch.load('./model_chkpt/ft_projection.pth',map_location=torch.device(device)))
44
-
45
- def inference(img=None,img_audio=None,val_q=None):
46
-
47
- max_generate_length = 100
48
- val_combined_embeds = []
49
-
50
- with torch.no_grad():
51
-
52
- # image
53
- if img is not None:
54
- image_processed = processor(images=img, return_tensors="pt").to(device)
55
- clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
56
- val_image_embeds = projection(clip_val_outputs)
57
-
58
- img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
59
- img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
60
-
61
- val_combined_embeds.append(val_image_embeds)
62
- val_combined_embeds.append(img_token_embeds)
63
-
64
- # audio
65
- if img_audio is not None:
66
-
67
- # accepting only initial 15 secs speech
68
- audio = AudioSegment.from_mp3( img_audio)
69
- clipped_audio = audio[:15*1000]
70
- clipped_audio.export( 'audio.mp3', format="mp3")
71
- result = audio_model.transcribe('audio.mp3')
72
- audio_text = ''
73
-
74
- audio_text = result["segments"][0]['text']
75
- audio_text = audio_text.strip()
76
- audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
77
- audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
78
- val_combined_embeds.append(audio_embeds)
79
-
80
- # text question
81
- if len(val_q) != 0:
82
- val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
83
- val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
84
- val_combined_embeds.append(val_q_embeds)
85
-
86
- # val_combined_emb
87
- val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
88
-
89
- predicted_caption = torch.full((1,max_generate_length),50256).to(device)
90
-
91
- for g in range(max_generate_length):
92
- phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
93
- predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
94
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
95
- predicted_caption[:,g] = predicted_word_token.view(1,-1)
96
- next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
97
- val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
98
-
99
- predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
100
-
101
- return predicted_captions_decoded
102
-
103
- with gr.Blocks() as demo:
104
-
105
- gr.Markdown(
106
- """
107
- # MultiModal GPT _TSAI
108
- Build on Tiny Clip model and Microsoft's Phi-2 model further fine tuned on Instruct150K.
109
- """
110
- )
111
-
112
- # app GUI
113
- with gr.Row():
114
- with gr.Column():
115
- img_input = gr.Image(label='Image',type="pil")
116
- img_audio = gr.Audio(label="Speak a Query", sources=['microphone', 'upload'], type='filepath')
117
- img_question = gr.Text(label ='Write a Query')
118
- with gr.Column():
119
- img_answer = gr.Text(label ='Answer')
120
-
121
- section_btn = gr.Button("Generate")
122
- section_btn.click(inference, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
123
-
124
- if __name__ == "__main__":
125
- demo.launch(debug=True)
126
-