Spaces:
Sleeping
Sleeping
File size: 4,993 Bytes
a14156a 59e90bd a14156a 7cb104b a14156a 7cb104b a14156a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from transformers import T5Tokenizer, ViTFeatureExtractor
from torch import nn
class Encoder(nn.Module):
def __init__(self, pretrained_model):
"""
Implements the Encoder."
Args:
pretrained_model (str): name of the pretrained model
"""
super(Encoder, self).__init__()
self.encoder = ViTModel.from_pretrained(pretrained_model)
def forward(self, input):
out = self.encoder(pixel_values = input)
return out
class Decoder(nn.Module):
def __init__(self, pretrained_model, encoder_modeldim):
"""
Implements the Decoder."
Args:
pretrained_model (str): name of the pretrained model
"""
super(Decoder, self).__init__()
self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model)
self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False)
self.encoder_modeldim = encoder_modeldim
def forward(self, output_encoder, targets, decoder_ids=None):
if self.decoder.model_dim!=self.encoder_modeldim:
print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}")
output_encoder = self.linear(output_encoder)
print(output_encoder.shape)
# Validation/Testing
if decoder_ids is not None:
out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids)
# Training
else:
out = self.decoder(encoder_outputs=output_encoder, labels=targets)
return out
class EncoderDecoder(nn.Module):
def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None):
"""
Implements a model that combines MyEncoder and MyDecoder."
Args:
pretrained_model (tuple): name of the pretrained model
encoder_dmodel (int): hidden dimension of the encoder output
eos_token_id (torch.long): token used for end of sentence
pad_token_id (torch.long): token used for padding
"""
super(EncoderDecoder, self).__init__()
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.encoder = Encoder(pretrained_model[0])
self.encoder_dmodel = encoder_dmodel
# Freeze parameters from encoder
#for p in self.encoder.parameters():
# p.requires_grad=False
self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel)
self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id
def forward(self, images = None, targets = None, decoder_ids = None):
output_encoder = self.encoder(images)
out = self.decoder(output_encoder, targets, decoder_ids)
return out
# Model loading and setting up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("model_vit_ai.pt", map_location=device)
model.to(device)
# Tokenizer and Feature Extractor
tokenizer = T5Tokenizer.from_pretrained('t5-base')
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
# Define the image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])
def preprocess_image(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = transform(image)
return image.unsqueeze(0)
def generate_caption(image):
model.eval()
with torch.no_grad():
image_tensor = preprocess_image(image).to(device)
decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device)
for _ in range(50):
outputs = model(images=image_tensor, decoder_ids=decoder_input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_id = next_token_logits.argmax(1, keepdim=True)
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)
if torch.eq(next_token_id, tokenizer.eos_token_id).all():
break
caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True)
return caption
sample_images = [
"sample_image1.jpg",
"sample_image2.jpg",
"sample_image3.jpg"
]
# Define Gradio interface
interface = gr.Interface(
fn=generate_caption,
inputs=gr.inputs.Image(source="upload", tool='editor', type="numpy", label="Upload an image or take a photo"),
outputs='text',
examples=sample_images,
title="Image Captioning Model",
description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption."
)
# Run the interface
interface.launch(debug=True)
|