paligemma-vqa / app.py
Scharbhen's picture
Update app.py
2f21f91 verified
raw
history blame
2.3 kB
import gradio as gr
import requests
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
import os
from huggingface_hub import login
login(os.getenv('hf_token'))
@spaces.GPU
def infer_ocrvqa(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ocrvqa-896").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ocrvqa-896")
inputs = processor(images=image,text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
@spaces.GPU
def infer_doc(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-docvqa-896").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-docvqa-896")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>PaliGemma для VQA/OCR 📄<center><h1>")
gr.HTML("<h3><center>Использование модели as is без файнтюнинга на документах. ⚡</h3>")
with gr.Tab(label="Ответы на вопросы по документам"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Answer")
submit_btn.click(infer_doc, [input_img, question], [output])
with gr.Tab(label="Чтение текста со сканов"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Infer")
submit_btn.click(infer_ocrvqa, [input_img, question], [output])
demo.launch(debug=True)