Spaces:
Runtime error
Runtime error
import os | |
import time | |
import json | |
import copy | |
import types | |
from os import listdir | |
from os.path import isfile, join | |
import argparse | |
import gradio as gr | |
import global_vars | |
from chats import central | |
from transformers import AutoModelForCausalLM | |
from miscs.styles import MODEL_SELECTION_CSS | |
from miscs.js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE | |
from utils import get_chat_manager, get_global_context | |
from pingpong.pingpong import PingPong | |
from pingpong.gradio import GradioAlpacaChatPPManager | |
from pingpong.gradio import GradioKoAlpacaChatPPManager | |
from pingpong.gradio import GradioStableLMChatPPManager | |
from pingpong.gradio import GradioFlanAlpacaChatPPManager | |
from pingpong.gradio import GradioOSStableLMChatPPManager | |
from pingpong.gradio import GradioVicunaChatPPManager | |
from pingpong.gradio import GradioStableVicunaChatPPManager | |
from pingpong.gradio import GradioStarChatPPManager | |
from pingpong.gradio import GradioMPTChatPPManager | |
from pingpong.gradio import GradioRedPajamaChatPPManager | |
from pingpong.gradio import GradioBaizeChatPPManager | |
# no cpu for | |
# - falcon families (too slow) | |
load_mode_list = ["cpu"] | |
ex_file = open("examples.txt", "r") | |
examples = ex_file.read().split("\n") | |
ex_btns = [] | |
chl_file = open("channels.txt", "r") | |
channels = chl_file.read().split("\n") | |
channel_btns = [] | |
default_ppm = GradioAlpacaChatPPManager() | |
default_ppm.ctx = "Context at top" | |
default_ppm.pingpongs = [ | |
PingPong("user input #1...", "bot response #1..."), | |
PingPong("user input #2...", "bot response #2..."), | |
] | |
chosen_ppm = copy.deepcopy(default_ppm) | |
prompt_styles = { | |
"Alpaca": default_ppm, | |
"Baize": GradioBaizeChatPPManager(), | |
"Koalpaca": GradioKoAlpacaChatPPManager(), | |
"MPT": GradioMPTChatPPManager(), | |
"OpenAssistant StableLM": GradioOSStableLMChatPPManager(), | |
"RedPajama": GradioRedPajamaChatPPManager(), | |
"StableVicuna": GradioVicunaChatPPManager(), | |
"StableLM": GradioStableLMChatPPManager(), | |
"StarChat": GradioStarChatPPManager(), | |
"Vicuna": GradioVicunaChatPPManager(), | |
} | |
response_configs = [ | |
f"configs/response_configs/{f}" | |
for f in listdir("configs/response_configs") | |
if isfile(join("configs/response_configs", f)) | |
] | |
summarization_configs = [ | |
f"configs/summarization_configs/{f}" | |
for f in listdir("configs/summarization_configs") | |
if isfile(join("configs/summarization_configs", f)) | |
] | |
model_info = json.load(open("model_cards.json")) | |
### | |
def move_to_model_select_view(): | |
return ( | |
"move to model select view", | |
gr.update(visible=False), | |
gr.update(visible=True), | |
) | |
def use_chosen_model(): | |
try: | |
test = global_vars.model | |
except AttributeError: | |
raise gr.Error("There is no previously chosen model") | |
gen_config = global_vars.gen_config | |
gen_sum_config = global_vars.gen_config_summarization | |
if global_vars.model_type == "custom": | |
ppmanager_type = chosen_ppm | |
else: | |
ppmanager_type = get_chat_manager(global_vars.model_type) | |
return ( | |
"Preparation done!", | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(label=global_vars.model_type), | |
{ | |
"ppmanager_type": ppmanager_type, | |
"model_type": global_vars.model_type, | |
}, | |
get_global_context(global_vars.model_type), | |
gen_config.temperature, | |
gen_config.top_p, | |
gen_config.top_k, | |
gen_config.repetition_penalty, | |
gen_config.max_new_tokens, | |
gen_config.num_beams, | |
gen_config.use_cache, | |
gen_config.do_sample, | |
gen_config.eos_token_id, | |
gen_config.pad_token_id, | |
gen_sum_config.temperature, | |
gen_sum_config.top_p, | |
gen_sum_config.top_k, | |
gen_sum_config.repetition_penalty, | |
gen_sum_config.max_new_tokens, | |
gen_sum_config.num_beams, | |
gen_sum_config.use_cache, | |
gen_sum_config.do_sample, | |
gen_sum_config.eos_token_id, | |
gen_sum_config.pad_token_id, | |
) | |
def move_to_byom_view(): | |
load_mode_list = [] | |
if global_vars.cuda_availability: | |
load_mode_list.extend(["gpu(half)", "gpu(load_in_8bit)", "gpu(load_in_4bit)"]) | |
if global_vars.mps_availability: | |
load_mode_list.append("apple silicon") | |
load_mode_list.append("cpu") | |
return ( | |
"move to the byom view", | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(choices=load_mode_list, value=load_mode_list[0]) | |
) | |
def prompt_style_change(key): | |
ppm = prompt_styles[key] | |
ppm.ctx = "Context at top" | |
ppm.pingpongs = [ | |
PingPong("user input #1...", "bot response #1..."), | |
PingPong("user input #2...", "bot response #2..."), | |
] | |
chosen_ppm = copy.deepcopy(ppm) | |
chosen_ppm.ctx = "" | |
chosen_ppm.pingpongs = [] | |
return ppm.build_prompts() | |
def byom_load( | |
base, ckpt, model_cls, tokenizer_cls, | |
bos_token_id, eos_token_id, pad_token_id, | |
load_mode, | |
): | |
# mode_cpu, model_mps, mode_8bit, mode_4bit, mode_full_gpu | |
global_vars.initialize_globals_byom( | |
base, ckpt, model_cls, tokenizer_cls, | |
bos_token_id, eos_token_id, pad_token_id, | |
True if load_mode == "cpu" else False, | |
True if load_mode == "apple silicon" else False, | |
True if load_mode == "8bit" else False, | |
True if load_mode == "4bit" else False, | |
True if load_mode == "gpu(half)" else False, | |
) | |
return ( | |
"" | |
) | |
def channel_num(btn_title): | |
choice = 0 | |
for idx, channel in enumerate(channels): | |
if channel == btn_title: | |
choice = idx | |
return choice | |
def set_chatbot(btn, ld, state): | |
choice = channel_num(btn) | |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
empty = len(res[choice].pingpongs) == 0 | |
return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty)) | |
def set_example(btn): | |
return btn, gr.update(visible=False) | |
def set_popup_visibility(ld, example_block): | |
return example_block | |
def move_to_second_view(btn): | |
info = model_info[btn] | |
guard_vram = 5 * 1024. | |
vram_req_full = int(info["vram(full)"]) + guard_vram | |
vram_req_8bit = int(info["vram(8bit)"]) + guard_vram | |
vram_req_4bit = int(info["vram(4bit)"]) + guard_vram | |
load_mode_list = [] | |
if global_vars.cuda_availability: | |
print(f"total vram = {global_vars.available_vrams_mb}") | |
print(f"required vram(full={info['vram(full)']}, 8bit={info['vram(8bit)']}, 4bit={info['vram(4bit)']})") | |
if global_vars.available_vrams_mb >= vram_req_full: | |
load_mode_list.append("gpu(half)") | |
if global_vars.available_vrams_mb >= vram_req_8bit: | |
load_mode_list.append("gpu(load_in_8bit)") | |
if global_vars.available_vrams_mb >= vram_req_4bit: | |
load_mode_list.append("gpu(load_in_4bit)") | |
if global_vars.mps_availability: | |
load_mode_list.append("apple silicon") | |
load_mode_list.extend(["cpu"]) | |
return ( | |
gr.update(visible=False), | |
gr.update(visible=True), | |
info["thumb"], | |
f"## {btn}", | |
f"**Parameters**\n: Approx. {info['parameters']}", | |
f"**🤗 Hub(base)**\n: {info['hub(base)']}", | |
f"**🤗 Hub(LoRA)**\n: {info['hub(ckpt)']}", | |
info['desc'], | |
f"""**Min VRAM requirements** : | |
| half precision | load_in_8bit | load_in_4bit | | |
| ------------------------------------- | ---------------------------------- | ---------------------------------- | | |
| {round(vram_req_full/1024., 1)}GiB | {round(vram_req_8bit/1024., 1)}GiB | {round(vram_req_4bit/1024., 1)}GiB | | |
""", | |
info['default_gen_config'], | |
info['example1'], | |
info['example2'], | |
info['example3'], | |
info['example4'], | |
info['thumb-tiny'], | |
gr.update(choices=load_mode_list, value=load_mode_list[0]), | |
"", | |
) | |
def move_to_first_view(): | |
return (gr.update(visible=True), gr.update(visible=False)) | |
def download_completed( | |
model_name, | |
model_base, | |
model_ckpt, | |
gen_config_path, | |
gen_config_sum_path, | |
load_mode, | |
thumbnail_tiny, | |
force_download, | |
): | |
global local_files_only | |
tmp_args = types.SimpleNamespace() | |
tmp_args.base_url = model_base.split(":")[-1].split("</p")[0].strip() | |
tmp_args.ft_ckpt_url = model_ckpt.split(":")[-1].split("</p")[0].strip() | |
tmp_args.gen_config_path = gen_config_path | |
tmp_args.gen_config_summarization_path = gen_config_sum_path | |
tmp_args.force_download_ckpt = force_download | |
tmp_args.thumbnail_tiny = thumbnail_tiny | |
tmp_args.mode_cpu = True if load_mode == "cpu" else False | |
tmp_args.mode_mps = True if load_mode == "apple silicon" else False | |
tmp_args.mode_8bit = True if load_mode == "gpu(load_in_8bit)" else False | |
tmp_args.mode_4bit = True if load_mode == "gpu(load_in_4bit)" else False | |
tmp_args.mode_full_gpu = True if load_mode == "gpu(half)" else False | |
tmp_args.local_files_only = local_files_only | |
try: | |
global_vars.initialize_globals(tmp_args) | |
except RuntimeError as e: | |
raise gr.Error("GPU memory is not enough to load this model.") | |
return "Download completed!" | |
def move_to_third_view(): | |
gen_config = global_vars.gen_config | |
gen_sum_config = global_vars.gen_config_summarization | |
if global_vars.model_type == "custom": | |
ppmanager_type = chosen_ppm | |
else: | |
ppmanager_type = get_chat_manager(global_vars.model_type) | |
return ( | |
"Preparation done!", | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(label=global_vars.model_type), | |
{ | |
"ppmanager_type": ppmanager_type, | |
"model_type": global_vars.model_type, | |
}, | |
get_global_context(global_vars.model_type), | |
gen_config.temperature, | |
gen_config.top_p, | |
gen_config.top_k, | |
gen_config.repetition_penalty, | |
gen_config.max_new_tokens, | |
gen_config.num_beams, | |
gen_config.use_cache, | |
gen_config.do_sample, | |
gen_config.eos_token_id, | |
gen_config.pad_token_id, | |
gen_sum_config.temperature, | |
gen_sum_config.top_p, | |
gen_sum_config.top_k, | |
gen_sum_config.repetition_penalty, | |
gen_sum_config.max_new_tokens, | |
gen_sum_config.num_beams, | |
gen_sum_config.use_cache, | |
gen_sum_config.do_sample, | |
gen_sum_config.eos_token_id, | |
gen_sum_config.pad_token_id, | |
) | |
def toggle_inspector(view_selector): | |
if view_selector == "with context inspector": | |
return gr.update(visible=True) | |
else: | |
return gr.update(visible=False) | |
def reset_chat(idx, ld, state): | |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
res[idx].pingpongs = [] | |
return ( | |
"", | |
[], | |
str(res), | |
gr.update(visible=True), | |
gr.update(interactive=False), | |
) | |
def rollback_last(idx, ld, state): | |
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] | |
last_user_message = res[idx].pingpongs[-1].ping | |
res[idx].pingpongs = res[idx].pingpongs[:-1] | |
return ( | |
last_user_message, | |
res[idx].build_uis(), | |
str(res), | |
gr.update(interactive=False) | |
) | |
def gradio_main(args): | |
global local_files_only | |
local_files_only = args.local_files_only | |
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo: | |
with gr.Column(visible=True, elem_id="landing-container") as landing_view: | |
gr.Markdown("# Chat with LLM", elem_classes=["center"]) | |
with gr.Row(elem_id="landing-container-selection"): | |
with gr.Column(): | |
gr.Markdown("""This is the landing page of the project, [LLM As Chatbot](https://github.com/deep-diver/LLM-As-Chatbot). This appliction is designed for personal use only. A single model will be selected at a time even if you open up a new browser or a tab. As an initial choice, please select one of the following menu""") | |
gr.Markdown(""" | |
**Bring your own model**: You can chat with arbitrary models. If your own custom model is based on 🤗 Hugging Face's [transformers](https://huggingface.co/docs/transformers/index) library, you will propbably be able to bring it into this application with this menu | |
**Select a model from model pool**: You can chat with one of the popular open source Large Language Model | |
**Use currently selected model**: If you have already selected, but if you came back to this landing page accidently, you can directly go back to the chatting mode with this menu | |
""") | |
byom = gr.Button("🫵🏼 Bring your own model", elem_id="go-byom-select", elem_classes=["square", "landing-btn"]) | |
select_model = gr.Button("🦙 Select a model from model pool", elem_id="go-model-select", elem_classes=["square", "landing-btn"]) | |
chosen_model = gr.Button("↪️ Use currently selected model", elem_id="go-use-selected-model", elem_classes=["square", "landing-btn"]) | |
with gr.Column(elem_id="landing-bottom"): | |
progress_view0 = gr.Textbox(label="Progress", elem_classes=["progress-view"]) | |
gr.Markdown("""[project](https://github.com/deep-diver/LLM-As-Chatbot) | |
[developer](https://github.com/deep-diver) | |
""", elem_classes=["center"]) | |
with gr.Column(visible=False) as model_choice_view: | |
gr.Markdown("# Choose a Model", elem_classes=["center"]) | |
with gr.Row(elem_id="container"): | |
with gr.Column(): | |
gr.Markdown("## ~ 10B Parameters") | |
with gr.Row(elem_classes=["sub-container"]): | |
with gr.Column(min_width=20): | |
t5_vicuna_3b = gr.Button("t5-vicuna-3b", elem_id="t5-vicuna-3b", elem_classes=["square"]) | |
gr.Markdown("T5 Vicuna", elem_classes=["center"]) | |
with gr.Column(min_width=20, visible=False): | |
flan3b = gr.Button("flan-3b", elem_id="flan-3b", elem_classes=["square"]) | |
gr.Markdown("Flan-XL", elem_classes=["center"]) | |
# with gr.Column(min_width=20): | |
# replit_3b = gr.Button("replit-3b", elem_id="replit-3b", elem_classes=["square"]) | |
# gr.Markdown("Replit Instruct", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
camel5b = gr.Button("camel-5b", elem_id="camel-5b", elem_classes=["square"]) | |
gr.Markdown("Camel", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
alpaca_lora7b = gr.Button("alpaca-lora-7b", elem_id="alpaca-lora-7b", elem_classes=["square"]) | |
gr.Markdown("Alpaca-LoRA", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
stablelm7b = gr.Button("stablelm-7b", elem_id="stablelm-7b", elem_classes=["square"]) | |
gr.Markdown("StableLM", elem_classes=["center"]) | |
with gr.Column(min_width=20, visible=False): | |
os_stablelm7b = gr.Button("os-stablelm-7b", elem_id="os-stablelm-7b", elem_classes=["square"]) | |
gr.Markdown("OA+StableLM", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
gpt4_alpaca_7b = gr.Button("gpt4-alpaca-7b", elem_id="gpt4-alpaca-7b", elem_classes=["square"]) | |
gr.Markdown("GPT4-Alpaca-LoRA", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
mpt_7b = gr.Button("mpt-7b", elem_id="mpt-7b", elem_classes=["square"]) | |
gr.Markdown("MPT", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
redpajama_7b = gr.Button("redpajama-7b", elem_id="redpajama-7b", elem_classes=["square"]) | |
gr.Markdown("RedPajama", elem_classes=["center"]) | |
with gr.Column(min_width=20, visible=False): | |
redpajama_instruct_7b = gr.Button("redpajama-instruct-7b", elem_id="redpajama-instruct-7b", elem_classes=["square"]) | |
gr.Markdown("RedPajama Instruct", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
vicuna_7b = gr.Button("vicuna-7b", elem_id="vicuna-7b", elem_classes=["square"]) | |
gr.Markdown("Vicuna", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
vicuna_7b_1_3 = gr.Button("vicuna-7b-1-3", elem_id="vicuna-7b-1-3", elem_classes=["square"]) | |
gr.Markdown("Vicuna 1.3", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
llama_deus_7b = gr.Button("llama-deus-7b", elem_id="llama-deus-7b",elem_classes=["square"]) | |
gr.Markdown("LLaMA Deus", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
evolinstruct_vicuna_7b = gr.Button("evolinstruct-vicuna-7b", elem_id="evolinstruct-vicuna-7b", elem_classes=["square"]) | |
gr.Markdown("EvolInstruct Vicuna", elem_classes=["center"]) | |
with gr.Column(min_width=20, visible=False): | |
alpacoom_7b = gr.Button("alpacoom-7b", elem_id="alpacoom-7b", elem_classes=["square"]) | |
gr.Markdown("Alpacoom", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
baize_7b = gr.Button("baize-7b", elem_id="baize-7b", elem_classes=["square"]) | |
gr.Markdown("Baize", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
guanaco_7b = gr.Button("guanaco-7b", elem_id="guanaco-7b", elem_classes=["square"]) | |
gr.Markdown("Guanaco", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
falcon_7b = gr.Button("falcon-7b", elem_id="falcon-7b", elem_classes=["square"]) | |
gr.Markdown("Falcon", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
wizard_falcon_7b = gr.Button("wizard-falcon-7b", elem_id="wizard-falcon-7b", elem_classes=["square"]) | |
gr.Markdown("Wizard Falcon", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
airoboros_7b = gr.Button("airoboros-7b", elem_id="airoboros-7b", elem_classes=["square"]) | |
gr.Markdown("Airoboros", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
samantha_7b = gr.Button("samantha-7b", elem_id="samantha-7b", elem_classes=["square"]) | |
gr.Markdown("Samantha", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
openllama_7b = gr.Button("openllama-7b", elem_id="openllama-7b", elem_classes=["square"]) | |
gr.Markdown("OpenLLaMA", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
orcamini_7b = gr.Button("orcamini-7b", elem_id="orcamini-7b", elem_classes=["square"]) | |
gr.Markdown("Orca Mini", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
xgen_7b = gr.Button("xgen-7b", elem_id="xgen-7b", elem_classes=["square"]) | |
gr.Markdown("XGen", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
llama2_7b = gr.Button("llama2-7b", elem_id="llama2-7b", elem_classes=["square"]) | |
gr.Markdown("LLaMA 2", elem_classes=["center"]) | |
gr.Markdown("## ~ 20B Parameters") | |
with gr.Row(elem_classes=["sub-container"]): | |
with gr.Column(min_width=20, visible=False): | |
flan11b = gr.Button("flan-11b", elem_id="flan-11b", elem_classes=["square"]) | |
gr.Markdown("Flan-XXL", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
koalpaca = gr.Button("koalpaca", elem_id="koalpaca", elem_classes=["square"]) | |
gr.Markdown("koalpaca", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
kullm = gr.Button("kullm", elem_id="kullm", elem_classes=["square"]) | |
gr.Markdown("KULLM", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
alpaca_lora13b = gr.Button("alpaca-lora-13b", elem_id="alpaca-lora-13b", elem_classes=["square"]) | |
gr.Markdown("Alpaca-LoRA", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
gpt4_alpaca_13b = gr.Button("gpt4-alpaca-13b", elem_id="gpt4-alpaca-13b", elem_classes=["square"]) | |
gr.Markdown("GPT4-Alpaca-LoRA", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
stable_vicuna_13b = gr.Button("stable-vicuna-13b", elem_id="stable-vicuna-13b", elem_classes=["square"]) | |
gr.Markdown("Stable-Vicuna", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
starchat_15b = gr.Button("starchat-15b", elem_id="starchat-15b", elem_classes=["square"]) | |
gr.Markdown("StarChat", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
starchat_beta_15b = gr.Button("starchat-beta-15b", elem_id="starchat-beta-15b", elem_classes=["square"]) | |
gr.Markdown("StarChat β", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
vicuna_13b = gr.Button("vicuna-13b", elem_id="vicuna-13b", elem_classes=["square"]) | |
gr.Markdown("Vicuna", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
vicuna_13b_1_3 = gr.Button("vicuna-13b-1-3", elem_id="vicuna-13b-1-3", elem_classes=["square"]) | |
gr.Markdown("Vicuna 1.3", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
evolinstruct_vicuna_13b = gr.Button("evolinstruct-vicuna-13b", elem_id="evolinstruct-vicuna-13b", elem_classes=["square"]) | |
gr.Markdown("EvolInstruct Vicuna", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
baize_13b = gr.Button("baize-13b", elem_id="baize-13b", elem_classes=["square"]) | |
gr.Markdown("Baize", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
guanaco_13b = gr.Button("guanaco-13b", elem_id="guanaco-13b", elem_classes=["square"]) | |
gr.Markdown("Guanaco", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
nous_hermes_13b = gr.Button("nous-hermes-13b", elem_id="nous-hermes-13b", elem_classes=["square"]) | |
gr.Markdown("Nous Hermes", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
airoboros_13b = gr.Button("airoboros-13b", elem_id="airoboros-13b", elem_classes=["square"]) | |
gr.Markdown("Airoboros", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
samantha_13b = gr.Button("samantha-13b", elem_id="samantha-13b", elem_classes=["square"]) | |
gr.Markdown("Samantha", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
chronos_13b = gr.Button("chronos-13b", elem_id="chronos-13b", elem_classes=["square"]) | |
gr.Markdown("Chronos", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
wizardlm_13b = gr.Button("wizardlm-13b", elem_id="wizardlm-13b", elem_classes=["square"]) | |
gr.Markdown("WizardLM", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
wizard_vicuna_13b = gr.Button("wizard-vicuna-13b", elem_id="wizard-vicuna-13b", elem_classes=["square"]) | |
gr.Markdown("Wizard Vicuna (Uncensored)", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
wizard_coder_15b = gr.Button("wizard-coder-15b", elem_id="wizard-coder-15b", elem_classes=["square"]) | |
gr.Markdown("Wizard Coder", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
openllama_13b = gr.Button("openllama-13b", elem_id="openllama-13b", elem_classes=["square"]) | |
gr.Markdown("OpenLLaMA", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
orcamini_13b = gr.Button("orcamini-13b", elem_id="orcamini-13b", elem_classes=["square"]) | |
gr.Markdown("Orca Mini", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
llama2_13b = gr.Button("llama2-13b", elem_id="llama2-13b", elem_classes=["square"]) | |
gr.Markdown("LLaMA 2", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
nous_hermes_13b_v2 = gr.Button("nous-hermes-13b-llama2", elem_id="nous-hermes-13b-llama2", elem_classes=["square"]) | |
gr.Markdown("Nous Hermes v2", elem_classes=["center"]) | |
gr.Markdown("## ~ 30B Parameters", visible=False) | |
with gr.Row(elem_classes=["sub-container"], visible=False): | |
with gr.Column(min_width=20): | |
camel20b = gr.Button("camel-20b", elem_id="camel-20b", elem_classes=["square"]) | |
gr.Markdown("Camel", elem_classes=["center"]) | |
gr.Markdown("## ~ 40B Parameters") | |
with gr.Row(elem_classes=["sub-container"]): | |
with gr.Column(min_width=20): | |
guanaco_33b = gr.Button("guanaco-33b", elem_id="guanaco-33b", elem_classes=["square"]) | |
gr.Markdown("Guanaco", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
falcon_40b = gr.Button("falcon-40b", elem_id="falcon-40b", elem_classes=["square"]) | |
gr.Markdown("Falcon", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
wizard_falcon_40b = gr.Button("wizard-falcon-40b", elem_id="wizard-falcon-40b", elem_classes=["square"]) | |
gr.Markdown("Wizard Falcon", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
samantha_33b = gr.Button("samantha-33b", elem_id="samantha-33b", elem_classes=["square"]) | |
gr.Markdown("Samantha", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
lazarus_30b = gr.Button("lazarus-30b", elem_id="lazarus-30b", elem_classes=["square"]) | |
gr.Markdown("Lazarus", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
chronos_33b = gr.Button("chronos-33b", elem_id="chronos-33b", elem_classes=["square"]) | |
gr.Markdown("Chronos", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
wizardlm_30b = gr.Button("wizardlm-30b", elem_id="wizardlm-30b", elem_classes=["square"]) | |
gr.Markdown("WizardLM", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
wizard_vicuna_30b = gr.Button("wizard-vicuna-30b", elem_id="wizard-vicuna-30b", elem_classes=["square"]) | |
gr.Markdown("Wizard Vicuna (Uncensored)", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
vicuna_33b_1_3 = gr.Button("vicuna-33b-1-3", elem_id="vicuna-33b-1-3", elem_classes=["square"]) | |
gr.Markdown("Vicuna 1.3", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
mpt_30b = gr.Button("mpt-30b", elem_id="mpt-30b", elem_classes=["square"]) | |
gr.Markdown("MPT", elem_classes=["center"]) | |
with gr.Column(min_width=20): | |
upstage_llama_30b = gr.Button("upstage-llama-30b", elem_id="upstage-llama-30b", elem_classes=["square"]) | |
gr.Markdown("Upstage LLaMA", elem_classes=["center"]) | |
gr.Markdown("## ~ 70B Parameters") | |
with gr.Row(elem_classes=["sub-container"]): | |
with gr.Column(min_width=20): | |
free_willy2_70b = gr.Button("free-willy2-70b", elem_id="free-willy2-70b", elem_classes=["square"]) | |
gr.Markdown("Free Willy 2", elem_classes=["center"]) | |
progress_view = gr.Textbox(label="Progress", elem_classes=["progress-view"]) | |
with gr.Column(visible=False) as byom_input_view: | |
with gr.Column(elem_id="container3"): | |
gr.Markdown("# Bring Your Own Model", elem_classes=["center"]) | |
gr.Markdown("### Model configuration") | |
byom_base = gr.Textbox(label="Base", placeholder="Enter path or 🤗 hub ID of the base model", interactive=True) | |
byom_ckpt = gr.Textbox(label="LoRA ckpt", placeholder="Enter path or 🤗 hub ID of the LoRA checkpoint", interactive=True) | |
with gr.Accordion("Advanced options", open=False): | |
gr.Markdown("If you leave the below textboxes empty, `transformers.AutoModelForCausalLM` and `transformers.AutoTokenizer` classes will be used by default. If you need any specific class, please type them below.") | |
byom_model_cls = gr.Textbox(label="Base model class", placeholder="Enter base model class", interactive=True) | |
byom_tokenizer_cls = gr.Textbox(label="Base tokenizer class", placeholder="Enter base tokenizer class", interactive=True) | |
with gr.Column(): | |
gr.Markdown("If you leave the below textboxes empty, any token ids for bos, eos, and pad will not be specified in `GenerationConfig`. If you think that you need to specify them. please type them below in decimal format.") | |
with gr.Row(): | |
byom_bos_token_id = gr.Textbox(label="bos_token_id", placeholder="for GenConfig") | |
byom_eos_token_id = gr.Textbox(label="eos_token_id", placeholder="for GenConfig") | |
byom_pad_token_id = gr.Textbox(label="pad_token_id", placeholder="for GenConfig") | |
with gr.Row(): | |
byom_load_mode = gr.Radio( | |
load_mode_list, | |
value=load_mode_list[0], | |
label="load mode", | |
elem_classes=["load-mode-selector"] | |
) | |
gr.Markdown("### Prompt configuration") | |
prompt_style_selector = gr.Dropdown( | |
label="Prompt style", | |
interactive=True, | |
choices=list(prompt_styles.keys()), | |
value="Alpaca" | |
) | |
with gr.Accordion("Prompt style preview", open=False): | |
prompt_style_previewer = gr.Textbox( | |
label="How prompt is actually structured", | |
lines=16, | |
value=default_ppm.build_prompts()) | |
with gr.Row(): | |
byom_back_btn = gr.Button("Back") | |
byom_confirm_btn = gr.Button("Confirm") | |
with gr.Column(elem_classes=["progress-view"]): | |
txt_view3 = gr.Textbox(label="Status") | |
progress_view3 = gr.Textbox(label="Progress") | |
with gr.Column(visible=False) as model_review_view: | |
gr.Markdown("# Confirm the chosen model", elem_classes=["center"]) | |
with gr.Column(elem_id="container2"): | |
gr.Markdown("Please expect loading time to be longer than expected. Depending on the size of models, it will probably take from 100 to 1000 seconds or so. Please be patient.") | |
with gr.Row(): | |
model_image = gr.Image(None, interactive=False, show_label=False) | |
with gr.Column(): | |
model_name = gr.Markdown("**Model name**") | |
model_desc = gr.Markdown("...") | |
model_params = gr.Markdown("Parameters\n: ...") | |
model_base = gr.Markdown("🤗 Hub(base)\n: ...") | |
model_ckpt = gr.Markdown("🤗 Hub(LoRA)\n: ...") | |
model_vram = gr.Markdown(f"""**Minimal VRAM requirement** : | |
| half precision | load_in_8bit | load_in_4bit | | |
| ------------------------------ | ------------------------- | ------------------------- | | |
| {round(7830/1024., 1)}GiB | {round(5224/1024., 1)}GiB | {round(4324/1024., 1)}GiB | | |
""") | |
model_thumbnail_tiny = gr.Textbox("", visible=False) | |
with gr.Column(): | |
gen_config_path = gr.Dropdown( | |
response_configs, | |
value=response_configs[0], | |
interactive=True, | |
label="Gen Config(response)", | |
) | |
gen_config_sum_path = gr.Dropdown( | |
summarization_configs, | |
value=summarization_configs[0], | |
interactive=True, | |
label="Gen Config(summarization)", | |
visible=False, | |
) | |
with gr.Row(): | |
load_mode = gr.Radio( | |
load_mode_list, | |
value=load_mode_list[0], | |
label="load mode", | |
elem_classes=["load-mode-selector"] | |
) | |
force_redownload = gr.Checkbox(label="Force Re-download", interactive=False, visible=False) | |
with gr.Accordion("Example showcases", open=False): | |
with gr.Tab("Ex1"): | |
example_showcase1 = gr.Chatbot( | |
[("hello", "world"), ("damn", "good")] | |
) | |
with gr.Tab("Ex2"): | |
example_showcase2 = gr.Chatbot( | |
[("hello", "world"), ("damn", "good")] | |
) | |
with gr.Tab("Ex3"): | |
example_showcase3 = gr.Chatbot( | |
[("hello", "world"), ("damn", "good")] | |
) | |
with gr.Tab("Ex4"): | |
example_showcase4 = gr.Chatbot( | |
[("hello", "world"), ("damn", "good")] | |
) | |
with gr.Row(): | |
back_to_model_choose_btn = gr.Button("Back") | |
confirm_btn = gr.Button("Confirm") | |
with gr.Column(elem_classes=["progress-view"]): | |
txt_view = gr.Textbox(label="Status") | |
progress_view2 = gr.Textbox(label="Progress") | |
with gr.Column(visible=False) as chat_view: | |
idx = gr.State(0) | |
chat_state = gr.State() | |
local_data = gr.JSON({}, visible=False) | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=180): | |
gr.Markdown("GradioChat", elem_id="left-top") | |
with gr.Column(elem_id="left-pane"): | |
chat_back_btn = gr.Button("Back", elem_id="chat-back-btn") | |
with gr.Accordion("Histories", elem_id="chat-history-accordion", open=False): | |
channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"])) | |
for channel in channels[1:]: | |
channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"])) | |
with gr.Column(scale=8, elem_id="right-pane"): | |
with gr.Column( | |
elem_id="initial-popup", visible=False | |
) as example_block: | |
with gr.Row(scale=1): | |
with gr.Column(elem_id="initial-popup-left-pane"): | |
gr.Markdown("GradioChat", elem_id="initial-popup-title") | |
gr.Markdown("Making the community's best AI chat models available to everyone.") | |
with gr.Column(elem_id="initial-popup-right-pane"): | |
gr.Markdown("Chat UI is now open sourced on Hugging Face Hub") | |
gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)") | |
with gr.Column(scale=1): | |
gr.Markdown("Examples") | |
with gr.Row(): | |
for example in examples: | |
ex_btns.append(gr.Button(example, elem_classes=["example-btn"])) | |
with gr.Column(elem_id="aux-btns-popup", visible=True): | |
with gr.Row(): | |
stop = gr.Button("Stop", elem_classes=["aux-btn"]) | |
regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"]) | |
clean = gr.Button("Clean", elem_classes=["aux-btn"]) | |
with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False): | |
context_inspector = gr.Textbox( | |
"", | |
elem_id="aux-viewer-inspector", | |
label="", | |
lines=30, | |
max_lines=50, | |
) | |
chatbot = gr.Chatbot(elem_id='chatbot') | |
instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt") | |
with gr.Accordion("Control Panel", open=False) as control_panel: | |
with gr.Column(): | |
with gr.Column(): | |
gr.Markdown("#### Global context") | |
with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=False): | |
global_context = gr.Textbox( | |
"global context", | |
lines=5, | |
max_lines=10, | |
interactive=True, | |
elem_id="global-context" | |
) | |
gr.Markdown("#### Internet search") | |
with gr.Row(): | |
internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode") | |
serper_api_key = gr.Textbox( | |
value= "" if args.serper_api_key is None else args.serper_api_key, | |
placeholder="Get one by visiting serper.dev", | |
label="Serper api key" | |
) | |
gr.Markdown("#### GenConfig for **response** text generation") | |
with gr.Row(): | |
res_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True) | |
res_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True) | |
res_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True) | |
res_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True) | |
res_mnts = gr.Slider(64, 2048, 0, step=1, label="new_tokens", interactive=True) | |
res_beams = gr.Slider(1, 4, 0, step=1, label="beams") | |
res_cache = gr.Radio([True, False], value=0, label="cache", interactive=True) | |
res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True) | |
res_eosid = gr.Number(value=0, visible=False, precision=0) | |
res_padid = gr.Number(value=0, visible=False, precision=0) | |
with gr.Column(visible=False): | |
gr.Markdown("#### GenConfig for **summary** text generation") | |
with gr.Row(): | |
sum_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True) | |
sum_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True) | |
sum_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True) | |
sum_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True) | |
sum_mnts = gr.Slider(64, 2048, 0, step=1, label="new_tokens", interactive=True) | |
sum_beams = gr.Slider(1, 8, 0, step=1, label="beams", interactive=True) | |
sum_cache = gr.Radio([True, False], value=0, label="cache", interactive=True) | |
sum_sample = gr.Radio([True, False], value=0, label="sample", interactive=True) | |
sum_eosid = gr.Number(value=0, visible=False, precision=0) | |
sum_padid = gr.Number(value=0, visible=False, precision=0) | |
with gr.Column(): | |
gr.Markdown("#### Context managements") | |
with gr.Row(): | |
ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True) | |
ctx_sum_prompt = gr.Textbox( | |
"summarize our conversations. what have we discussed about so far?", | |
label="design a prompt to summarize the conversations", | |
visible=False | |
) | |
btns = [ | |
t5_vicuna_3b, flan3b, camel5b, alpaca_lora7b, stablelm7b, | |
gpt4_alpaca_7b, os_stablelm7b, mpt_7b, redpajama_7b, redpajama_instruct_7b, llama_deus_7b, | |
evolinstruct_vicuna_7b, alpacoom_7b, baize_7b, guanaco_7b, vicuna_7b_1_3, | |
falcon_7b, wizard_falcon_7b, airoboros_7b, samantha_7b, openllama_7b, orcamini_7b, | |
xgen_7b,llama2_7b, | |
flan11b, koalpaca, kullm, alpaca_lora13b, gpt4_alpaca_13b, stable_vicuna_13b, | |
starchat_15b, starchat_beta_15b, vicuna_7b, vicuna_13b, evolinstruct_vicuna_13b, | |
baize_13b, guanaco_13b, nous_hermes_13b, airoboros_13b, samantha_13b, chronos_13b, | |
wizardlm_13b, wizard_vicuna_13b, wizard_coder_15b, vicuna_13b_1_3, openllama_13b, orcamini_13b, | |
llama2_13b, nous_hermes_13b_v2, camel20b, | |
guanaco_33b, falcon_40b, wizard_falcon_40b, samantha_33b, lazarus_30b, chronos_33b, | |
wizardlm_30b, wizard_vicuna_30b, vicuna_33b_1_3, mpt_30b, upstage_llama_30b, | |
free_willy2_70b | |
] | |
for btn in btns: | |
btn.click( | |
move_to_second_view, | |
btn, | |
[ | |
model_choice_view, model_review_view, | |
model_image, model_name, model_params, model_base, model_ckpt, | |
model_desc, model_vram, gen_config_path, | |
example_showcase1, example_showcase2, example_showcase3, example_showcase4, | |
model_thumbnail_tiny, load_mode, | |
progress_view | |
] | |
) | |
select_model.click( | |
move_to_model_select_view, | |
None, | |
[progress_view0, landing_view, model_choice_view] | |
) | |
chosen_model.click( | |
use_chosen_model, | |
None, | |
[progress_view0, landing_view, chat_view, chatbot, chat_state, global_context, | |
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid] | |
) | |
byom.click( | |
move_to_byom_view, | |
None, | |
[progress_view0, landing_view, byom_input_view, byom_load_mode] | |
) | |
byom_back_btn.click( | |
move_to_first_view, | |
None, | |
[landing_view, byom_input_view] | |
) | |
byom_confirm_btn.click( | |
lambda: "Start downloading/loading the model...", None, txt_view3 | |
).then( | |
byom_load, | |
[byom_base, byom_ckpt, byom_model_cls, byom_tokenizer_cls, | |
byom_bos_token_id, byom_eos_token_id, byom_pad_token_id, | |
byom_load_mode], | |
[progress_view3] | |
).then( | |
lambda: "Model is fully loaded...", None, txt_view3 | |
).then( | |
move_to_third_view, | |
None, | |
[progress_view3, byom_input_view, chat_view, chatbot, chat_state, global_context, | |
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid] | |
) | |
prompt_style_selector.change( | |
prompt_style_change, | |
prompt_style_selector, | |
prompt_style_previewer | |
) | |
back_to_model_choose_btn.click( | |
move_to_first_view, | |
None, | |
[model_choice_view, model_review_view] | |
) | |
confirm_btn.click( | |
lambda: "Start downloading/loading the model...", None, txt_view | |
).then( | |
download_completed, | |
[model_name, model_base, model_ckpt, gen_config_path, gen_config_sum_path, load_mode, model_thumbnail_tiny, force_redownload], | |
[progress_view2] | |
).then( | |
lambda: "Model is fully loaded...", None, txt_view | |
).then( | |
lambda: time.sleep(2), None, None | |
).then( | |
move_to_third_view, | |
None, | |
[progress_view2, model_review_view, chat_view, chatbot, chat_state, global_context, | |
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid] | |
) | |
for btn in channel_btns: | |
btn.click( | |
set_chatbot, | |
[btn, local_data, chat_state], | |
[chatbot, idx, example_block, regenerate] | |
).then( | |
None, btn, None, | |
_js=UPDATE_LEFT_BTNS_STATE | |
) | |
for btn in ex_btns: | |
btn.click( | |
set_example, | |
[btn], | |
[instruction_txtbox, example_block] | |
) | |
instruction_txtbox.submit( | |
lambda: [ | |
gr.update(visible=False), | |
gr.update(interactive=True) | |
], | |
None, | |
[example_block, regenerate] | |
) | |
send_event = instruction_txtbox.submit( | |
central.chat_stream, | |
[idx, local_data, instruction_txtbox, chat_state, | |
global_context, ctx_num_lconv, ctx_sum_prompt, | |
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, | |
internet_option, serper_api_key], | |
[instruction_txtbox, chatbot, context_inspector, local_data], | |
) | |
instruction_txtbox.submit( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
regenerate.click( | |
rollback_last, | |
[idx, local_data, chat_state], | |
[instruction_txtbox, chatbot, local_data, regenerate] | |
).then( | |
central.chat_stream, | |
[idx, local_data, instruction_txtbox, chat_state, | |
global_context, ctx_num_lconv, ctx_sum_prompt, | |
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid, | |
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid, | |
internet_option, serper_api_key], | |
[instruction_txtbox, chatbot, context_inspector, local_data], | |
).then( | |
lambda: gr.update(interactive=True), | |
None, | |
regenerate | |
).then( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
stop.click( | |
None, None, None, | |
cancels=[send_event] | |
) | |
clean.click( | |
reset_chat, | |
[idx, local_data, chat_state], | |
[instruction_txtbox, chatbot, local_data, example_block, regenerate] | |
).then( | |
None, local_data, None, | |
_js="(v)=>{ setStorage('local_data',v) }" | |
) | |
chat_back_btn.click( | |
lambda: [gr.update(visible=False), gr.update(visible=True)], | |
None, | |
[chat_view, landing_view] | |
) | |
demo.load( | |
None, | |
inputs=None, | |
outputs=[chatbot, local_data], | |
_js=GET_LOCAL_STORAGE, | |
) | |
demo.queue().launch( | |
server_port=6006, | |
server_name="0.0.0.0", | |
debug=args.debug, | |
share=args.share, | |
root_path=f"{args.root_path}" | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--root-path', default="") | |
parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--serper-api-key', default=None, type=str) | |
args = parser.parse_args() | |
gradio_main(args) | |