paligemma-vqa / app.py
Scharbhen's picture
Update app.py
3f1d188 verified
raw
history blame
2.74 kB
import gradio as gr
import requests
from PIL import Image
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
import os
from huggingface_hub import login
login(os.getenv('hf_token'))
@spaces.GPU
def infer_ocrvqa(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-ocrvqa-896").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-ocrvqa-896")
systemprompt = "Ты ассистент по анализу финансовых отчетов. Ниже приведены вопросы по данным на изображении. Необходимо отвечать на вопросы по суммам в таблицах максимально точно и обращать внимание на названия колонок таблиц. Вопросы: "
inputs = processor(images=image,text=systemprompt+question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
@spaces.GPU
def infer_doc(image, question):
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-ft-docvqa-896").to("cuda")
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-ft-docvqa-896")
inputs = processor(images=image, text=question, return_tensors="pt").to("cuda")
predictions = model.generate(**inputs, max_new_tokens=100)
return processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
css = """
#mkd {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1><center>PaliGemma для VQA/OCR 📄<center><h1>")
gr.HTML("<h3><center>Использование модели as is без файнтюнинга на документах. ⚡</h3>")
with gr.Tab(label="Ответы на вопросы по документам"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Answer")
submit_btn.click(infer_doc, [input_img, question], [output])
with gr.Tab(label="Чтение текста со сканов"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Document")
question = gr.Text(label="Question")
submit_btn = gr.Button(value="Submit")
output = gr.Text(label="Infer")
submit_btn.click(infer_ocrvqa, [input_img, question], [output])
demo.launch(debug=True)