iamkhadke's picture
Update app.py
0fc936c
raw
history blame
1.85 kB
import re
import gradio as gr
import torch
from functools import partial
from PIL import Image
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-docvqa-large")
processor = Pix2StructProcessor.from_pretrained("google/pix2struct-docvqa-large")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def generate(img, questions):
global model, processor
inputs = processor(images=[img for _ in range(len(questions))], text=questions, return_tensors="pt").to(device)
predictions = model.generate(**inputs, max_new_tokens=256)
return zip(questions, processor.batch_decode(predictions, skip_special_tokens=True))
def process_document(image, question):
return generate(image, [question])
description = "Gradio Demo for Pix2Struct, an instance of `VisionEncoderDecoderModel` fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below. \n Note: Average Inference time 60s."
article = "<p style='text-align: center'><a href='https://www.linkedin.com/in/khadke-chetan/' target='_blank'>Chetan Khadke</a></p> |<a href='https://arxiv.org/abs/2111.15664' target='_blank'>Pix2Struct for DocVQA</a> | <a href='https://arxiv.org/pdf/2210.03347.pdf' target='_blank'>Paper link</a></p>"
demo = gr.Interface(
fn=process_document,
inputs=["image", "text"],
outputs="json",
title="Demo: Pix2Struct for DocVQA",
description=description,
article=article,
enable_queue=True,
examples=[["example_1.png", "When is the coffee break?"], ["example_2.jpeg", "What's the population of Stoddard?"]],
cache_examples=False)
demo.launch()