import os import time import json import copy import types from os import listdir from os.path import isfile, join import argparse import gradio as gr import global_vars from chats import central from transformers import AutoModelForCausalLM from miscs.styles import MODEL_SELECTION_CSS from miscs.js import GET_LOCAL_STORAGE, UPDATE_LEFT_BTNS_STATE from utils import get_chat_manager, get_global_context from pingpong.pingpong import PingPong from pingpong.gradio import GradioAlpacaChatPPManager from pingpong.gradio import GradioKoAlpacaChatPPManager from pingpong.gradio import GradioStableLMChatPPManager from pingpong.gradio import GradioFlanAlpacaChatPPManager from pingpong.gradio import GradioOSStableLMChatPPManager from pingpong.gradio import GradioVicunaChatPPManager from pingpong.gradio import GradioStableVicunaChatPPManager from pingpong.gradio import GradioStarChatPPManager from pingpong.gradio import GradioMPTChatPPManager from pingpong.gradio import GradioRedPajamaChatPPManager from pingpong.gradio import GradioBaizeChatPPManager # no cpu for # - falcon families (too slow) load_mode_list = ["cpu"] ex_file = open("examples.txt", "r") examples = ex_file.read().split("\n") ex_btns = [] chl_file = open("channels.txt", "r") channels = chl_file.read().split("\n") channel_btns = [] default_ppm = GradioAlpacaChatPPManager() default_ppm.ctx = "Context at top" default_ppm.pingpongs = [ PingPong("user input #1...", "bot response #1..."), PingPong("user input #2...", "bot response #2..."), ] chosen_ppm = copy.deepcopy(default_ppm) prompt_styles = { "Alpaca": default_ppm, "Baize": GradioBaizeChatPPManager(), "Koalpaca": GradioKoAlpacaChatPPManager(), "MPT": GradioMPTChatPPManager(), "OpenAssistant StableLM": GradioOSStableLMChatPPManager(), "RedPajama": GradioRedPajamaChatPPManager(), "StableVicuna": GradioVicunaChatPPManager(), "StableLM": GradioStableLMChatPPManager(), "StarChat": GradioStarChatPPManager(), "Vicuna": GradioVicunaChatPPManager(), } response_configs = [ f"configs/response_configs/{f}" for f in listdir("configs/response_configs") if isfile(join("configs/response_configs", f)) ] summarization_configs = [ f"configs/summarization_configs/{f}" for f in listdir("configs/summarization_configs") if isfile(join("configs/summarization_configs", f)) ] model_info = json.load(open("model_cards.json")) ### def move_to_model_select_view(): return ( "move to model select view", gr.update(visible=False), gr.update(visible=True), ) def use_chosen_model(): try: test = global_vars.model except AttributeError: raise gr.Error("There is no previously chosen model") gen_config = global_vars.gen_config gen_sum_config = global_vars.gen_config_summarization if global_vars.model_type == "custom": ppmanager_type = chosen_ppm else: ppmanager_type = get_chat_manager(global_vars.model_type) return ( "Preparation done!", gr.update(visible=False), gr.update(visible=True), gr.update(label=global_vars.model_type), { "ppmanager_type": ppmanager_type, "model_type": global_vars.model_type, }, get_global_context(global_vars.model_type), gen_config.temperature, gen_config.top_p, gen_config.top_k, gen_config.repetition_penalty, gen_config.max_new_tokens, gen_config.num_beams, gen_config.use_cache, gen_config.do_sample, gen_config.eos_token_id, gen_config.pad_token_id, gen_sum_config.temperature, gen_sum_config.top_p, gen_sum_config.top_k, gen_sum_config.repetition_penalty, gen_sum_config.max_new_tokens, gen_sum_config.num_beams, gen_sum_config.use_cache, gen_sum_config.do_sample, gen_sum_config.eos_token_id, gen_sum_config.pad_token_id, ) def move_to_byom_view(): load_mode_list = [] if global_vars.cuda_availability: load_mode_list.extend(["gpu(half)", "gpu(load_in_8bit)", "gpu(load_in_4bit)"]) if global_vars.mps_availability: load_mode_list.append("apple silicon") load_mode_list.append("cpu") return ( "move to the byom view", gr.update(visible=False), gr.update(visible=True), gr.update(choices=load_mode_list, value=load_mode_list[0]) ) def prompt_style_change(key): ppm = prompt_styles[key] ppm.ctx = "Context at top" ppm.pingpongs = [ PingPong("user input #1...", "bot response #1..."), PingPong("user input #2...", "bot response #2..."), ] chosen_ppm = copy.deepcopy(ppm) chosen_ppm.ctx = "" chosen_ppm.pingpongs = [] return ppm.build_prompts() def byom_load( base, ckpt, model_cls, tokenizer_cls, bos_token_id, eos_token_id, pad_token_id, load_mode, ): # mode_cpu, model_mps, mode_8bit, mode_4bit, mode_full_gpu global_vars.initialize_globals_byom( base, ckpt, model_cls, tokenizer_cls, bos_token_id, eos_token_id, pad_token_id, True if load_mode == "cpu" else False, True if load_mode == "apple silicon" else False, True if load_mode == "8bit" else False, True if load_mode == "4bit" else False, True if load_mode == "gpu(half)" else False, ) return ( "" ) def channel_num(btn_title): choice = 0 for idx, channel in enumerate(channels): if channel == btn_title: choice = idx return choice def set_chatbot(btn, ld, state): choice = channel_num(btn) res = [state["ppmanager_type"].from_json(json.dumps(ppm_str)) for ppm_str in ld] empty = len(res[choice].pingpongs) == 0 return (res[choice].build_uis(), choice, gr.update(visible=empty), gr.update(interactive=not empty)) def set_example(btn): return btn, gr.update(visible=False) def set_popup_visibility(ld, example_block): return example_block def move_to_second_view(btn): info = model_info[btn] guard_vram = 5 * 1024. vram_req_full = int(info["vram(full)"]) + guard_vram vram_req_8bit = int(info["vram(8bit)"]) + guard_vram vram_req_4bit = int(info["vram(4bit)"]) + guard_vram load_mode_list = [] if global_vars.cuda_availability: print(f"total vram = {global_vars.available_vrams_mb}") print(f"required vram(full={info['vram(full)']}, 8bit={info['vram(8bit)']}, 4bit={info['vram(4bit)']})") if global_vars.available_vrams_mb >= vram_req_full: load_mode_list.append("gpu(half)") if global_vars.available_vrams_mb >= vram_req_8bit: load_mode_list.append("gpu(load_in_8bit)") if global_vars.available_vrams_mb >= vram_req_4bit: load_mode_list.append("gpu(load_in_4bit)") if global_vars.mps_availability: load_mode_list.append("apple silicon") load_mode_list.extend(["cpu"]) return ( gr.update(visible=False), gr.update(visible=True), info["thumb"], f"## {btn}", f"**Parameters**\n: Approx. {info['parameters']}", f"**🤗 Hub(base)**\n: {info['hub(base)']}", f"**🤗 Hub(LoRA)**\n: {info['hub(ckpt)']}", info['desc'], f"""**Min VRAM requirements** : | half precision | load_in_8bit | load_in_4bit | | ------------------------------------- | ---------------------------------- | ---------------------------------- | | {round(vram_req_full/1024., 1)}GiB | {round(vram_req_8bit/1024., 1)}GiB | {round(vram_req_4bit/1024., 1)}GiB | """, info['default_gen_config'], info['example1'], info['example2'], info['example3'], info['example4'], info['thumb-tiny'], gr.update(choices=load_mode_list, value=load_mode_list[0]), "", ) def move_to_first_view(): return (gr.update(visible=True), gr.update(visible=False)) def download_completed( model_name, model_base, model_ckpt, gen_config_path, gen_config_sum_path, load_mode, thumbnail_tiny, force_download, ): global local_files_only tmp_args = types.SimpleNamespace() tmp_args.base_url = model_base.split(":")[-1].split("