Spaces:
Sleeping
Sleeping
import gradio as gr | |
import base64 | |
from transformers import pipeline | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
# Load CLIP model and processor | |
from transformers import GenerationConfig | |
from peft import PeftModel, PeftConfig | |
import torch.nn as nn | |
import random | |
import torch | |
from transformers import AutoTokenizer | |
import os | |
from transformers import AutoModelForCausalLM, BitsAndBytesConfig,AutoTokenizer | |
import torch | |
hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") # This retrieves the secret | |
model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
print("loading_model") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype = torch.float32, | |
trust_remote_code=True, | |
token=hf_token | |
) | |
print("loaded_model") | |
tokenizer = AutoTokenizer.from_pretrained(model_name,token=hf_token) | |
tokenizer.pad_token = tokenizer.eos_token | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device | |
select_feature = 'patch' | |
def feature_select(image_forward_outs): | |
image_features = image_forward_outs.hidden_states[-1] | |
if select_feature == 'patch': | |
image_features = image_features[:, 1:] # Skip CLS token if selecting patch | |
elif select_feature == 'cls_patch': | |
image_features = image_features # Keep CLS + patch tokens | |
else: | |
raise ValueError(f'Unexpected select feature: {select_feature}') | |
return image_features | |
class MLPProjection(nn.Module): | |
def __init__(self, input_dim, output_dim, hidden_dim=768, depth=2): | |
super(MLPProjection, self).__init__() | |
modules = [] | |
modules.append(nn.Linear(input_dim, hidden_dim,bias = False)) | |
for _ in range(1, depth): | |
modules.append(nn.GELU()) | |
modules.append(nn.Linear(hidden_dim, output_dim,bias=False)) | |
self.mlp = nn.Sequential(*modules) | |
def forward(self, x): | |
return self.mlp(x) | |
class PHI2WithMLP(nn.Module): | |
def __init__(self, phi2_model, mlp_projection): | |
super(PHI2WithMLP, self).__init__() | |
self.phi2_model = phi2_model | |
self.mlp_projection = mlp_projection | |
self.config = phi2_model.config | |
def forward(self, image_embeddings=None, | |
inputs_embeds=None, | |
input_ids=None, | |
attention_mask=None, | |
labels=None, | |
output_attentions=False, | |
output_hidden_states=False, | |
**kwargs): # Catch any additional arguments): | |
if input_ids is not None: | |
token_embeddings = self.phi2_model.get_input_embeddings()(input_ids) | |
elif inputs_embeds is not None: | |
token_embeddings = inputs_embeds | |
else: | |
raise ValueError("You must provide either input_ids or inputs_embeds.") | |
if image_embeddings is not None: | |
# Apply MLP to image embeddings to map to text embedding space | |
projected_image_embeddings = self.mlp_projection(image_embeddings).to(device = token_embeddings.device) | |
# Get the sequence length for the image embeddings | |
image_embedding_length = projected_image_embeddings.size(1) | |
batch_size, text_sequence_length = attention_mask.shape | |
# Extend attention mask for image embeddings (ones for image embedding positions) | |
new_attention_mask = torch.cat( | |
[torch.ones((batch_size,image_embedding_length), device=attention_mask.device),attention_mask ], dim=1 | |
) | |
# Combine image and token embeddings | |
combined_embeddings = torch.cat([projected_image_embeddings, token_embeddings], dim=1) # Concatenating along sequence length | |
else: | |
# No image embeddings: Use only token embeddings and the original attention mask | |
combined_embeddings = token_embeddings | |
new_attention_mask = attention_mask | |
if labels is not None: | |
# Labels should match the sequence length of combined embeddings | |
# If labels correspond only to text tokens, pad them to match the new sequence length | |
if image_embeddings is not None: | |
label_padding = torch.full( | |
(batch_size, image_embedding_length), -100, device=labels.device # Use -100 for ignore index | |
) | |
new_labels = torch.cat([label_padding,labels], dim=1) | |
else: | |
new_labels = labels | |
else: | |
new_labels = labels | |
# Pass the combined embeddings through the PHI2 model with the (updated or original) attention mask | |
outputs = self.phi2_model(inputs_embeds=combined_embeddings, attention_mask=new_attention_mask,labels = new_labels, output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
**kwargs) | |
return outputs | |
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, image_embeddings=None, **kwargs): | |
# Generate inputs with projections where necessary | |
if image_embeddings is not None: | |
projected_image_embeddings = self.mlp_projection(image_embeddings) | |
projected_image_embeddings = projected_image_embeddings.unsqueeze(0) | |
token_embeddings = self.phi2_model.get_input_embeddings()(input_ids) | |
combined_embeddings = torch.cat([projected_image_embeddings, token_embeddings], dim=1) | |
image_embedding_length = projected_image_embeddings.size(1) | |
image_embedding_length = projected_image_embeddings.size(1) | |
batch_size, text_sequence_length = attention_mask.shape | |
# Extend attention mask for image embeddings (ones for image embedding positions) | |
new_attention_mask = torch.cat( | |
[torch.ones((batch_size,image_embedding_length), device=attention_mask.device),attention_mask ], dim=1 | |
) | |
else: | |
combined_embeddings = self.phi2_model.get_input_embeddings()(input_ids) | |
new_attention_mask = attention_mask | |
return { | |
"inputs_embeds": combined_embeddings, | |
"attention_mask": new_attention_mask, | |
**kwargs | |
} | |
def generate(self, input_ids, attention_mask=None, image_embeddings=None, **kwargs): | |
self.eval() # Set to evaluation mode | |
# Prepare inputs for generation | |
inputs = self.prepare_inputs_for_generation(input_ids, attention_mask, image_embeddings, **kwargs) | |
# Use the model's built-in generate method | |
return self.phi2_model.generate(**inputs) | |
def create_phi2_model_with_lora(mlp_projection,lan_model): | |
for param in mlp_projection.parameters(): | |
param.requires_grad = True | |
# Return PHI2 model with MLP projection | |
return PHI2WithMLP(lan_model, mlp_projection) | |
model_embedding_dim = model.config.hidden_size # This might change based on your model architecture | |
# Example usage | |
input_dim = 768 # Input dimension of image embeddings | |
output_dim = model_embedding_dim # Target dimension of text embeddings | |
hidden_dim = 1024 # Hidden layer dimension of the MLP | |
mlp_projection = MLPProjection(input_dim, output_dim, hidden_dim, depth=2).to(device) # Customize MLP | |
combined_model = create_phi2_model_with_lora(mlp_projection, model) | |
peft_model_id = "Kartheekb7/results1" | |
loaded_model = PeftModel.from_pretrained(combined_model, peft_model_id) | |
loaded_mlp_weights = torch.load("mlp_projection_weights.pth",map_location=torch.device('cpu')) | |
loaded_model.base_model.model.mlp_projection.load_state_dict(loaded_mlp_weights) | |
# Create a new GenerationConfig with desired settings | |
generation_config = GenerationConfig(max_new_tokens=128, temperature=0.01, top_p=1) | |
loaded_model.generation_config = generation_config | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-small", | |
chunk_length_s=30, | |
device=device, | |
) | |
def image_to_base64(image_path): | |
with open(image_path, 'rb') as img: | |
encoded_string = base64.b64encode(img.read()) | |
return encoded_string.decode('utf-8') | |
def audio_to_base64(audio_path): | |
with open(audio_path, 'rb') as audio_file: | |
encoded_string = base64.b64encode(audio_file.read()) | |
return encoded_string.decode('utf-8') | |
def get_clip_embedding(image_path): | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt", padding=True) | |
with torch.no_grad(): | |
image_features = clip_model.get_image_features(**inputs) | |
image_forward_outs = clip_model.vision_model(**inputs, output_hidden_states=True) | |
image_features = feature_select(image_forward_outs) | |
image_embedding = image_features.squeeze(0) | |
return image_embedding | |
def process_text(text_input): | |
# Tokenize text input | |
input_encoding = tokenizer( | |
text_input, | |
return_tensors='pt', | |
# padding='max_length', | |
truncation=True, | |
# max_length=256-49 # Set this to match your model's input size | |
) | |
return input_encoding | |
def audio_process(path_file): | |
result = pipe(path_file,generate_kwargs={"language": "english"}) | |
return result['text'] | |
def chat(message, history, audio=None, image=None): | |
image_embedding = None | |
response = "" if message is None else message | |
input_message = message | |
if audio is not None: | |
print("audio") | |
base64_audio = audio_to_base64(audio) | |
data_url = f"data:audio/wav;base64,{base64_audio}" | |
input_message += f"<audio controls><source src='{data_url}' type='audio/wav'>Your browser does not support the audio element.</audio>" | |
response += audio_process(audio) | |
print("audio_processed") | |
if image is not None: | |
base64 = image_to_base64(image) | |
data_url = f"data:image/jpeg;base64,{base64}" | |
input_message += f" " | |
image_embedding = get_clip_embedding(image) | |
print("image_processed") | |
input_encoding = process_text(response) | |
print("inference start") | |
outputs = loaded_model.generate(**input_encoding,image_embeddings = image_embedding, max_new_tokens=64, temperature=0.01, top_p=1) | |
# Decode output to text | |
print("inference end") | |
response_new = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
history.append((input_message, response_new)) | |
return history | |
with gr.Blocks() as iface: | |
chatbot = gr.Chatbot() | |
state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=20): | |
msg = gr.Textbox(show_label=False, placeholder="Type a message...", container=False) | |
with gr.Column(min_width=70, scale=1): | |
submit = gr.Button("➤", variant="primary") | |
with gr.Column(min_width=50, scale=1): | |
audio_btn = gr.Button("🎤") | |
with gr.Column(min_width=50, scale=1): | |
file_btn = gr.Button("📎") | |
audio = gr.Audio(sources=["microphone","upload"], type="filepath", visible=False) | |
image = gr.Image(type="filepath", visible=False) | |
def process_input(message, history, audio_file, image_file): | |
history = chat(message, history, audio_file, image_file) | |
return "", history | |
submit.click(process_input, inputs=[msg, state, audio, image], outputs=[msg, chatbot]) | |
msg.submit(process_input, inputs=[msg, state, audio, image], outputs=[msg, chatbot]) | |
def toggle_audio(audio_visible): | |
return gr.update(visible=not audio_visible) | |
def toggle_image(image_visible): | |
return gr.update(visible=not image_visible) | |
audio_btn.click(toggle_audio, inputs=[audio], outputs=[audio]) | |
file_btn.click(toggle_image, inputs=[image], outputs=[image]) | |
iface.launch() |