efeperro's picture
Update app.py
6f9e813 verified
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from transformers import T5Tokenizer, ViTImageProcessor
from torch import nn
from typing import Tuple
import sentencepiece
class Encoder(nn.Module):
def __init__(self, pretrained_model):
"""
Implements the Encoder."
Args:
pretrained_model (str): name of the pretrained model
"""
super(Encoder, self).__init__()
self.encoder = ViTModel.from_pretrained(pretrained_model)
def forward(self, input):
out = self.encoder(pixel_values = input)
return out
class Decoder(nn.Module):
def __init__(self, pretrained_model, encoder_modeldim):
"""
Implements the Decoder."
Args:
pretrained_model (str): name of the pretrained model
"""
super(Decoder, self).__init__()
self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model)
self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False)
self.encoder_modeldim = encoder_modeldim
def forward(self, output_encoder, targets, decoder_ids=None):
if self.decoder.model_dim!=self.encoder_modeldim:
print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}")
output_encoder = self.linear(output_encoder)
print(output_encoder.shape)
# Validation/Testing
if decoder_ids is not None:
out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids)
# Training
else:
out = self.decoder(encoder_outputs=output_encoder, labels=targets)
return out
class EncoderDecoder(nn.Module):
def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None):
"""
Implements a model that combines MyEncoder and MyDecoder."
Args:
pretrained_model (tuple): name of the pretrained model
encoder_dmodel (int): hidden dimension of the encoder output
eos_token_id (torch.long): token used for end of sentence
pad_token_id (torch.long): token used for padding
"""
super(EncoderDecoder, self).__init__()
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.encoder = Encoder(pretrained_model[0])
self.encoder_dmodel = encoder_dmodel
# Freeze parameters from encoder
#for p in self.encoder.parameters():
# p.requires_grad=False
self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel)
self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id
def forward(self, images = None, targets = None, decoder_ids = None):
output_encoder = self.encoder(images)
out = self.decoder(output_encoder, targets, decoder_ids)
return out
# Model loading and setting up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model_vit_ai.pt", map_location=device)
model.to(device)
# Tokenizer and Feature Extractor
tokenizer = T5Tokenizer.from_pretrained('t5-base')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
# Define the image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])
def preprocess_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image)
return image.unsqueeze(0)
def generate_caption(image):
model.eval()
with torch.no_grad():
image_tensor = preprocess_image(image).to(device)
decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device)
for _ in range(50):
outputs = model(images=image_tensor, decoder_ids=decoder_input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = next_token_logits.argmax(1, keepdim=True)
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)
if torch.eq(next_token_id, tokenizer.eos_token_id).all():
break
caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True)
return caption
sample_images = [
"sample_image1.jpg",
"sample_image2.jpg",
"sample_image3.jpg"
]
interface = gr.Interface(
fn=generate_caption,
inputs="image", # Specify the input type as "image"
outputs="text",
examples=sample_images,
title="Image Captioning Model",
description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption."
)
# Run the interface
interface.launch(debug=True)