LLM-As-Chatbot / app.py
koonmania's picture
Upload folder using huggingface_hub
4df8249
import os
import time
import json
import copy
import types
from os import listdir
from os.path import isfile, join
import argparse
import gradio as gr
import global_vars
from chats import central
from transformers import AutoModelForCausalLM
from miscs.styles import MODEL_SELECTION_CSS
from miscs.js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE
from utils import get_chat_manager, get_global_context
from pingpong.pingpong import PingPong
from pingpong.gradio import GradioAlpacaChatPPManager
from pingpong.gradio import GradioKoAlpacaChatPPManager
from pingpong.gradio import GradioStableLMChatPPManager
from pingpong.gradio import GradioFlanAlpacaChatPPManager
from pingpong.gradio import GradioOSStableLMChatPPManager
from pingpong.gradio import GradioVicunaChatPPManager
from pingpong.gradio import GradioStableVicunaChatPPManager
from pingpong.gradio import GradioStarChatPPManager
from pingpong.gradio import GradioMPTChatPPManager
from pingpong.gradio import GradioRedPajamaChatPPManager
from pingpong.gradio import GradioBaizeChatPPManager
# no cpu for
# - falcon families (too slow)
load_mode_list = ["cpu"]
ex_file = open("examples.txt", "r")
examples = ex_file.read().split("\n")
ex_btns = []
chl_file = open("channels.txt", "r")
channels = chl_file.read().split("\n")
channel_btns = []
default_ppm = GradioAlpacaChatPPManager()
default_ppm.ctx = "Context at top"
default_ppm.pingpongs = [
PingPong("user input #1...", "bot response #1..."),
PingPong("user input #2...", "bot response #2..."),
]
chosen_ppm = copy.deepcopy(default_ppm)
prompt_styles = {
"Alpaca": default_ppm,
"Baize": GradioBaizeChatPPManager(),
"Koalpaca": GradioKoAlpacaChatPPManager(),
"MPT": GradioMPTChatPPManager(),
"OpenAssistant StableLM": GradioOSStableLMChatPPManager(),
"RedPajama": GradioRedPajamaChatPPManager(),
"StableVicuna": GradioVicunaChatPPManager(),
"StableLM": GradioStableLMChatPPManager(),
"StarChat": GradioStarChatPPManager(),
"Vicuna": GradioVicunaChatPPManager(),
}
response_configs = [
f"configs/response_configs/{f}"
for f in listdir("configs/response_configs")
if isfile(join("configs/response_configs", f))
]
summarization_configs = [
f"configs/summarization_configs/{f}"
for f in listdir("configs/summarization_configs")
if isfile(join("configs/summarization_configs", f))
]
model_info = json.load(open("model_cards.json"))
###
def move_to_model_select_view():
return (
"move to model select view",
gr.update(visible=False),
gr.update(visible=True),
)
def use_chosen_model():
try:
test = global_vars.model
except AttributeError:
raise gr.Error("There is no previously chosen model")
gen_config = global_vars.gen_config
gen_sum_config = global_vars.gen_config_summarization
if global_vars.model_type == "custom":
ppmanager_type = chosen_ppm
else:
ppmanager_type = get_chat_manager(global_vars.model_type)
return (
"Preparation done!",
gr.update(visible=False),
gr.update(visible=True),
gr.update(label=global_vars.model_type),
{
"ppmanager_type": ppmanager_type,
"model_type": global_vars.model_type,
},
get_global_context(global_vars.model_type),
gen_config.temperature,
gen_config.top_p,
gen_config.top_k,
gen_config.repetition_penalty,
gen_config.max_new_tokens,
gen_config.num_beams,
gen_config.use_cache,
gen_config.do_sample,
gen_config.eos_token_id,
gen_config.pad_token_id,
gen_sum_config.temperature,
gen_sum_config.top_p,
gen_sum_config.top_k,
gen_sum_config.repetition_penalty,
gen_sum_config.max_new_tokens,
gen_sum_config.num_beams,
gen_sum_config.use_cache,
gen_sum_config.do_sample,
gen_sum_config.eos_token_id,
gen_sum_config.pad_token_id,
)
def move_to_byom_view():
load_mode_list = []
if global_vars.cuda_availability:
load_mode_list.extend(["gpu(half)", "gpu(load_in_8bit)", "gpu(load_in_4bit)"])
if global_vars.mps_availability:
load_mode_list.append("apple silicon")
load_mode_list.append("cpu")
return (
"move to the byom view",
gr.update(visible=False),
gr.update(visible=True),
gr.update(choices=load_mode_list, value=load_mode_list[0])
)
def prompt_style_change(key):
ppm = prompt_styles[key]
ppm.ctx = "Context at top"
ppm.pingpongs = [
PingPong("user input #1...", "bot response #1..."),
PingPong("user input #2...", "bot response #2..."),
]
chosen_ppm = copy.deepcopy(ppm)
chosen_ppm.ctx = ""
chosen_ppm.pingpongs = []
return ppm.build_prompts()
def byom_load(
base, ckpt, model_cls, tokenizer_cls,
bos_token_id, eos_token_id, pad_token_id,
load_mode,
):
# mode_cpu, model_mps, mode_8bit, mode_4bit, mode_full_gpu
global_vars.initialize_globals_byom(
base, ckpt, model_cls, tokenizer_cls,
bos_token_id, eos_token_id, pad_token_id,
True if load_mode == "cpu" else False,
True if load_mode == "apple silicon" else False,
True if load_mode == "8bit" else False,
True if load_mode == "4bit" else False,
True if load_mode == "gpu(half)" else False,
)
return (
""
)
def channel_num(btn_title):
choice = 0
for idx, channel in enumerate(channels):
if channel == btn_title:
choice = idx
return choice
def set_chatbot(btn, ld, state):
choice = channel_num(btn)
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
empty = len(res[choice].pingpongs) == 0
return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty))
def set_example(btn):
return btn, gr.update(visible=False)
def set_popup_visibility(ld, example_block):
return example_block
def move_to_second_view(btn):
info = model_info[btn]
guard_vram = 5 * 1024.
vram_req_full = int(info["vram(full)"]) + guard_vram
vram_req_8bit = int(info["vram(8bit)"]) + guard_vram
vram_req_4bit = int(info["vram(4bit)"]) + guard_vram
load_mode_list = []
if global_vars.cuda_availability:
print(f"total vram = {global_vars.available_vrams_mb}")
print(f"required vram(full={info['vram(full)']}, 8bit={info['vram(8bit)']}, 4bit={info['vram(4bit)']})")
if global_vars.available_vrams_mb >= vram_req_full:
load_mode_list.append("gpu(half)")
if global_vars.available_vrams_mb >= vram_req_8bit:
load_mode_list.append("gpu(load_in_8bit)")
if global_vars.available_vrams_mb >= vram_req_4bit:
load_mode_list.append("gpu(load_in_4bit)")
if global_vars.mps_availability:
load_mode_list.append("apple silicon")
load_mode_list.extend(["cpu"])
return (
gr.update(visible=False),
gr.update(visible=True),
info["thumb"],
f"## {btn}",
f"**Parameters**\n: Approx. {info['parameters']}",
f"**🤗 Hub(base)**\n: {info['hub(base)']}",
f"**🤗 Hub(LoRA)**\n: {info['hub(ckpt)']}",
info['desc'],
f"""**Min VRAM requirements** :
| half precision | load_in_8bit | load_in_4bit |
| ------------------------------------- | ---------------------------------- | ---------------------------------- |
| {round(vram_req_full/1024., 1)}GiB | {round(vram_req_8bit/1024., 1)}GiB | {round(vram_req_4bit/1024., 1)}GiB |
""",
info['default_gen_config'],
info['example1'],
info['example2'],
info['example3'],
info['example4'],
info['thumb-tiny'],
gr.update(choices=load_mode_list, value=load_mode_list[0]),
"",
)
def move_to_first_view():
return (gr.update(visible=True), gr.update(visible=False))
def download_completed(
model_name,
model_base,
model_ckpt,
gen_config_path,
gen_config_sum_path,
load_mode,
thumbnail_tiny,
force_download,
):
global local_files_only
tmp_args = types.SimpleNamespace()
tmp_args.base_url = model_base.split(":")[-1].split("</p")[0].strip()
tmp_args.ft_ckpt_url = model_ckpt.split(":")[-1].split("</p")[0].strip()
tmp_args.gen_config_path = gen_config_path
tmp_args.gen_config_summarization_path = gen_config_sum_path
tmp_args.force_download_ckpt = force_download
tmp_args.thumbnail_tiny = thumbnail_tiny
tmp_args.mode_cpu = True if load_mode == "cpu" else False
tmp_args.mode_mps = True if load_mode == "apple silicon" else False
tmp_args.mode_8bit = True if load_mode == "gpu(load_in_8bit)" else False
tmp_args.mode_4bit = True if load_mode == "gpu(load_in_4bit)" else False
tmp_args.mode_full_gpu = True if load_mode == "gpu(half)" else False
tmp_args.local_files_only = local_files_only
try:
global_vars.initialize_globals(tmp_args)
except RuntimeError as e:
raise gr.Error("GPU memory is not enough to load this model.")
return "Download completed!"
def move_to_third_view():
gen_config = global_vars.gen_config
gen_sum_config = global_vars.gen_config_summarization
if global_vars.model_type == "custom":
ppmanager_type = chosen_ppm
else:
ppmanager_type = get_chat_manager(global_vars.model_type)
return (
"Preparation done!",
gr.update(visible=False),
gr.update(visible=True),
gr.update(label=global_vars.model_type),
{
"ppmanager_type": ppmanager_type,
"model_type": global_vars.model_type,
},
get_global_context(global_vars.model_type),
gen_config.temperature,
gen_config.top_p,
gen_config.top_k,
gen_config.repetition_penalty,
gen_config.max_new_tokens,
gen_config.num_beams,
gen_config.use_cache,
gen_config.do_sample,
gen_config.eos_token_id,
gen_config.pad_token_id,
gen_sum_config.temperature,
gen_sum_config.top_p,
gen_sum_config.top_k,
gen_sum_config.repetition_penalty,
gen_sum_config.max_new_tokens,
gen_sum_config.num_beams,
gen_sum_config.use_cache,
gen_sum_config.do_sample,
gen_sum_config.eos_token_id,
gen_sum_config.pad_token_id,
)
def toggle_inspector(view_selector):
if view_selector == "with context inspector":
return gr.update(visible=True)
else:
return gr.update(visible=False)
def reset_chat(idx, ld, state):
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
res[idx].pingpongs = []
return (
"",
[],
str(res),
gr.update(visible=True),
gr.update(interactive=False),
)
def rollback_last(idx, ld, state):
res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld]
last_user_message = res[idx].pingpongs[-1].ping
res[idx].pingpongs = res[idx].pingpongs[:-1]
return (
last_user_message,
res[idx].build_uis(),
str(res),
gr.update(interactive=False)
)
def gradio_main(args):
global local_files_only
local_files_only = args.local_files_only
with gr.Blocks(css=MODEL_SELECTION_CSS, theme='gradio/soft') as demo:
with gr.Column(visible=True, elem_id="landing-container") as landing_view:
gr.Markdown("# Chat with LLM", elem_classes=["center"])
with gr.Row(elem_id="landing-container-selection"):
with gr.Column():
gr.Markdown("""This is the landing page of the project, [LLM As Chatbot](https://github.com/deep-diver/LLM-As-Chatbot). This appliction is designed for personal use only. A single model will be selected at a time even if you open up a new browser or a tab. As an initial choice, please select one of the following menu""")
gr.Markdown("""
**Bring your own model**: You can chat with arbitrary models. If your own custom model is based on 🤗 Hugging Face's [transformers](https://huggingface.co/docs/transformers/index) library, you will propbably be able to bring it into this application with this menu
**Select a model from model pool**: You can chat with one of the popular open source Large Language Model
**Use currently selected model**: If you have already selected, but if you came back to this landing page accidently, you can directly go back to the chatting mode with this menu
""")
byom = gr.Button("🫵🏼 Bring your own model", elem_id="go-byom-select", elem_classes=["square", "landing-btn"])
select_model = gr.Button("🦙 Select a model from model pool", elem_id="go-model-select", elem_classes=["square", "landing-btn"])
chosen_model = gr.Button("↪️ Use currently selected model", elem_id="go-use-selected-model", elem_classes=["square", "landing-btn"])
with gr.Column(elem_id="landing-bottom"):
progress_view0 = gr.Textbox(label="Progress", elem_classes=["progress-view"])
gr.Markdown("""[project](https://github.com/deep-diver/LLM-As-Chatbot)
[developer](https://github.com/deep-diver)
""", elem_classes=["center"])
with gr.Column(visible=False) as model_choice_view:
gr.Markdown("# Choose a Model", elem_classes=["center"])
with gr.Row(elem_id="container"):
with gr.Column():
gr.Markdown("## ~ 10B Parameters")
with gr.Row(elem_classes=["sub-container"]):
with gr.Column(min_width=20):
t5_vicuna_3b = gr.Button("t5-vicuna-3b", elem_id="t5-vicuna-3b", elem_classes=["square"])
gr.Markdown("T5 Vicuna", elem_classes=["center"])
with gr.Column(min_width=20, visible=False):
flan3b = gr.Button("flan-3b", elem_id="flan-3b", elem_classes=["square"])
gr.Markdown("Flan-XL", elem_classes=["center"])
# with gr.Column(min_width=20):
# replit_3b = gr.Button("replit-3b", elem_id="replit-3b", elem_classes=["square"])
# gr.Markdown("Replit Instruct", elem_classes=["center"])
with gr.Column(min_width=20):
camel5b = gr.Button("camel-5b", elem_id="camel-5b", elem_classes=["square"])
gr.Markdown("Camel", elem_classes=["center"])
with gr.Column(min_width=20):
alpaca_lora7b = gr.Button("alpaca-lora-7b", elem_id="alpaca-lora-7b", elem_classes=["square"])
gr.Markdown("Alpaca-LoRA", elem_classes=["center"])
with gr.Column(min_width=20):
stablelm7b = gr.Button("stablelm-7b", elem_id="stablelm-7b", elem_classes=["square"])
gr.Markdown("StableLM", elem_classes=["center"])
with gr.Column(min_width=20, visible=False):
os_stablelm7b = gr.Button("os-stablelm-7b", elem_id="os-stablelm-7b", elem_classes=["square"])
gr.Markdown("OA+StableLM", elem_classes=["center"])
with gr.Column(min_width=20):
gpt4_alpaca_7b = gr.Button("gpt4-alpaca-7b", elem_id="gpt4-alpaca-7b", elem_classes=["square"])
gr.Markdown("GPT4-Alpaca-LoRA", elem_classes=["center"])
with gr.Column(min_width=20):
mpt_7b = gr.Button("mpt-7b", elem_id="mpt-7b", elem_classes=["square"])
gr.Markdown("MPT", elem_classes=["center"])
with gr.Column(min_width=20):
redpajama_7b = gr.Button("redpajama-7b", elem_id="redpajama-7b", elem_classes=["square"])
gr.Markdown("RedPajama", elem_classes=["center"])
with gr.Column(min_width=20, visible=False):
redpajama_instruct_7b = gr.Button("redpajama-instruct-7b", elem_id="redpajama-instruct-7b", elem_classes=["square"])
gr.Markdown("RedPajama Instruct", elem_classes=["center"])
with gr.Column(min_width=20):
vicuna_7b = gr.Button("vicuna-7b", elem_id="vicuna-7b", elem_classes=["square"])
gr.Markdown("Vicuna", elem_classes=["center"])
with gr.Column(min_width=20):
vicuna_7b_1_3 = gr.Button("vicuna-7b-1-3", elem_id="vicuna-7b-1-3", elem_classes=["square"])
gr.Markdown("Vicuna 1.3", elem_classes=["center"])
with gr.Column(min_width=20):
llama_deus_7b = gr.Button("llama-deus-7b", elem_id="llama-deus-7b",elem_classes=["square"])
gr.Markdown("LLaMA Deus", elem_classes=["center"])
with gr.Column(min_width=20):
evolinstruct_vicuna_7b = gr.Button("evolinstruct-vicuna-7b", elem_id="evolinstruct-vicuna-7b", elem_classes=["square"])
gr.Markdown("EvolInstruct Vicuna", elem_classes=["center"])
with gr.Column(min_width=20, visible=False):
alpacoom_7b = gr.Button("alpacoom-7b", elem_id="alpacoom-7b", elem_classes=["square"])
gr.Markdown("Alpacoom", elem_classes=["center"])
with gr.Column(min_width=20):
baize_7b = gr.Button("baize-7b", elem_id="baize-7b", elem_classes=["square"])
gr.Markdown("Baize", elem_classes=["center"])
with gr.Column(min_width=20):
guanaco_7b = gr.Button("guanaco-7b", elem_id="guanaco-7b", elem_classes=["square"])
gr.Markdown("Guanaco", elem_classes=["center"])
with gr.Column(min_width=20):
falcon_7b = gr.Button("falcon-7b", elem_id="falcon-7b", elem_classes=["square"])
gr.Markdown("Falcon", elem_classes=["center"])
with gr.Column(min_width=20):
wizard_falcon_7b = gr.Button("wizard-falcon-7b", elem_id="wizard-falcon-7b", elem_classes=["square"])
gr.Markdown("Wizard Falcon", elem_classes=["center"])
with gr.Column(min_width=20):
airoboros_7b = gr.Button("airoboros-7b", elem_id="airoboros-7b", elem_classes=["square"])
gr.Markdown("Airoboros", elem_classes=["center"])
with gr.Column(min_width=20):
samantha_7b = gr.Button("samantha-7b", elem_id="samantha-7b", elem_classes=["square"])
gr.Markdown("Samantha", elem_classes=["center"])
with gr.Column(min_width=20):
openllama_7b = gr.Button("openllama-7b", elem_id="openllama-7b", elem_classes=["square"])
gr.Markdown("OpenLLaMA", elem_classes=["center"])
with gr.Column(min_width=20):
orcamini_7b = gr.Button("orcamini-7b", elem_id="orcamini-7b", elem_classes=["square"])
gr.Markdown("Orca Mini", elem_classes=["center"])
with gr.Column(min_width=20):
xgen_7b = gr.Button("xgen-7b", elem_id="xgen-7b", elem_classes=["square"])
gr.Markdown("XGen", elem_classes=["center"])
with gr.Column(min_width=20):
llama2_7b = gr.Button("llama2-7b", elem_id="llama2-7b", elem_classes=["square"])
gr.Markdown("LLaMA 2", elem_classes=["center"])
gr.Markdown("## ~ 20B Parameters")
with gr.Row(elem_classes=["sub-container"]):
with gr.Column(min_width=20, visible=False):
flan11b = gr.Button("flan-11b", elem_id="flan-11b", elem_classes=["square"])
gr.Markdown("Flan-XXL", elem_classes=["center"])
with gr.Column(min_width=20):
koalpaca = gr.Button("koalpaca", elem_id="koalpaca", elem_classes=["square"])
gr.Markdown("koalpaca", elem_classes=["center"])
with gr.Column(min_width=20):
kullm = gr.Button("kullm", elem_id="kullm", elem_classes=["square"])
gr.Markdown("KULLM", elem_classes=["center"])
with gr.Column(min_width=20):
alpaca_lora13b = gr.Button("alpaca-lora-13b", elem_id="alpaca-lora-13b", elem_classes=["square"])
gr.Markdown("Alpaca-LoRA", elem_classes=["center"])
with gr.Column(min_width=20):
gpt4_alpaca_13b = gr.Button("gpt4-alpaca-13b", elem_id="gpt4-alpaca-13b", elem_classes=["square"])
gr.Markdown("GPT4-Alpaca-LoRA", elem_classes=["center"])
with gr.Column(min_width=20):
stable_vicuna_13b = gr.Button("stable-vicuna-13b", elem_id="stable-vicuna-13b", elem_classes=["square"])
gr.Markdown("Stable-Vicuna", elem_classes=["center"])
with gr.Column(min_width=20):
starchat_15b = gr.Button("starchat-15b", elem_id="starchat-15b", elem_classes=["square"])
gr.Markdown("StarChat", elem_classes=["center"])
with gr.Column(min_width=20):
starchat_beta_15b = gr.Button("starchat-beta-15b", elem_id="starchat-beta-15b", elem_classes=["square"])
gr.Markdown("StarChat β", elem_classes=["center"])
with gr.Column(min_width=20):
vicuna_13b = gr.Button("vicuna-13b", elem_id="vicuna-13b", elem_classes=["square"])
gr.Markdown("Vicuna", elem_classes=["center"])
with gr.Column(min_width=20):
vicuna_13b_1_3 = gr.Button("vicuna-13b-1-3", elem_id="vicuna-13b-1-3", elem_classes=["square"])
gr.Markdown("Vicuna 1.3", elem_classes=["center"])
with gr.Column(min_width=20):
evolinstruct_vicuna_13b = gr.Button("evolinstruct-vicuna-13b", elem_id="evolinstruct-vicuna-13b", elem_classes=["square"])
gr.Markdown("EvolInstruct Vicuna", elem_classes=["center"])
with gr.Column(min_width=20):
baize_13b = gr.Button("baize-13b", elem_id="baize-13b", elem_classes=["square"])
gr.Markdown("Baize", elem_classes=["center"])
with gr.Column(min_width=20):
guanaco_13b = gr.Button("guanaco-13b", elem_id="guanaco-13b", elem_classes=["square"])
gr.Markdown("Guanaco", elem_classes=["center"])
with gr.Column(min_width=20):
nous_hermes_13b = gr.Button("nous-hermes-13b", elem_id="nous-hermes-13b", elem_classes=["square"])
gr.Markdown("Nous Hermes", elem_classes=["center"])
with gr.Column(min_width=20):
airoboros_13b = gr.Button("airoboros-13b", elem_id="airoboros-13b", elem_classes=["square"])
gr.Markdown("Airoboros", elem_classes=["center"])
with gr.Column(min_width=20):
samantha_13b = gr.Button("samantha-13b", elem_id="samantha-13b", elem_classes=["square"])
gr.Markdown("Samantha", elem_classes=["center"])
with gr.Column(min_width=20):
chronos_13b = gr.Button("chronos-13b", elem_id="chronos-13b", elem_classes=["square"])
gr.Markdown("Chronos", elem_classes=["center"])
with gr.Column(min_width=20):
wizardlm_13b = gr.Button("wizardlm-13b", elem_id="wizardlm-13b", elem_classes=["square"])
gr.Markdown("WizardLM", elem_classes=["center"])
with gr.Column(min_width=20):
wizard_vicuna_13b = gr.Button("wizard-vicuna-13b", elem_id="wizard-vicuna-13b", elem_classes=["square"])
gr.Markdown("Wizard Vicuna (Uncensored)", elem_classes=["center"])
with gr.Column(min_width=20):
wizard_coder_15b = gr.Button("wizard-coder-15b", elem_id="wizard-coder-15b", elem_classes=["square"])
gr.Markdown("Wizard Coder", elem_classes=["center"])
with gr.Column(min_width=20):
openllama_13b = gr.Button("openllama-13b", elem_id="openllama-13b", elem_classes=["square"])
gr.Markdown("OpenLLaMA", elem_classes=["center"])
with gr.Column(min_width=20):
orcamini_13b = gr.Button("orcamini-13b", elem_id="orcamini-13b", elem_classes=["square"])
gr.Markdown("Orca Mini", elem_classes=["center"])
with gr.Column(min_width=20):
llama2_13b = gr.Button("llama2-13b", elem_id="llama2-13b", elem_classes=["square"])
gr.Markdown("LLaMA 2", elem_classes=["center"])
with gr.Column(min_width=20):
nous_hermes_13b_v2 = gr.Button("nous-hermes-13b-llama2", elem_id="nous-hermes-13b-llama2", elem_classes=["square"])
gr.Markdown("Nous Hermes v2", elem_classes=["center"])
gr.Markdown("## ~ 30B Parameters", visible=False)
with gr.Row(elem_classes=["sub-container"], visible=False):
with gr.Column(min_width=20):
camel20b = gr.Button("camel-20b", elem_id="camel-20b", elem_classes=["square"])
gr.Markdown("Camel", elem_classes=["center"])
gr.Markdown("## ~ 40B Parameters")
with gr.Row(elem_classes=["sub-container"]):
with gr.Column(min_width=20):
guanaco_33b = gr.Button("guanaco-33b", elem_id="guanaco-33b", elem_classes=["square"])
gr.Markdown("Guanaco", elem_classes=["center"])
with gr.Column(min_width=20):
falcon_40b = gr.Button("falcon-40b", elem_id="falcon-40b", elem_classes=["square"])
gr.Markdown("Falcon", elem_classes=["center"])
with gr.Column(min_width=20):
wizard_falcon_40b = gr.Button("wizard-falcon-40b", elem_id="wizard-falcon-40b", elem_classes=["square"])
gr.Markdown("Wizard Falcon", elem_classes=["center"])
with gr.Column(min_width=20):
samantha_33b = gr.Button("samantha-33b", elem_id="samantha-33b", elem_classes=["square"])
gr.Markdown("Samantha", elem_classes=["center"])
with gr.Column(min_width=20):
lazarus_30b = gr.Button("lazarus-30b", elem_id="lazarus-30b", elem_classes=["square"])
gr.Markdown("Lazarus", elem_classes=["center"])
with gr.Column(min_width=20):
chronos_33b = gr.Button("chronos-33b", elem_id="chronos-33b", elem_classes=["square"])
gr.Markdown("Chronos", elem_classes=["center"])
with gr.Column(min_width=20):
wizardlm_30b = gr.Button("wizardlm-30b", elem_id="wizardlm-30b", elem_classes=["square"])
gr.Markdown("WizardLM", elem_classes=["center"])
with gr.Column(min_width=20):
wizard_vicuna_30b = gr.Button("wizard-vicuna-30b", elem_id="wizard-vicuna-30b", elem_classes=["square"])
gr.Markdown("Wizard Vicuna (Uncensored)", elem_classes=["center"])
with gr.Column(min_width=20):
vicuna_33b_1_3 = gr.Button("vicuna-33b-1-3", elem_id="vicuna-33b-1-3", elem_classes=["square"])
gr.Markdown("Vicuna 1.3", elem_classes=["center"])
with gr.Column(min_width=20):
mpt_30b = gr.Button("mpt-30b", elem_id="mpt-30b", elem_classes=["square"])
gr.Markdown("MPT", elem_classes=["center"])
with gr.Column(min_width=20):
upstage_llama_30b = gr.Button("upstage-llama-30b", elem_id="upstage-llama-30b", elem_classes=["square"])
gr.Markdown("Upstage LLaMA", elem_classes=["center"])
gr.Markdown("## ~ 70B Parameters")
with gr.Row(elem_classes=["sub-container"]):
with gr.Column(min_width=20):
free_willy2_70b = gr.Button("free-willy2-70b", elem_id="free-willy2-70b", elem_classes=["square"])
gr.Markdown("Free Willy 2", elem_classes=["center"])
progress_view = gr.Textbox(label="Progress", elem_classes=["progress-view"])
with gr.Column(visible=False) as byom_input_view:
with gr.Column(elem_id="container3"):
gr.Markdown("# Bring Your Own Model", elem_classes=["center"])
gr.Markdown("### Model configuration")
byom_base = gr.Textbox(label="Base", placeholder="Enter path or 🤗 hub ID of the base model", interactive=True)
byom_ckpt = gr.Textbox(label="LoRA ckpt", placeholder="Enter path or 🤗 hub ID of the LoRA checkpoint", interactive=True)
with gr.Accordion("Advanced options", open=False):
gr.Markdown("If you leave the below textboxes empty, `transformers.AutoModelForCausalLM` and `transformers.AutoTokenizer` classes will be used by default. If you need any specific class, please type them below.")
byom_model_cls = gr.Textbox(label="Base model class", placeholder="Enter base model class", interactive=True)
byom_tokenizer_cls = gr.Textbox(label="Base tokenizer class", placeholder="Enter base tokenizer class", interactive=True)
with gr.Column():
gr.Markdown("If you leave the below textboxes empty, any token ids for bos, eos, and pad will not be specified in `GenerationConfig`. If you think that you need to specify them. please type them below in decimal format.")
with gr.Row():
byom_bos_token_id = gr.Textbox(label="bos_token_id", placeholder="for GenConfig")
byom_eos_token_id = gr.Textbox(label="eos_token_id", placeholder="for GenConfig")
byom_pad_token_id = gr.Textbox(label="pad_token_id", placeholder="for GenConfig")
with gr.Row():
byom_load_mode = gr.Radio(
load_mode_list,
value=load_mode_list[0],
label="load mode",
elem_classes=["load-mode-selector"]
)
gr.Markdown("### Prompt configuration")
prompt_style_selector = gr.Dropdown(
label="Prompt style",
interactive=True,
choices=list(prompt_styles.keys()),
value="Alpaca"
)
with gr.Accordion("Prompt style preview", open=False):
prompt_style_previewer = gr.Textbox(
label="How prompt is actually structured",
lines=16,
value=default_ppm.build_prompts())
with gr.Row():
byom_back_btn = gr.Button("Back")
byom_confirm_btn = gr.Button("Confirm")
with gr.Column(elem_classes=["progress-view"]):
txt_view3 = gr.Textbox(label="Status")
progress_view3 = gr.Textbox(label="Progress")
with gr.Column(visible=False) as model_review_view:
gr.Markdown("# Confirm the chosen model", elem_classes=["center"])
with gr.Column(elem_id="container2"):
gr.Markdown("Please expect loading time to be longer than expected. Depending on the size of models, it will probably take from 100 to 1000 seconds or so. Please be patient.")
with gr.Row():
model_image = gr.Image(None, interactive=False, show_label=False)
with gr.Column():
model_name = gr.Markdown("**Model name**")
model_desc = gr.Markdown("...")
model_params = gr.Markdown("Parameters\n: ...")
model_base = gr.Markdown("🤗 Hub(base)\n: ...")
model_ckpt = gr.Markdown("🤗 Hub(LoRA)\n: ...")
model_vram = gr.Markdown(f"""**Minimal VRAM requirement** :
| half precision | load_in_8bit | load_in_4bit |
| ------------------------------ | ------------------------- | ------------------------- |
| {round(7830/1024., 1)}GiB | {round(5224/1024., 1)}GiB | {round(4324/1024., 1)}GiB |
""")
model_thumbnail_tiny = gr.Textbox("", visible=False)
with gr.Column():
gen_config_path = gr.Dropdown(
response_configs,
value=response_configs[0],
interactive=True,
label="Gen Config(response)",
)
gen_config_sum_path = gr.Dropdown(
summarization_configs,
value=summarization_configs[0],
interactive=True,
label="Gen Config(summarization)",
visible=False,
)
with gr.Row():
load_mode = gr.Radio(
load_mode_list,
value=load_mode_list[0],
label="load mode",
elem_classes=["load-mode-selector"]
)
force_redownload = gr.Checkbox(label="Force Re-download", interactive=False, visible=False)
with gr.Accordion("Example showcases", open=False):
with gr.Tab("Ex1"):
example_showcase1 = gr.Chatbot(
[("hello", "world"), ("damn", "good")]
)
with gr.Tab("Ex2"):
example_showcase2 = gr.Chatbot(
[("hello", "world"), ("damn", "good")]
)
with gr.Tab("Ex3"):
example_showcase3 = gr.Chatbot(
[("hello", "world"), ("damn", "good")]
)
with gr.Tab("Ex4"):
example_showcase4 = gr.Chatbot(
[("hello", "world"), ("damn", "good")]
)
with gr.Row():
back_to_model_choose_btn = gr.Button("Back")
confirm_btn = gr.Button("Confirm")
with gr.Column(elem_classes=["progress-view"]):
txt_view = gr.Textbox(label="Status")
progress_view2 = gr.Textbox(label="Progress")
with gr.Column(visible=False) as chat_view:
idx = gr.State(0)
chat_state = gr.State()
local_data = gr.JSON({}, visible=False)
with gr.Row():
with gr.Column(scale=1, min_width=180):
gr.Markdown("GradioChat", elem_id="left-top")
with gr.Column(elem_id="left-pane"):
chat_back_btn = gr.Button("Back", elem_id="chat-back-btn")
with gr.Accordion("Histories", elem_id="chat-history-accordion", open=False):
channel_btns.append(gr.Button(channels[0], elem_classes=["custom-btn-highlight"]))
for channel in channels[1:]:
channel_btns.append(gr.Button(channel, elem_classes=["custom-btn"]))
with gr.Column(scale=8, elem_id="right-pane"):
with gr.Column(
elem_id="initial-popup", visible=False
) as example_block:
with gr.Row(scale=1):
with gr.Column(elem_id="initial-popup-left-pane"):
gr.Markdown("GradioChat", elem_id="initial-popup-title")
gr.Markdown("Making the community's best AI chat models available to everyone.")
with gr.Column(elem_id="initial-popup-right-pane"):
gr.Markdown("Chat UI is now open sourced on Hugging Face Hub")
gr.Markdown("check out the [↗ repository](https://huggingface.co/spaces/chansung/test-multi-conv)")
with gr.Column(scale=1):
gr.Markdown("Examples")
with gr.Row():
for example in examples:
ex_btns.append(gr.Button(example, elem_classes=["example-btn"]))
with gr.Column(elem_id="aux-btns-popup", visible=True):
with gr.Row():
stop = gr.Button("Stop", elem_classes=["aux-btn"])
regenerate = gr.Button("Regen", interactive=False, elem_classes=["aux-btn"])
clean = gr.Button("Clean", elem_classes=["aux-btn"])
with gr.Accordion("Context Inspector", elem_id="aux-viewer", open=False):
context_inspector = gr.Textbox(
"",
elem_id="aux-viewer-inspector",
label="",
lines=30,
max_lines=50,
)
chatbot = gr.Chatbot(elem_id='chatbot')
instruction_txtbox = gr.Textbox(placeholder="Ask anything", label="", elem_id="prompt-txt")
with gr.Accordion("Control Panel", open=False) as control_panel:
with gr.Column():
with gr.Column():
gr.Markdown("#### Global context")
with gr.Accordion("global context will persist during conversation, and it is placed at the top of the prompt", open=False):
global_context = gr.Textbox(
"global context",
lines=5,
max_lines=10,
interactive=True,
elem_id="global-context"
)
gr.Markdown("#### Internet search")
with gr.Row():
internet_option = gr.Radio(choices=["on", "off"], value="off", label="mode")
serper_api_key = gr.Textbox(
value= "" if args.serper_api_key is None else args.serper_api_key,
placeholder="Get one by visiting serper.dev",
label="Serper api key"
)
gr.Markdown("#### GenConfig for **response** text generation")
with gr.Row():
res_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True)
res_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True)
res_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True)
res_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True)
res_mnts = gr.Slider(64, 2048, 0, step=1, label="new_tokens", interactive=True)
res_beams = gr.Slider(1, 4, 0, step=1, label="beams")
res_cache = gr.Radio([True, False], value=0, label="cache", interactive=True)
res_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
res_eosid = gr.Number(value=0, visible=False, precision=0)
res_padid = gr.Number(value=0, visible=False, precision=0)
with gr.Column(visible=False):
gr.Markdown("#### GenConfig for **summary** text generation")
with gr.Row():
sum_temp = gr.Slider(0.0, 2.0, 0, step=0.1, label="temp", interactive=True)
sum_topp = gr.Slider(0.0, 2.0, 0, step=0.1, label="top_p", interactive=True)
sum_topk = gr.Slider(20, 1000, 0, step=1, label="top_k", interactive=True)
sum_rpen = gr.Slider(0.0, 2.0, 0, step=0.1, label="rep_penalty", interactive=True)
sum_mnts = gr.Slider(64, 2048, 0, step=1, label="new_tokens", interactive=True)
sum_beams = gr.Slider(1, 8, 0, step=1, label="beams", interactive=True)
sum_cache = gr.Radio([True, False], value=0, label="cache", interactive=True)
sum_sample = gr.Radio([True, False], value=0, label="sample", interactive=True)
sum_eosid = gr.Number(value=0, visible=False, precision=0)
sum_padid = gr.Number(value=0, visible=False, precision=0)
with gr.Column():
gr.Markdown("#### Context managements")
with gr.Row():
ctx_num_lconv = gr.Slider(2, 10, 3, step=1, label="number of recent talks to keep", interactive=True)
ctx_sum_prompt = gr.Textbox(
"summarize our conversations. what have we discussed about so far?",
label="design a prompt to summarize the conversations",
visible=False
)
btns = [
t5_vicuna_3b, flan3b, camel5b, alpaca_lora7b, stablelm7b,
gpt4_alpaca_7b, os_stablelm7b, mpt_7b, redpajama_7b, redpajama_instruct_7b, llama_deus_7b,
evolinstruct_vicuna_7b, alpacoom_7b, baize_7b, guanaco_7b, vicuna_7b_1_3,
falcon_7b, wizard_falcon_7b, airoboros_7b, samantha_7b, openllama_7b, orcamini_7b,
xgen_7b,llama2_7b,
flan11b, koalpaca, kullm, alpaca_lora13b, gpt4_alpaca_13b, stable_vicuna_13b,
starchat_15b, starchat_beta_15b, vicuna_7b, vicuna_13b, evolinstruct_vicuna_13b,
baize_13b, guanaco_13b, nous_hermes_13b, airoboros_13b, samantha_13b, chronos_13b,
wizardlm_13b, wizard_vicuna_13b, wizard_coder_15b, vicuna_13b_1_3, openllama_13b, orcamini_13b,
llama2_13b, nous_hermes_13b_v2, camel20b,
guanaco_33b, falcon_40b, wizard_falcon_40b, samantha_33b, lazarus_30b, chronos_33b,
wizardlm_30b, wizard_vicuna_30b, vicuna_33b_1_3, mpt_30b, upstage_llama_30b,
free_willy2_70b
]
for btn in btns:
btn.click(
move_to_second_view,
btn,
[
model_choice_view, model_review_view,
model_image, model_name, model_params, model_base, model_ckpt,
model_desc, model_vram, gen_config_path,
example_showcase1, example_showcase2, example_showcase3, example_showcase4,
model_thumbnail_tiny, load_mode,
progress_view
]
)
select_model.click(
move_to_model_select_view,
None,
[progress_view0, landing_view, model_choice_view]
)
chosen_model.click(
use_chosen_model,
None,
[progress_view0, landing_view, chat_view, chatbot, chat_state, global_context,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid]
)
byom.click(
move_to_byom_view,
None,
[progress_view0, landing_view, byom_input_view, byom_load_mode]
)
byom_back_btn.click(
move_to_first_view,
None,
[landing_view, byom_input_view]
)
byom_confirm_btn.click(
lambda: "Start downloading/loading the model...", None, txt_view3
).then(
byom_load,
[byom_base, byom_ckpt, byom_model_cls, byom_tokenizer_cls,
byom_bos_token_id, byom_eos_token_id, byom_pad_token_id,
byom_load_mode],
[progress_view3]
).then(
lambda: "Model is fully loaded...", None, txt_view3
).then(
move_to_third_view,
None,
[progress_view3, byom_input_view, chat_view, chatbot, chat_state, global_context,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid]
)
prompt_style_selector.change(
prompt_style_change,
prompt_style_selector,
prompt_style_previewer
)
back_to_model_choose_btn.click(
move_to_first_view,
None,
[model_choice_view, model_review_view]
)
confirm_btn.click(
lambda: "Start downloading/loading the model...", None, txt_view
).then(
download_completed,
[model_name, model_base, model_ckpt, gen_config_path, gen_config_sum_path, load_mode, model_thumbnail_tiny, force_redownload],
[progress_view2]
).then(
lambda: "Model is fully loaded...", None, txt_view
).then(
lambda: time.sleep(2), None, None
).then(
move_to_third_view,
None,
[progress_view2, model_review_view, chat_view, chatbot, chat_state, global_context,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid]
)
for btn in channel_btns:
btn.click(
set_chatbot,
[btn, local_data, chat_state],
[chatbot, idx, example_block, regenerate]
).then(
None, btn, None,
_js=UPDATE_LEFT_BTNS_STATE
)
for btn in ex_btns:
btn.click(
set_example,
[btn],
[instruction_txtbox, example_block]
)
instruction_txtbox.submit(
lambda: [
gr.update(visible=False),
gr.update(interactive=True)
],
None,
[example_block, regenerate]
)
send_event = instruction_txtbox.submit(
central.chat_stream,
[idx, local_data, instruction_txtbox, chat_state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key],
[instruction_txtbox, chatbot, context_inspector, local_data],
)
instruction_txtbox.submit(
None, local_data, None,
_js="(v)=>{ setStorage('local_data',v) }"
)
regenerate.click(
rollback_last,
[idx, local_data, chat_state],
[instruction_txtbox, chatbot, local_data, regenerate]
).then(
central.chat_stream,
[idx, local_data, instruction_txtbox, chat_state,
global_context, ctx_num_lconv, ctx_sum_prompt,
res_temp, res_topp, res_topk, res_rpen, res_mnts, res_beams, res_cache, res_sample, res_eosid, res_padid,
sum_temp, sum_topp, sum_topk, sum_rpen, sum_mnts, sum_beams, sum_cache, sum_sample, sum_eosid, sum_padid,
internet_option, serper_api_key],
[instruction_txtbox, chatbot, context_inspector, local_data],
).then(
lambda: gr.update(interactive=True),
None,
regenerate
).then(
None, local_data, None,
_js="(v)=>{ setStorage('local_data',v) }"
)
stop.click(
None, None, None,
cancels=[send_event]
)
clean.click(
reset_chat,
[idx, local_data, chat_state],
[instruction_txtbox, chatbot, local_data, example_block, regenerate]
).then(
None, local_data, None,
_js="(v)=>{ setStorage('local_data',v) }"
)
chat_back_btn.click(
lambda: [gr.update(visible=False), gr.update(visible=True)],
None,
[chat_view, landing_view]
)
demo.load(
None,
inputs=None,
outputs=[chatbot, local_data],
_js=GET_LOCAL_STORAGE,
)
demo.queue().launch(
server_port=6006,
server_name="0.0.0.0",
debug=args.debug,
share=args.share,
root_path=f"{args.root_path}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root-path', default="")
parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--share', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--debug', default=False, action=argparse.BooleanOptionalAction)
parser.add_argument('--serper-api-key', default=None, type=str)
args = parser.parse_args()
gradio_main(args)