import json import os import re from collections import defaultdict import glob import numpy as np import time import gradio as gr import torch import torch.nn as nn import torchvision.models as models import torch.nn.functional as F from torch import optim from torch.utils.data import Dataset from torchvision import transforms from torch.utils.data import DataLoader from PIL import Image from vocab import Vocabulary from model import * # Paths ckpt_dir = "./best_model.pth" # Path to the trained model ques_vocab_path = "./question_vocabs.txt" # Path to question vocabulary ans_vocab_path = "./annotation_vocabs.txt" # Path to answer vocabulary device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Model Parameters FEATURE_SIZE, WORD_EMBED = 1024, 300 MAX_QU_LEN, NUM_HIDDEN, HIDDEN_SIZE = 30, 2, 512 # Load Vocabulary ques_vocab = Vocabulary(ques_vocab_path) ans_vocab = Vocabulary(ans_vocab_path) # Image Preprocessing transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) # Preprocess Question def preprocess_question(question, ques_vocab, max_qu_len): tokens = question.lower().split() qu_idx = [ques_vocab.word2idx(token) for token in tokens] qu_idx = qu_idx[:max_qu_len] + [ques_vocab.word2idx('')] * (max_qu_len - len(qu_idx)) return torch.tensor(qu_idx, dtype=torch.long) # Load Model and Checkpoint def load_model(): model = VQAModel_attn(feature_size=FEATURE_SIZE, ques_vocab_size=ques_vocab.vocabulary_size, ans_vocab_size=ans_vocab.vocabulary_size, word_embed=WORD_EMBED, hidden_size=HIDDEN_SIZE, num_hidden=NUM_HIDDEN).to(device) # Fix checkpoint key mismatches state_dict = torch.load(ckpt_dir, map_location=device) fixed_state_dict = {k.replace("qu_encoder", "ques_encoder"): v for k, v in state_dict.items()} model.load_state_dict(fixed_state_dict) model.eval() return model model = load_model() # Inference Function def vqa_interface(image, question): """Function for Gradio interface to process an image and a question.""" # Preprocess Image image = Image.open(image).convert('RGB') image_tensor = transform(image).unsqueeze(0).to(device) # Preprocess Question question_tensor = preprocess_question(question, ques_vocab, MAX_QU_LEN).unsqueeze(0).to(device) # Model Inference with torch.no_grad(): logits = model(image_tensor, question_tensor) prediction = torch.argmax(logits, dim=1) answer = ans_vocab.idx2word(prediction.item()) return answer # Gradio Interface interface = gr.Interface( fn=vqa_interface, inputs=[ gr.Image(type="filepath", label="Upload an Image"), gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question") ], outputs=gr.Textbox(label="Predicted Answer"), title="Visual Question Answering", description="Upload an image and ask a question about it. The model will predict the answer.", allow_flagging="never" ) if __name__ == "__main__": interface.launch(share=True)