diabolic6045's picture
Update app.py
def2b7f verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor
from PIL import Image
import spaces
import os
from huggingface_hub import login
login(os.environ["HF_KEY"])
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", trust_remote_code=True, device_map='auto')
processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto')
# Define the helper function to build prompts
TASK2INSTRUCTION = {
"caption": "画像を詳細に述べてください。",
"tag": "与えられた単語を使って、画像を詳細に述べてください。",
"vqa": "与えられた画像を下に、質問に答えてください。",
}
def build_prompt(task="caption", input=None, sep="\n\n### "):
assert task in TASK2INSTRUCTION, f"Please choose from {list(TASK2INSTRUCTION.keys())}"
if task in ["tag", "vqa"]:
assert input is not None, "Please fill in `input`!"
if task == "tag" and isinstance(input, list):
input = "、".join(input)
else:
assert input is None, f"`{task}` mode doesn't support to input questions"
sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
p = sys_msg
roles = ["指示", "応答"]
instruction = TASK2INSTRUCTION[task]
msgs = [": \n" + instruction, ": \n"]
if input:
roles.insert(1, "入力")
msgs.insert(1, ": \n" + input)
for role, msg in zip(roles, msgs):
p += sep + role + msg
return p
# Define the function to generate text from the image and prompt
@spaces.GPU(duration=120)
def generate_text(image, task, input_text=None):
prompt = build_prompt(task=task, input=input_text)
inputs = processor(images=image, return_tensors="pt")
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs.update(text_encoding)
outputs = model.generate(
**inputs.to(device=device, dtype=model.dtype),
do_sample=False,
num_beams=5,
max_new_tokens=128,
min_length=1,
repetition_penalty=1.5,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
return generated_text
# Define the Gradio interface
image_input = gr.Image(label="Upload an image")
task_input = gr.Radio(choices=["caption", "tag", "vqa"], value="caption", label="Select a task")
text_input = gr.Textbox(label="Enter text (for tag or vqa tasks)")
output = gr.Textbox(label="Generated text")
interface = gr.Interface(
fn=generate_text,
inputs=[image_input, task_input, text_input],
outputs=output,
examples=[
["examples/example_1.jpeg", "caption", None],
["examples/example_2.jpg", "tag", "寿司、箸"],
["examples/example_3.jpg", "vqa", "この画像を説明する"],
],
)
interface.launch()