|
import streamlit as st |
|
from transformers import ViltProcessor, ViltForQuestionAnswering |
|
from PIL import Image |
|
|
|
def load_model(): |
|
processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
|
model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") |
|
return processor, model |
|
|
|
def predict(image, text, processor, model): |
|
encoding = processor(image, text, return_tensors="pt") |
|
outputs = model(**encoding) |
|
logits = outputs.logits |
|
idx = logits.argmax(-1).item() |
|
return model.config.id2label[idx] |
|
|
|
def main(): |
|
st.title("VQA") |
|
st.write("Upload an image and input a question to get an answer.") |
|
|
|
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_image is not None: |
|
image = Image.open(uploaded_image) |
|
question = st.text_input("Question about the image:") |
|
|
|
if question: |
|
processor, model = load_model() |
|
answer = predict(image, question, processor, model) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.image(image, caption='Uploaded Image.', use_column_width=True) |
|
with col2: |
|
st.write(f"**Question:** {question}") |
|
st.write(f"**Answer:** {answer}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|