VQA / app.py
ZubairAhmed777's picture
Update app.py
2d20c0d verified
import json
import os
import re
from collections import defaultdict
import glob
import numpy as np
import time
import gradio as gr
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
from vocab import Vocabulary
from model import *
# Paths
ckpt_dir = "./best_model.pth" # Path to the trained model
ques_vocab_path = "./question_vocabs.txt" # Path to question vocabulary
ans_vocab_path = "./annotation_vocabs.txt" # Path to answer vocabulary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model Parameters
FEATURE_SIZE, WORD_EMBED = 1024, 300
MAX_QU_LEN, NUM_HIDDEN, HIDDEN_SIZE = 30, 2, 512
# Load Vocabulary
ques_vocab = Vocabulary(ques_vocab_path)
ans_vocab = Vocabulary(ans_vocab_path)
# Image Preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# Preprocess Question
def preprocess_question(question, ques_vocab, max_qu_len):
tokens = question.lower().split()
qu_idx = [ques_vocab.word2idx(token) for token in tokens]
qu_idx = qu_idx[:max_qu_len] + [ques_vocab.word2idx('<pad>')] * (max_qu_len - len(qu_idx))
return torch.tensor(qu_idx, dtype=torch.long)
# Load Model and Checkpoint
def load_model():
model = VQAModel_attn(feature_size=FEATURE_SIZE,
ques_vocab_size=ques_vocab.vocabulary_size,
ans_vocab_size=ans_vocab.vocabulary_size,
word_embed=WORD_EMBED,
hidden_size=HIDDEN_SIZE,
num_hidden=NUM_HIDDEN).to(device)
# Fix checkpoint key mismatches
state_dict = torch.load(ckpt_dir, map_location=device)
fixed_state_dict = {k.replace("qu_encoder", "ques_encoder"): v for k, v in state_dict.items()}
model.load_state_dict(fixed_state_dict)
model.eval()
return model
model = load_model()
# Inference Function
def vqa_interface(image, question):
"""Function for Gradio interface to process an image and a question."""
# Preprocess Image
image = Image.open(image).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
# Preprocess Question
question_tensor = preprocess_question(question, ques_vocab, MAX_QU_LEN).unsqueeze(0).to(device)
# Model Inference
with torch.no_grad():
logits = model(image_tensor, question_tensor)
prediction = torch.argmax(logits, dim=1)
answer = ans_vocab.idx2word(prediction.item())
return answer
# Gradio Interface
interface = gr.Interface(
fn=vqa_interface,
inputs=[
gr.Image(type="filepath", label="Upload an Image"),
gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question")
],
outputs=gr.Textbox(label="Predicted Answer"),
title="Visual Question Answering",
description="Upload an image and ask a question about it. The model will predict the answer.",
allow_flagging="never"
)
if __name__ == "__main__":
interface.launch(share=True)