import sys from threading import Thread import gradio as gr import torch from transformers import AutoModel, AutoProcessor from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList device = "cuda:0" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True).to(device) processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True) class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [151645] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False @torch.no_grad() def response(message, history, image): stop = StopOnTokens() messages = [{"role": "system", "content": "You are a helpful assistant."}] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) if len(messages) == 1: message = f" {message}" messages.append({"role": "user", "content": message}) model_inputs = processor.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) image = ( processor.feature_extractor(image) .unsqueeze(0) ) attention_mask = torch.ones( 1, model_inputs.shape[1] + processor.num_image_latents - 1 ) model_inputs = { "input_ids": model_inputs, "images": image, "attention_mask": attention_mask } model_inputs = {k: v.to(device) for k, v in model_inputs.items()} streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=1024, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() history.append([message, ""]) partial_response = "" for new_token in streamer: partial_response += new_token history[-1][1] = partial_response yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True) with gr.Blocks() as demo: with gr.Row(): image = gr.Image(type="pil") with gr.Column(): chat = gr.Chatbot(show_label=False) message = gr.Textbox(interactive=True, show_label=False, container=False) with gr.Row(): gr.ClearButton([chat, message]) stop = gr.Button(value="Stop", variant="stop", visible=False) submit = gr.Button(value="Submit", variant="primary") with gr.Row(): gr.Examples( [ ["images/interior.jpg", "Describe the image accurately."], ["images/cat.jpg", "Describe the image in three sentences."], ["images/child.jpg", "Describe the image in one sentence."], ], [image, message], label="Captioning" ) gr.Examples( [ ["images/scream.jpg", "What is the main emotion of this image?"], ["images/louvre.jpg", "Where is this landmark located?"], ["images/three_people.jpg", "What are these people doing?"] ], [image, message], label="VQA" ) response_handler = ( response, [message, chat, image], [chat, submit, stop] ) postresponse_handler = ( lambda: (gr.Button(visible=False), gr.Button(visible=True)), None, [stop, submit] ) event1 = message.submit(*response_handler) event1.then(*postresponse_handler) event2 = submit.click(*response_handler) event2.then(*postresponse_handler) stop.click(None, None, None, cancels=[event1, event2]) demo.queue() demo.launch()