Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
from __future__ import annotations | |
import os | |
import gradio as gr | |
import PIL.Image | |
import spaces | |
import torch | |
from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor | |
DESCRIPTION = "# InstructBLIP" | |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024")) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_id = "Salesforce/instructblip-vicuna-7b" | |
processor = InstructBlipProcessor.from_pretrained(model_id) | |
model = InstructBlipForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") | |
def run( | |
image: PIL.Image.Image, | |
prompt: str, | |
text_decoding_method: str = "Nucleus sampling", | |
num_beams: int = 5, | |
max_length: int = 256, | |
min_length: int = 1, | |
top_p: float = 0.9, | |
repetition_penalty: float = 1.5, | |
length_penalty: float = 1.0, | |
temperature: float = 1.0, | |
) -> str: | |
h, w = image.size | |
scale = MAX_IMAGE_SIZE / max(h, w) | |
if scale < 1: | |
new_w = int(w * scale) | |
new_h = int(h * scale) | |
image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS) | |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) | |
generated_ids = model.generate( | |
**inputs, | |
do_sample=text_decoding_method == "Nucleus sampling", | |
num_beams=num_beams, | |
max_length=max_length, | |
min_length=min_length, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
temperature=temperature, | |
) | |
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
return generated_caption | |
with gr.Blocks(css="style.css") as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil") | |
prompt = gr.Textbox(label="Prompt") | |
run_button = gr.Button() | |
with gr.Accordion(label="Advanced options", open=False): | |
text_decoding_method = gr.Radio( | |
label="Text Decoding Method", | |
choices=["Beam search", "Nucleus sampling"], | |
value="Nucleus sampling", | |
) | |
num_beams = gr.Slider( | |
label="Number of Beams", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=5, | |
) | |
max_length = gr.Slider( | |
label="Max Length", | |
minimum=1, | |
maximum=512, | |
step=1, | |
value=256, | |
) | |
min_length = gr.Slider( | |
label="Minimum Length", | |
minimum=1, | |
maximum=64, | |
step=1, | |
value=1, | |
) | |
top_p = gr.Slider( | |
label="Top P", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.9, | |
) | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", | |
info="Larger value prevents repetition.", | |
minimum=1.0, | |
maximum=5.0, | |
step=0.5, | |
value=1.5, | |
) | |
length_penalty = gr.Slider( | |
label="Length Penalty", | |
info="Set to larger for longer sequence, used with beam search.", | |
minimum=-1.0, | |
maximum=2.0, | |
step=0.2, | |
value=1.0, | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
info="Used with nucleus sampling.", | |
minimum=0.5, | |
maximum=1.0, | |
step=0.1, | |
value=1.0, | |
) | |
with gr.Column(): | |
output = gr.Textbox(label="Result") | |
gr.on( | |
triggers=[prompt.submit, run_button.click], | |
fn=run, | |
inputs=[ | |
input_image, | |
prompt, | |
text_decoding_method, | |
num_beams, | |
max_length, | |
min_length, | |
top_p, | |
repetition_penalty, | |
length_penalty, | |
temperature, | |
], | |
outputs=output, | |
api_name="run", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() | |