llava_chat / app.py
Kartheekb7's picture
Update app.py
c267fe9 verified
import gradio as gr
import base64
from transformers import pipeline
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
# Load CLIP model and processor
from transformers import GenerationConfig
from peft import PeftModel, PeftConfig
import torch.nn as nn
import random
import torch
from transformers import AutoTokenizer
import os
from transformers import AutoModelForCausalLM, BitsAndBytesConfig,AutoTokenizer
import torch
hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # This retrieves the secret
model_name = "meta-llama/Llama-3.2-1B-Instruct"
print("loading_model")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype = torch.float32,
trust_remote_code=True,
token=hf_token
)
print("loaded_model")
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
select_feature = 'patch'
def feature_select(image_forward_outs):
image_features = image_forward_outs.hidden_states[-1]
if select_feature == 'patch':
image_features = image_features[:, 1:] # Skip CLS token if selecting patch
elif select_feature == 'cls_patch':
image_features = image_features # Keep CLS + patch tokens
else:
raise ValueError(f'Unexpected select feature: {select_feature}')
return image_features
class MLPProjection(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim=768, depth=2):
super(MLPProjection, self).__init__()
modules = []
modules.append(nn.Linear(input_dim, hidden_dim,bias = False))
for _ in range(1, depth):
modules.append(nn.GELU())
modules.append(nn.Linear(hidden_dim, output_dim,bias=False))
self.mlp = nn.Sequential(*modules)
def forward(self, x):
return self.mlp(x)
class PHI2WithMLP(nn.Module):
def __init__(self, phi2_model, mlp_projection):
super(PHI2WithMLP, self).__init__()
self.phi2_model = phi2_model
self.mlp_projection = mlp_projection
self.config = phi2_model.config
def forward(self, image_embeddings=None,
inputs_embeds=None,
input_ids=None,
attention_mask=None,
labels=None,
output_attentions=False,
output_hidden_states=False,
**kwargs): # Catch any additional arguments):
if input_ids is not None:
token_embeddings = self.phi2_model.get_input_embeddings()(input_ids)
elif inputs_embeds is not None:
token_embeddings = inputs_embeds
else:
raise ValueError("You must provide either input_ids or inputs_embeds.")
if image_embeddings is not None:
# Apply MLP to image embeddings to map to text embedding space
projected_image_embeddings = self.mlp_projection(image_embeddings).to(device = token_embeddings.device)
# Get the sequence length for the image embeddings
image_embedding_length = projected_image_embeddings.size(1)
batch_size, text_sequence_length = attention_mask.shape
# Extend attention mask for image embeddings (ones for image embedding positions)
new_attention_mask = torch.cat(
[torch.ones((batch_size,image_embedding_length), device=attention_mask.device),attention_mask ], dim=1
)
# Combine image and token embeddings
combined_embeddings = torch.cat([projected_image_embeddings, token_embeddings], dim=1) # Concatenating along sequence length
else:
# No image embeddings: Use only token embeddings and the original attention mask
combined_embeddings = token_embeddings
new_attention_mask = attention_mask
if labels is not None:
# Labels should match the sequence length of combined embeddings
# If labels correspond only to text tokens, pad them to match the new sequence length
if image_embeddings is not None:
label_padding = torch.full(
(batch_size, image_embedding_length), -100, device=labels.device # Use -100 for ignore index
)
new_labels = torch.cat([label_padding,labels], dim=1)
else:
new_labels = labels
else:
new_labels = labels
# Pass the combined embeddings through the PHI2 model with the (updated or original) attention mask
outputs = self.phi2_model(inputs_embeds=combined_embeddings, attention_mask=new_attention_mask,labels = new_labels, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
**kwargs)
return outputs
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, image_embeddings=None, **kwargs):
# Generate inputs with projections where necessary
if image_embeddings is not None:
projected_image_embeddings = self.mlp_projection(image_embeddings)
projected_image_embeddings = projected_image_embeddings.unsqueeze(0)
token_embeddings = self.phi2_model.get_input_embeddings()(input_ids)
combined_embeddings = torch.cat([projected_image_embeddings, token_embeddings], dim=1)
image_embedding_length = projected_image_embeddings.size(1)
image_embedding_length = projected_image_embeddings.size(1)
batch_size, text_sequence_length = attention_mask.shape
# Extend attention mask for image embeddings (ones for image embedding positions)
new_attention_mask = torch.cat(
[torch.ones((batch_size,image_embedding_length), device=attention_mask.device),attention_mask ], dim=1
)
else:
combined_embeddings = self.phi2_model.get_input_embeddings()(input_ids)
new_attention_mask = attention_mask
return {
"inputs_embeds": combined_embeddings,
"attention_mask": new_attention_mask,
**kwargs
}
def generate(self, input_ids, attention_mask=None, image_embeddings=None, **kwargs):
self.eval() # Set to evaluation mode
# Prepare inputs for generation
inputs = self.prepare_inputs_for_generation(input_ids, attention_mask, image_embeddings, **kwargs)
# Use the model's built-in generate method
return self.phi2_model.generate(**inputs)
def create_phi2_model_with_lora(mlp_projection,lan_model):
for param in mlp_projection.parameters():
param.requires_grad = True
# Return PHI2 model with MLP projection
return PHI2WithMLP(lan_model, mlp_projection)
model_embedding_dim = model.config.hidden_size # This might change based on your model architecture
# Example usage
input_dim = 768 # Input dimension of image embeddings
output_dim = model_embedding_dim # Target dimension of text embeddings
hidden_dim = 1024 # Hidden layer dimension of the MLP
mlp_projection = MLPProjection(input_dim, output_dim, hidden_dim, depth=2).to(device) # Customize MLP
combined_model = create_phi2_model_with_lora(mlp_projection, model)
peft_model_id = "Kartheekb7/results1"
loaded_model = PeftModel.from_pretrained(combined_model, peft_model_id)
loaded_mlp_weights = torch.load("mlp_projection_weights.pth",map_location=torch.device('cpu'))
loaded_model.base_model.model.mlp_projection.load_state_dict(loaded_mlp_weights)
# Create a new GenerationConfig with desired settings
generation_config = GenerationConfig(max_new_tokens=128, temperature=0.01, top_p=1)
loaded_model.generation_config = generation_config
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-small",
chunk_length_s=30,
device=device,
)
def image_to_base64(image_path):
with open(image_path, 'rb') as img:
encoded_string = base64.b64encode(img.read())
return encoded_string.decode('utf-8')
def audio_to_base64(audio_path):
with open(audio_path, 'rb') as audio_file:
encoded_string = base64.b64encode(audio_file.read())
return encoded_string.decode('utf-8')
def get_clip_embedding(image_path):
image = Image.open(image_path)
inputs = processor(images=image, return_tensors="pt", padding=True)
with torch.no_grad():
image_features = clip_model.get_image_features(**inputs)
image_forward_outs = clip_model.vision_model(**inputs, output_hidden_states=True)
image_features = feature_select(image_forward_outs)
image_embedding = image_features.squeeze(0)
return image_embedding
def process_text(text_input):
# Tokenize text input
input_encoding = tokenizer(
text_input,
return_tensors='pt',
# padding='max_length',
truncation=True,
# max_length=256-49 # Set this to match your model's input size
)
return input_encoding
def audio_process(path_file):
result = pipe(path_file,generate_kwargs={"language": "english"})
return result['text']
def chat(message, history, audio=None, image=None):
image_embedding = None
response = "" if message is None else message
input_message = message
if audio is not None:
print("audio")
base64_audio = audio_to_base64(audio)
data_url = f"data:audio/wav;base64,{base64_audio}"
input_message += f"<audio controls><source src='{data_url}' type='audio/wav'>Your browser does not support the audio element.</audio>"
response += audio_process(audio)
print("audio_processed")
if image is not None:
base64 = image_to_base64(image)
data_url = f"data:image/jpeg;base64,{base64}"
input_message += f" ![]({data_url})"
image_embedding = get_clip_embedding(image)
print("image_processed")
input_encoding = process_text(response)
print("inference start")
outputs = loaded_model.generate(**input_encoding,image_embeddings = image_embedding, max_new_tokens=64, temperature=0.01, top_p=1)
# Decode output to text
print("inference end")
response_new = tokenizer.decode(outputs[0], skip_special_tokens=True)
history.append((input_message, response_new))
return history
with gr.Blocks() as iface:
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row():
with gr.Column(scale=20):
msg = gr.Textbox(show_label=False, placeholder="Type a message...", container=False)
with gr.Column(min_width=70, scale=1):
submit = gr.Button("➤", variant="primary")
with gr.Column(min_width=50, scale=1):
audio_btn = gr.Button("🎤")
with gr.Column(min_width=50, scale=1):
file_btn = gr.Button("📎")
audio = gr.Audio(sources=["microphone","upload"], type="filepath", visible=False)
image = gr.Image(type="filepath", visible=False)
def process_input(message, history, audio_file, image_file):
history = chat(message, history, audio_file, image_file)
return "", history
submit.click(process_input, inputs=[msg, state, audio, image], outputs=[msg, chatbot])
msg.submit(process_input, inputs=[msg, state, audio, image], outputs=[msg, chatbot])
def toggle_audio(audio_visible):
return gr.update(visible=not audio_visible)
def toggle_image(image_visible):
return gr.update(visible=not image_visible)
audio_btn.click(toggle_audio, inputs=[audio], outputs=[audio])
file_btn.click(toggle_image, inputs=[image], outputs=[image])
iface.launch()