|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor |
|
from PIL import Image |
|
import spaces |
|
import os |
|
|
|
from huggingface_hub import login |
|
login(os.environ["HF_KEY"]) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", trust_remote_code=True, device_map='auto') |
|
processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto') |
|
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm", device_map='auto') |
|
|
|
|
|
TASK2INSTRUCTION = { |
|
"caption": "画像を詳細に述べてください。", |
|
"tag": "与えられた単語を使って、画像を詳細に述べてください。", |
|
"vqa": "与えられた画像を下に、質問に答えてください。", |
|
} |
|
|
|
def build_prompt(task="caption", input=None, sep="\n\n### "): |
|
assert task in TASK2INSTRUCTION, f"Please choose from {list(TASK2INSTRUCTION.keys())}" |
|
if task in ["tag", "vqa"]: |
|
assert input is not None, "Please fill in `input`!" |
|
if task == "tag" and isinstance(input, list): |
|
input = "、".join(input) |
|
else: |
|
assert input is None, f"`{task}` mode doesn't support to input questions" |
|
sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。" |
|
p = sys_msg |
|
roles = ["指示", "応答"] |
|
instruction = TASK2INSTRUCTION[task] |
|
msgs = [": \n" + instruction, ": \n"] |
|
if input: |
|
roles.insert(1, "入力") |
|
msgs.insert(1, ": \n" + input) |
|
for role, msg in zip(roles, msgs): |
|
p += sep + role + msg |
|
return p |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def generate_text(image, task, input_text=None): |
|
prompt = build_prompt(task=task, input=input_text) |
|
inputs = processor(images=image, return_tensors="pt") |
|
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt") |
|
inputs.update(text_encoding) |
|
outputs = model.generate( |
|
**inputs.to(device=device, dtype=model.dtype), |
|
do_sample=False, |
|
num_beams=5, |
|
max_new_tokens=128, |
|
min_length=1, |
|
repetition_penalty=1.5, |
|
) |
|
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip() |
|
return generated_text |
|
|
|
|
|
image_input = gr.Image(label="Upload an image") |
|
task_input = gr.Radio(choices=["caption", "tag", "vqa"], value="caption", label="Select a task") |
|
text_input = gr.Textbox(label="Enter text (for tag or vqa tasks)") |
|
|
|
output = gr.Textbox(label="Generated text") |
|
|
|
interface = gr.Interface( |
|
fn=generate_text, |
|
inputs=[image_input, task_input, text_input], |
|
outputs=output, |
|
examples=[ |
|
["examples/example_1.jpeg", "caption", None], |
|
["examples/example_2.jpg", "tag", "寿司、箸"], |
|
["examples/example_3.jpg", "vqa", "この画像を説明する"], |
|
], |
|
) |
|
|
|
interface.launch() |