efeperro's picture
Update app.py
59e90bd verified
raw
history blame
4.99 kB
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from transformers import T5Tokenizer, ViTFeatureExtractor
from torch import nn
class Encoder(nn.Module):
def __init__(self, pretrained_model):
"""
Implements the Encoder."
Args:
pretrained_model (str): name of the pretrained model
"""
super(Encoder, self).__init__()
self.encoder = ViTModel.from_pretrained(pretrained_model)
def forward(self, input):
out = self.encoder(pixel_values = input)
return out
class Decoder(nn.Module):
def __init__(self, pretrained_model, encoder_modeldim):
"""
Implements the Decoder."
Args:
pretrained_model (str): name of the pretrained model
"""
super(Decoder, self).__init__()
self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model)
self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False)
self.encoder_modeldim = encoder_modeldim
def forward(self, output_encoder, targets, decoder_ids=None):
if self.decoder.model_dim!=self.encoder_modeldim:
print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}")
output_encoder = self.linear(output_encoder)
print(output_encoder.shape)
# Validation/Testing
if decoder_ids is not None:
out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids)
# Training
else:
out = self.decoder(encoder_outputs=output_encoder, labels=targets)
return out
class EncoderDecoder(nn.Module):
def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None):
"""
Implements a model that combines MyEncoder and MyDecoder."
Args:
pretrained_model (tuple): name of the pretrained model
encoder_dmodel (int): hidden dimension of the encoder output
eos_token_id (torch.long): token used for end of sentence
pad_token_id (torch.long): token used for padding
"""
super(EncoderDecoder, self).__init__()
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.encoder = Encoder(pretrained_model[0])
self.encoder_dmodel = encoder_dmodel
# Freeze parameters from encoder
#for p in self.encoder.parameters():
# p.requires_grad=False
self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel)
self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id
def forward(self, images = None, targets = None, decoder_ids = None):
output_encoder = self.encoder(images)
out = self.decoder(output_encoder, targets, decoder_ids)
return out
# Model loading and setting up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model_vit_ai.pt", map_location=device)
model.to(device)
# Tokenizer and Feature Extractor
tokenizer = T5Tokenizer.from_pretrained('t5-base')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
# Define the image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])
def preprocess_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image)
return image.unsqueeze(0)
def generate_caption(image):
model.eval()
with torch.no_grad():
image_tensor = preprocess_image(image).to(device)
decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device)
for _ in range(50):
outputs = model(images=image_tensor, decoder_ids=decoder_input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = next_token_logits.argmax(1, keepdim=True)
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)
if torch.eq(next_token_id, tokenizer.eos_token_id).all():
break
caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True)
return caption
sample_images = [
"sample_image1.jpg",
"sample_image2.jpg",
"sample_image3.jpg"
]
# Define Gradio interface
interface = gr.Interface(
fn=generate_caption,
inputs=gr.inputs.Image(source="upload", tool='editor', type="numpy", label="Upload an image or take a photo"),
outputs='text',
examples=sample_images,
title="Image Captioning Model",
description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption."
)
# Run the interface
interface.launch(debug=True)