Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import os | |
import argparse | |
import time | |
import subprocess | |
import spaces | |
import cumo.serve.gradio_web_server as gws | |
import datetime | |
import json | |
import gradio as gr | |
import requests | |
from PIL import Image | |
from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle) | |
from cumo.constants import LOGDIR | |
from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) | |
import hashlib | |
import torch | |
import io | |
from cumo.constants import WORKER_HEART_BEAT_INTERVAL | |
from cumo.utils import (build_logger, server_error_msg, | |
pretty_print_semaphore) | |
from cumo.model.builder import load_pretrained_model | |
from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token | |
from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
# Execute the pip install command with additional options | |
#subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'] | |
headers = {"User-Agent": "CuMo"} | |
no_change_btn = gr.Button() | |
enable_btn = gr.Button(interactive=True) | |
disable_btn = gr.Button(interactive=False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_path = './checkpoints/CuMo-mistral-7b' | |
model_base = 'mistralai/Mistral-7B-Instruct-v0.2' | |
model_name = 'CuMo-mistral-7b' | |
conv_mode = 'mistral_instruct_system' | |
load_8bit = False | |
load_4bit = False | |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device, use_flash_attn=False) | |
model.config.training = False | |
def upvote_last_response(state): | |
return ("",) + (disable_btn,) * 3 | |
def downvote_last_response(state): | |
return ("",) + (disable_btn,) * 3 | |
def flag_last_response(state): | |
return ("",) + (disable_btn,) * 3 | |
def clear_history(): | |
state = default_conversation.copy() | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def add_text(state, imagebox, textbox, image_process_mode): | |
if state is None: | |
state = conv_templates[conv_mode].copy() | |
if imagebox is not None: | |
textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox | |
image = Image.open(imagebox).convert('RGB') | |
if imagebox is not None: | |
textbox = (textbox, image, image_process_mode) | |
state.append_message(state.roles[0], textbox) | |
state.append_message(state.roles[1], None) | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
def delete_text(state, image_process_mode): | |
state.messages[-1][-1] = None | |
prev_human_msg = state.messages[-2] | |
if type(prev_human_msg[1]) in (tuple, list): | |
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
def regenerate(state, image_process_mode): | |
state.messages[-1][-1] = None | |
prev_human_msg = state.messages[-2] | |
if type(prev_human_msg[1]) in (tuple, list): | |
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
state.skip_next = False | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens): | |
prompt = state.get_prompt() | |
images = state.get_images(return_pil=True) | |
#prompt, image_args = process_image(prompt, images) | |
ori_prompt = prompt | |
num_image_tokens = 0 | |
if images is not None and len(images) > 0: | |
if len(images) > 0: | |
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): | |
raise ValueError("Number of images does not match number of <image> tokens in prompt") | |
#images = [load_image_from_base64(image) for image in images] | |
image_sizes = [image.size for image in images] | |
images = process_images(images, image_processor, model.config) | |
if type(images) is list: | |
images = [image.to(model.device, dtype=torch.float16) for image in images] | |
else: | |
images = images.to(model.device, dtype=torch.float16) | |
replace_token = DEFAULT_IMAGE_TOKEN | |
if getattr(model.config, 'mm_use_im_start_end', False): | |
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN | |
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) | |
num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches | |
else: | |
images = None | |
image_sizes = None | |
image_args = {"images": images, "image_sizes": image_sizes} | |
else: | |
images = None | |
image_args = {} | |
max_context_length = getattr(model.config, 'max_position_embeddings', 2048) | |
max_new_tokens = 512 | |
do_sample = True if temperature > 0.001 else False | |
stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2 | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) | |
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) | |
if max_new_tokens < 1: | |
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" | |
return | |
thread = Thread(target=model.generate, kwargs=dict( | |
inputs=input_ids, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_p=top_p, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
use_cache=True, | |
pad_token_id=tokenizer.eos_token_id, | |
**image_args | |
)) | |
thread.start() | |
generated_text = '' | |
for new_text in streamer: | |
generated_text += new_text | |
if generated_text.endswith(stop_str): | |
generated_text = generated_text[:-len(stop_str)] | |
state.messages[-1][-1] = generated_text | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5 | |
torch.cuda.empty_cache() | |
title_markdown = (""" | |
# CuMo: Scaling Multimodal LLM with Co-Upcycled Mixture-of-Experts | |
[[Project Page](https://chrisjuniorli.github.io/project/CuMo/)] [[Code](https://github.com/SHI-Labs/CuMo)] [[Model](https://huggingface.co/shi-labs/CuMo-mistral-7b)] | π [[Arxiv](https://arxiv.org/pdf/2405.05949)]] | |
""") | |
tos_markdown = (""" | |
### Terms of use | |
By using this service, users are required to agree to the following terms: | |
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. | |
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. | |
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. | |
""") | |
learn_more_markdown = (""" | |
### License | |
The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation. | |
""") | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) | |
with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css=block_css) as demo: | |
state = gr.State() | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
imagebox = gr.Image(label="Input Image", type="filepath") | |
image_process_mode = gr.Radio( | |
["Crop", "Resize", "Pad", "Default"], | |
value="Default", | |
label="Preprocess for non-square image", visible=False) | |
#cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
cur_dir = './cumo/serve' | |
gr.Examples(examples=[ | |
[f"{cur_dir}/examples/aveger.jpg", "Can you introduce this movie based on the poster?"], | |
[f"{cur_dir}/examples/fridge.webp", "Can you describe what groceries are presented in this fridge?"], | |
[f"{cur_dir}/examples/su7_4.jpg", "What car is it in this image?"], | |
[f"{cur_dir}/examples/nvidia.jpeg", "Can you tell me what happened in this image?"], | |
[f"{cur_dir}/examples/animal.webp", "What animals are in this image?"], | |
[f"{cur_dir}/examples/disney.jpeg", "How many characters in this image?"], | |
[f"{cur_dir}/examples/reka_6.jpeg", "What colour is my hat (im sitting on the bear)?"], | |
], inputs=[imagebox, textbox], cache_examples=False) | |
with gr.Accordion("Parameters", open=False) as parameter_row: | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) | |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) | |
with gr.Column(scale=8): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="CuMo Chatbot", | |
height=650, | |
layout="panel", | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox.render() | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button(value="Send", variant="primary") | |
with gr.Row(elem_id="buttons") as button_row: | |
upvote_btn = gr.Button(value="π Upvote", interactive=False) | |
downvote_btn = gr.Button(value="π Downvote", interactive=False) | |
flag_btn = gr.Button(value="β οΈ Flag", interactive=False) | |
#stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False) | |
regenerate_btn = gr.Button(value="π Regenerate", interactive=False) | |
clear_btn = gr.Button(value="ποΈ Clear", interactive=False) | |
gr.Markdown(tos_markdown) | |
gr.Markdown(learn_more_markdown) | |
url_params = gr.JSON(visible=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
upvote_btn.click( | |
upvote_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
downvote_btn.click( | |
downvote_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
flag_btn.click( | |
flag_last_response, | |
[state], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
clear_btn.click( | |
clear_history, | |
None, | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
) | |
regenerate_btn.click( | |
delete_text, | |
[state, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
generate, | |
[state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
) | |
textbox.submit( | |
add_text, | |
[state, imagebox, textbox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
generate, | |
[state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
) | |
submit_btn.click( | |
add_text, | |
[state, imagebox, textbox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
).then( | |
generate, | |
[state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
) | |
demo.queue( | |
status_update_rate=10, | |
api_open=False | |
).launch() |