Spaces:
Sleeping
Sleeping
import json | |
import os | |
import re | |
from collections import defaultdict | |
import glob | |
import numpy as np | |
import time | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torchvision.models as models | |
import torch.nn.functional as F | |
from torch import optim | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from torch.utils.data import DataLoader | |
from PIL import Image | |
from vocab import Vocabulary | |
from model import * | |
# Paths | |
ckpt_dir = "./best_model.pth" # Path to the trained model | |
ques_vocab_path = "./question_vocabs.txt" # Path to question vocabulary | |
ans_vocab_path = "./annotation_vocabs.txt" # Path to answer vocabulary | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Model Parameters | |
FEATURE_SIZE, WORD_EMBED = 1024, 300 | |
MAX_QU_LEN, NUM_HIDDEN, HIDDEN_SIZE = 30, 2, 512 | |
# Load Vocabulary | |
ques_vocab = Vocabulary(ques_vocab_path) | |
ans_vocab = Vocabulary(ans_vocab_path) | |
# Image Preprocessing | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
]) | |
# Preprocess Question | |
def preprocess_question(question, ques_vocab, max_qu_len): | |
tokens = question.lower().split() | |
qu_idx = [ques_vocab.word2idx(token) for token in tokens] | |
qu_idx = qu_idx[:max_qu_len] + [ques_vocab.word2idx('<pad>')] * (max_qu_len - len(qu_idx)) | |
return torch.tensor(qu_idx, dtype=torch.long) | |
# Load Model and Checkpoint | |
def load_model(): | |
model = VQAModel_attn(feature_size=FEATURE_SIZE, | |
ques_vocab_size=ques_vocab.vocabulary_size, | |
ans_vocab_size=ans_vocab.vocabulary_size, | |
word_embed=WORD_EMBED, | |
hidden_size=HIDDEN_SIZE, | |
num_hidden=NUM_HIDDEN).to(device) | |
# Fix checkpoint key mismatches | |
state_dict = torch.load(ckpt_dir, map_location=device) | |
fixed_state_dict = {k.replace("qu_encoder", "ques_encoder"): v for k, v in state_dict.items()} | |
model.load_state_dict(fixed_state_dict) | |
model.eval() | |
return model | |
model = load_model() | |
# Inference Function | |
def vqa_interface(image, question): | |
"""Function for Gradio interface to process an image and a question.""" | |
# Preprocess Image | |
image = Image.open(image).convert('RGB') | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
# Preprocess Question | |
question_tensor = preprocess_question(question, ques_vocab, MAX_QU_LEN).unsqueeze(0).to(device) | |
# Model Inference | |
with torch.no_grad(): | |
logits = model(image_tensor, question_tensor) | |
prediction = torch.argmax(logits, dim=1) | |
answer = ans_vocab.idx2word(prediction.item()) | |
return answer | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=vqa_interface, | |
inputs=[ | |
gr.Image(type="filepath", label="Upload an Image"), | |
gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question") | |
], | |
outputs=gr.Textbox(label="Predicted Answer"), | |
title="Visual Question Answering", | |
description="Upload an image and ask a question about it. The model will predict the answer.", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
interface.launch(share=True) | |