|
""" |
|
gradio_web_server.py |
|
|
|
Entry point for all VLM-Evaluation interactive demos; specify model and get a gradio UI where you can chat with it! |
|
|
|
This file is copied from the script used to define the gradio web server in the LLaVa codebase: |
|
https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/gradio_web_server.py with only very minor |
|
modifications. |
|
""" |
|
|
|
import argparse |
|
import datetime |
|
import hashlib |
|
import json |
|
import os |
|
import time |
|
|
|
import gradio as gr |
|
import requests |
|
from llava.constants import LOGDIR |
|
from llava.conversation import conv_templates, default_conversation |
|
from llava.utils import build_logger, moderation_msg, server_error_msg, violates_moderation |
|
|
|
from serve import INTERACTION_MODES_MAP, MODEL_ID_TO_NAME |
|
|
|
logger = build_logger("gradio_web_server", "gradio_web_server.log") |
|
|
|
headers = {"User-Agent": "PrismaticVLMs Client"} |
|
|
|
no_change_btn = gr.Button.update() |
|
enable_btn = gr.Button.update(interactive=True) |
|
disable_btn = gr.Button.update(interactive=False) |
|
|
|
|
|
def get_conv_log_filename(): |
|
t = datetime.datetime.now() |
|
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") |
|
return name |
|
|
|
|
|
def get_model_list(): |
|
ret = requests.post(args.controller_url + "/refresh_all_workers") |
|
assert ret.status_code == 200 |
|
ret = requests.post(args.controller_url + "/list_models") |
|
models = ret.json()["models"] |
|
models = sorted( |
|
models, key=lambda x: list(MODEL_ID_TO_NAME.values()).index(x) if x in MODEL_ID_TO_NAME.values() else len(models) |
|
) |
|
logger.info(f"Models: {models}") |
|
return models |
|
|
|
|
|
get_window_url_params = """ |
|
function() { |
|
const params = new URLSearchParams(window.location.search); |
|
url_params = Object.fromEntries(params); |
|
console.log(url_params); |
|
return url_params; |
|
} |
|
""" |
|
|
|
|
|
def load_demo(url_params, request: gr.Request): |
|
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") |
|
|
|
dropdown_update = gr.Dropdown.update(visible=True) |
|
if "model" in url_params: |
|
model = url_params["model"] |
|
if model in models: |
|
dropdown_update = gr.Dropdown.update(value=model, visible=True) |
|
|
|
state = default_conversation.copy() |
|
return state, dropdown_update |
|
|
|
|
|
def load_demo_refresh_model_list(request: gr.Request): |
|
logger.info(f"load_demo. ip: {request.client.host}") |
|
models = get_model_list() |
|
state = default_conversation.copy() |
|
dropdown_update = gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else "") |
|
return state, dropdown_update |
|
|
|
|
|
def vote_last_response(state, vote_type, model_selector, request: gr.Request): |
|
with open(get_conv_log_filename(), "a") as fout: |
|
data = { |
|
"tstamp": round(time.time(), 4), |
|
"type": vote_type, |
|
"model": model_selector, |
|
"state": state.dict(), |
|
"ip": request.client.host, |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
|
|
|
|
def regenerate(state, image_process_mode, request: gr.Request): |
|
logger.info(f"regenerate. ip: {request.client.host}") |
|
state.messages[-1][-1] = None |
|
prev_human_msg = state.messages[-2] |
|
if type(prev_human_msg[1]) in (tuple, list): |
|
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) |
|
state.skip_next = False |
|
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 |
|
|
|
|
|
def clear_history(request: gr.Request): |
|
logger.info(f"clear_history. ip: {request.client.host}") |
|
state = default_conversation.copy() |
|
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 |
|
|
|
|
|
def add_text(state, text, image, image_process_mode, request: gr.Request): |
|
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") |
|
if len(text) <= 0 and image is None: |
|
state.skip_next = True |
|
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 |
|
if args.moderate: |
|
flagged = violates_moderation(text) |
|
if flagged: |
|
state.skip_next = True |
|
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5 |
|
|
|
text = text[:1536] |
|
if image is not None: |
|
text = text[:1200] |
|
if "<image>" not in text: |
|
|
|
text = text + "\n<image>" |
|
text = (text, image, image_process_mode) |
|
if len(state.get_images(return_pil=True)) > 0: |
|
state = default_conversation.copy() |
|
state.append_message(state.roles[0], text) |
|
state.append_message(state.roles[1], None) |
|
state.skip_next = False |
|
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 |
|
|
|
|
|
def http_bot(state, model_selector, interaction_mode, temperature, max_new_tokens, request: gr.Request): |
|
logger.info(f"http_bot. ip: {request.client.host}") |
|
start_tstamp = time.time() |
|
model_name = model_selector |
|
|
|
if state.skip_next: |
|
|
|
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 |
|
return |
|
|
|
if len(state.messages) == state.offset + 2: |
|
|
|
|
|
new_state = conv_templates["llava_v1"].copy() |
|
new_state.append_message(new_state.roles[0], state.messages[-2][1]) |
|
new_state.append_message(new_state.roles[1], None) |
|
state = new_state |
|
|
|
|
|
controller_url = args.controller_url |
|
ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name}) |
|
worker_addr = ret.json()["address"] |
|
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") |
|
|
|
|
|
if worker_addr == "": |
|
state.messages[-1][-1] = server_error_msg |
|
yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) |
|
return |
|
|
|
|
|
prompt = state.get_prompt() |
|
|
|
all_images = state.get_images(return_pil=True) |
|
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] |
|
for image, im_hash in zip(all_images, all_image_hash): |
|
t = datetime.datetime.now() |
|
filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{im_hash}.jpg") |
|
if not os.path.isfile(filename): |
|
os.makedirs(os.path.dirname(filename), exist_ok=True) |
|
image.save(filename) |
|
|
|
|
|
pload = { |
|
"model": model_name, |
|
"prompt": prompt, |
|
"interaction_mode": interaction_mode, |
|
"temperature": float(temperature), |
|
"max_new_tokens": int(max_new_tokens), |
|
"images": f"List of {len(state.get_images())} images: {all_image_hash}", |
|
} |
|
logger.info(f"==== request ====\n{pload}") |
|
|
|
pload["images"] = state.get_images() |
|
|
|
state.messages[-1][-1] = "β" |
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
|
|
|
try: |
|
|
|
response = requests.post( |
|
worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10 |
|
) |
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
|
if chunk: |
|
data = json.loads(chunk.decode()) |
|
if data["error_code"] == 0: |
|
output = data["text"][len(prompt) :].strip() |
|
state.messages[-1][-1] = output + "β" |
|
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
|
else: |
|
output = data["text"] + f" (error_code: {data['error_code']})" |
|
state.messages[-1][-1] = output |
|
yield (state, state.to_gradio_chatbot()) + ( |
|
disable_btn, |
|
disable_btn, |
|
disable_btn, |
|
enable_btn, |
|
enable_btn, |
|
) |
|
return |
|
time.sleep(0.03) |
|
except requests.exceptions.RequestException: |
|
state.messages[-1][-1] = server_error_msg |
|
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) |
|
return |
|
|
|
state.messages[-1][-1] = state.messages[-1][-1][:-1] |
|
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 |
|
|
|
finish_tstamp = time.time() |
|
logger.info(f"{output}") |
|
|
|
with open(get_conv_log_filename(), "a") as fout: |
|
data = { |
|
"tstamp": round(finish_tstamp, 4), |
|
"type": "chat", |
|
"model": model_name, |
|
"start": round(start_tstamp, 4), |
|
"finish": round(finish_tstamp, 4), |
|
"state": state.dict(), |
|
"images": all_image_hash, |
|
"ip": request.client.host, |
|
} |
|
fout.write(json.dumps(data) + "\n") |
|
|
|
|
|
title_markdown = """ |
|
# Prismatic VLMs: Investigating the Design Space of Visually-Conditioned Language Models |
|
[[[Training Code](github.com/TRI-ML/prismatic-vlms)] |
|
[[[Evaluation Code](github.com/TRI-ML/vlm-evaluation)] |
|
| π [[Paper](https://arxiv.org/abs/2402.07865)] |
|
""" |
|
|
|
tos_markdown = """ |
|
### Terms of use |
|
By using this service, users are required to agree to the following terms: |
|
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may |
|
generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The |
|
service may collect user dialogue data for future research. For an optimal experience, |
|
please use desktop computers for this demo, as mobile devices may compromise its quality. This website |
|
is heavily inspired by the website released by [LLaVA](https://github.com/haotian-liu/LLaVA). |
|
""" |
|
|
|
|
|
learn_more_markdown = """ |
|
### License |
|
The service is a research preview intended for non-commercial use only, subject to the model |
|
[License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, |
|
[Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, |
|
and [Privacy Practices] |
|
(https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) |
|
of ShareGPT. Please contact us if you find any potential violation. |
|
""" |
|
|
|
block_css = """ |
|
|
|
#buttons button { |
|
min-width: min(120px,100%); |
|
} |
|
|
|
""" |
|
|
|
|
|
def build_demo(embed_mode): |
|
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) |
|
|
|
with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="stone")) as demo: |
|
state = gr.State() |
|
|
|
if not embed_mode: |
|
gr.Markdown(title_markdown) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=models, |
|
value=models[0] if len(models) > 0 else "", |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
) |
|
|
|
imagebox = gr.Image(type="pil") |
|
image_process_mode = gr.Radio( |
|
["Crop", "Resize", "Pad", "Default"], |
|
value="Default", |
|
label="Preprocess for non-square image", |
|
visible=False, |
|
) |
|
|
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
gr.Examples( |
|
examples=[ |
|
[f"{cur_dir}/examples/cows_in_pasture.png", "How many cows are in this image?"], |
|
[ |
|
f"{cur_dir}/examples/monkey_knives.png", |
|
"What is happening in this image?", |
|
], |
|
], |
|
inputs=[imagebox, textbox], |
|
) |
|
|
|
with gr.Accordion("Parameters", open=False): |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.2, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
max_output_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=4096, |
|
value=2048, |
|
step=64, |
|
interactive=True, |
|
label="Max output tokens", |
|
) |
|
|
|
with gr.Accordion("Interaction Mode", open=False): |
|
interaction_modes = list(INTERACTION_MODES_MAP.keys()) |
|
interaction_mode = gr.Dropdown( |
|
choices=interaction_modes, |
|
value=interaction_modes[0] if len(interaction_modes) > 0 else "Chat", |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
) |
|
|
|
with gr.Column(scale=8): |
|
chatbot = gr.Chatbot(elem_id="chatbot", label="PrismaticVLMs Chatbot", height=550) |
|
with gr.Row(): |
|
with gr.Column(scale=8): |
|
textbox.render() |
|
with gr.Column(scale=1, min_width=50): |
|
submit_btn = gr.Button(value="Generate", variant="primary") |
|
with gr.Row(elem_id="buttons"): |
|
|
|
|
|
|
|
|
|
regenerate_btn = gr.Button(value="π Regenerate", interactive=False) |
|
clear_btn = gr.Button(value="ποΈ Clear", interactive=False) |
|
|
|
if not embed_mode: |
|
gr.Markdown(tos_markdown) |
|
gr.Markdown(learn_more_markdown) |
|
url_params = gr.JSON(visible=False) |
|
|
|
|
|
btn_list = [regenerate_btn, clear_btn] |
|
|
|
regenerate_btn.click( |
|
regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, *btn_list], queue=False |
|
).then( |
|
http_bot, |
|
[state, model_selector, interaction_mode, temperature, max_output_tokens], |
|
[state, chatbot, *btn_list], |
|
) |
|
|
|
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox, *btn_list], queue=False) |
|
|
|
textbox.submit( |
|
add_text, |
|
[state, textbox, imagebox, image_process_mode], |
|
[state, chatbot, textbox, imagebox, *btn_list], |
|
queue=False, |
|
).then( |
|
http_bot, |
|
[state, model_selector, interaction_mode, temperature, max_output_tokens], |
|
[state, chatbot, *btn_list], |
|
) |
|
|
|
submit_btn.click( |
|
add_text, |
|
[state, textbox, imagebox, image_process_mode], |
|
[state, chatbot, textbox, imagebox, *btn_list], |
|
queue=False, |
|
).then( |
|
http_bot, |
|
[state, model_selector, interaction_mode, temperature, max_output_tokens], |
|
[state, chatbot, *btn_list], |
|
) |
|
|
|
if args.model_list_mode == "once": |
|
demo.load(load_demo, [url_params], [state, model_selector], _js=get_window_url_params, queue=False) |
|
elif args.model_list_mode == "reload": |
|
demo.load(load_demo_refresh_model_list, None, [state, model_selector], queue=False) |
|
else: |
|
raise ValueError(f"Unknown model list mode: {args.model_list_mode}") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
parser.add_argument("--port", type=int) |
|
parser.add_argument("--controller-url", type=str, default="http://localhost:21001") |
|
parser.add_argument("--concurrency-count", type=int, default=10) |
|
parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"]) |
|
parser.add_argument("--share", action="store_true") |
|
parser.add_argument("--moderate", action="store_true") |
|
parser.add_argument("--embed", action="store_true") |
|
args = parser.parse_args() |
|
logger.info(f"args: {args}") |
|
|
|
models = get_model_list() |
|
|
|
logger.info(args) |
|
demo = build_demo(args.embed) |
|
demo.queue(concurrency_count=args.concurrency_count, api_open=False).launch( |
|
server_name=args.host, server_port=args.port, share=args.share |
|
) |
|
|