Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import T5Tokenizer, ViTImageProcessor | |
from torch import nn | |
from typing import Tuple | |
import sentencepiece | |
class Encoder(nn.Module): | |
def __init__(self, pretrained_model): | |
""" | |
Implements the Encoder." | |
Args: | |
pretrained_model (str): name of the pretrained model | |
""" | |
super(Encoder, self).__init__() | |
self.encoder = ViTModel.from_pretrained(pretrained_model) | |
def forward(self, input): | |
out = self.encoder(pixel_values = input) | |
return out | |
class Decoder(nn.Module): | |
def __init__(self, pretrained_model, encoder_modeldim): | |
""" | |
Implements the Decoder." | |
Args: | |
pretrained_model (str): name of the pretrained model | |
""" | |
super(Decoder, self).__init__() | |
self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model) | |
self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False) | |
self.encoder_modeldim = encoder_modeldim | |
def forward(self, output_encoder, targets, decoder_ids=None): | |
if self.decoder.model_dim!=self.encoder_modeldim: | |
print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}") | |
output_encoder = self.linear(output_encoder) | |
print(output_encoder.shape) | |
# Validation/Testing | |
if decoder_ids is not None: | |
out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids) | |
# Training | |
else: | |
out = self.decoder(encoder_outputs=output_encoder, labels=targets) | |
return out | |
class EncoderDecoder(nn.Module): | |
def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None): | |
""" | |
Implements a model that combines MyEncoder and MyDecoder." | |
Args: | |
pretrained_model (tuple): name of the pretrained model | |
encoder_dmodel (int): hidden dimension of the encoder output | |
eos_token_id (torch.long): token used for end of sentence | |
pad_token_id (torch.long): token used for padding | |
""" | |
super(EncoderDecoder, self).__init__() | |
self.eos_token_id = eos_token_id | |
self.pad_token_id = pad_token_id | |
self.encoder = Encoder(pretrained_model[0]) | |
self.encoder_dmodel = encoder_dmodel | |
# Freeze parameters from encoder | |
#for p in self.encoder.parameters(): | |
# p.requires_grad=False | |
self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel) | |
self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id | |
def forward(self, images = None, targets = None, decoder_ids = None): | |
output_encoder = self.encoder(images) | |
out = self.decoder(output_encoder, targets, decoder_ids) | |
return out | |
# Model loading and setting up the device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = torch.load("model_vit_ai.pt", map_location=device) | |
model.to(device) | |
# Tokenizer and Feature Extractor | |
tokenizer = T5Tokenizer.from_pretrained('t5-base') | |
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') | |
# Define the image preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | |
]) | |
def preprocess_image(image): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
image = transform(image) | |
return image.unsqueeze(0) | |
def generate_caption(image): | |
model.eval() | |
with torch.no_grad(): | |
image_tensor = preprocess_image(image).to(device) | |
decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device) | |
for _ in range(50): | |
outputs = model(images=image_tensor, decoder_ids=decoder_input_ids) | |
next_token_logits = outputs.logits[:, -1, :] | |
next_token_id = next_token_logits.argmax(1, keepdim=True) | |
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1) | |
if torch.eq(next_token_id, tokenizer.eos_token_id).all(): | |
break | |
caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True) | |
return caption | |
sample_images = [ | |
"sample_image1.jpg", | |
"sample_image2.jpg", | |
"sample_image3.jpg" | |
] | |
interface = gr.Interface( | |
fn=generate_caption, | |
inputs="image", # Specify the input type as "image" | |
outputs="text", | |
examples=sample_images, | |
title="Image Captioning Model", | |
description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption." | |
) | |
# Run the interface | |
interface.launch(debug=True) | |