Spaces:
Runtime error
Runtime error
import sys | |
from threading import Thread | |
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoProcessor | |
from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
model = AutoModel.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True).to(device) | |
processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True) | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = [151645] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def response(message, history, image): | |
stop = StopOnTokens() | |
messages = [{"role": "system", "content": "You are a helpful assistant."}] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
if len(messages) == 1: | |
message = f" <image>{message}" | |
messages.append({"role": "user", "content": message}) | |
model_inputs = processor.tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
) | |
image = ( | |
processor.feature_extractor(image) | |
.unsqueeze(0) | |
) | |
attention_mask = torch.ones( | |
1, model_inputs.shape[1] + processor.num_image_latents - 1 | |
) | |
model_inputs = { | |
"input_ids": model_inputs, | |
"images": image, | |
"attention_mask": attention_mask | |
} | |
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | |
streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=1024, | |
stopping_criteria=StoppingCriteriaList([stop]) | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
history.append([message, ""]) | |
partial_response = "" | |
for new_token in streamer: | |
partial_response += new_token | |
history[-1][1] = partial_response | |
yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
image = gr.Image(type="pil") | |
with gr.Column(): | |
chat = gr.Chatbot(show_label=False) | |
message = gr.Textbox(interactive=True, show_label=False, container=False) | |
with gr.Row(): | |
gr.ClearButton([chat, message]) | |
stop = gr.Button(value="Stop", variant="stop", visible=False) | |
submit = gr.Button(value="Submit", variant="primary") | |
with gr.Row(): | |
gr.Examples( | |
[ | |
["images/interior.jpg", "Describe the image accurately."], | |
["images/cat.jpg", "Describe the image in three sentences."], | |
["images/child.jpg", "Describe the image in one sentence."], | |
], | |
[image, message], | |
label="Captioning" | |
) | |
gr.Examples( | |
[ | |
["images/scream.jpg", "What is the main emotion of this image?"], | |
["images/louvre.jpg", "Where is this landmark located?"], | |
["images/three_people.jpg", "What are these people doing?"] | |
], | |
[image, message], | |
label="VQA" | |
) | |
response_handler = ( | |
response, | |
[message, chat, image], | |
[chat, submit, stop] | |
) | |
postresponse_handler = ( | |
lambda: (gr.Button(visible=False), gr.Button(visible=True)), | |
None, | |
[stop, submit] | |
) | |
event1 = message.submit(*response_handler) | |
event1.then(*postresponse_handler) | |
event2 = submit.click(*response_handler) | |
event2.then(*postresponse_handler) | |
stop.click(None, None, None, cancels=[event1, event2]) | |
demo.queue() | |
demo.launch() | |