import gradio as gr import torch from PIL import Image from torchvision import transforms from transformers import T5Tokenizer, ViTFeatureExtractor class Encoder(nn.Module): def __init__(self, pretrained_model): """ Implements the Encoder." Args: pretrained_model (str): name of the pretrained model """ super(Encoder, self).__init__() self.encoder = ViTModel.from_pretrained(pretrained_model) def forward(self, input): out = self.encoder(pixel_values = input) return out class Decoder(nn.Module): def __init__(self, pretrained_model, encoder_modeldim): """ Implements the Decoder." Args: pretrained_model (str): name of the pretrained model """ super(Decoder, self).__init__() self.decoder = T5ForConditionalGeneration.from_pretrained(pretrained_model) self.linear = nn.Linear(self.decoder.model_dim, encoder_modeldim, bias = False) self.encoder_modeldim = encoder_modeldim def forward(self, output_encoder, targets, decoder_ids=None): if self.decoder.model_dim!=self.encoder_modeldim: print(f"Changed model hidden dimension from {self.encoder_modeldim} to {self.decoder.model_dim}") output_encoder = self.linear(output_encoder) print(output_encoder.shape) # Validation/Testing if decoder_ids is not None: out = self.decoder(encoder_outputs=output_encoder, decoder_input_ids=decoder_ids) # Training else: out = self.decoder(encoder_outputs=output_encoder, labels=targets) return out class EncoderDecoder(nn.Module): def __init__(self, pretrained_model: Tuple[str], encoder_dmodel=768, eos_token_id=None, pad_token_id=None): """ Implements a model that combines MyEncoder and MyDecoder." Args: pretrained_model (tuple): name of the pretrained model encoder_dmodel (int): hidden dimension of the encoder output eos_token_id (torch.long): token used for end of sentence pad_token_id (torch.long): token used for padding """ super(EncoderDecoder, self).__init__() self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.encoder = Encoder(pretrained_model[0]) self.encoder_dmodel = encoder_dmodel # Freeze parameters from encoder #for p in self.encoder.parameters(): # p.requires_grad=False self.decoder = Decoder(pretrained_model[1], self.encoder_dmodel) self.decoder_start_token_id = self.decoder.decoder.config.decoder_start_token_id def forward(self, images = None, targets = None, decoder_ids = None): output_encoder = self.encoder(images) out = self.decoder(output_encoder, targets, decoder_ids) return out # Model loading and setting up the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load("model_vit_ai.pt", map_location=device) model.to(device) # Tokenizer and Feature Extractor tokenizer = T5Tokenizer.from_pretrained('t5-base') feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') # Define the image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) ]) def preprocess_image(image): image = Image.fromarray(image.astype('uint8'), 'RGB') image = transform(image) return image.unsqueeze(0) def generate_caption(image): model.eval() with torch.no_grad(): image_tensor = preprocess_image(image).to(device) decoder_input_ids = torch.full((1, 1), model.decoder_start_token_id, dtype=torch.long, device=device) for _ in range(50): outputs = model(images=image_tensor, decoder_ids=decoder_input_ids) next_token_logits = outputs.logits[:, -1, :] next_token_id = next_token_logits.argmax(1, keepdim=True) decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1) if torch.eq(next_token_id, tokenizer.eos_token_id).all(): break caption = tokenizer.decode(decoder_input_ids.squeeze(0), skip_special_tokens=True) return caption sample_images = [ "sample_image1.jpg", "sample_image2.jpg", "sample_image3.jpg" ] # Define Gradio interface interface = gr.Interface( fn=generate_caption, inputs=gr.inputs.Image(source="upload", tool='editor', type="numpy", label="Upload an image or take a photo"), outputs='text', examples=sample_images, title="Image Captioning Model", description="Upload an image, select a sample image, or use your webcam to take a photo and generate a caption." ) # Run the interface interface.launch(debug=True)