Spaces:
Runtime error
Runtime error
import gc | |
import yaml | |
import torch | |
from transformers import GenerationConfig | |
from models import alpaca, stablelm, koalpaca, flan_alpaca, mpt | |
from models import camel, t5_vicuna, vicuna, starchat, redpajama, bloom | |
from models import baize, guanaco, falcon, kullm, replit, airoboros | |
from models import samantha_vicuna, wizard_coder, xgen, freewilly | |
from models import byom | |
cuda_availability = False | |
available_vrams_gb = 0 | |
mps_availability = False | |
if torch.cuda.is_available(): | |
cuda_availability = True | |
available_vrams_mb = sum( | |
[ | |
torch.cuda.get_device_properties(i).total_memory | |
for i in range(torch.cuda.device_count()) | |
] | |
) / 1024. / 1024 | |
if torch.backends.mps.is_available(): | |
mps_availability = True | |
def initialize_globals_byom( | |
base, ckpt, model_cls, tokenizer_cls, | |
bos_token_id, eos_token_id, pad_token_id, | |
mode_cpu, model_mps, mode_8bit, mode_4bit, mode_full_gpu | |
): | |
global model, model_type, stream_model, tokenizer | |
global model_thumbnail_tiny, device | |
global gen_config, gen_config_raw | |
global gen_config_summarization | |
model_type = "custom" | |
model, tokenizer = byom.load_model( | |
base=base, | |
finetuned=ckpt, | |
mode_cpu=mode_cpu, | |
mode_mps=mode_mps, | |
mode_full_gpu=mode_full_gpu, | |
mode_8bit=mode_8bit, | |
mode_4bit=mode_4bit, | |
model_cls=model_cls if model_cls != "" else None, | |
tokenizer_cls=tokenizer_cls if tokenizer_cls != "" else None | |
) | |
stream_model = model | |
gen_config, gen_config_raw = get_generation_config("configs/response_configs/default.yaml") | |
gen_config_summarization, _ = get_generation_config("configs/summarization_configs/default.yaml") | |
if bos_token_id != "" or bos_token_id.isdigit(): | |
gen_config.bos_token_id = int(bos_token_id) | |
if eos_token_id != "" or eos_token_id.isdigit(): | |
gen_config.eos_token_id = int(eos_token_id) | |
if pad_token_id != "" or pad_token_id.isdigit(): | |
gen_config.pad_token_id = int(pad_token_id) | |
def initialize_globals(args): | |
global device, model_thumbnail_tiny | |
global model, model_type, stream_model, tokenizer | |
global gen_config, gen_config_raw | |
global gen_config_summarization | |
model_type_tmp = "alpaca" | |
if "stabilityai/freewilly2" in args.base_url.lower(): | |
model_type_tmp = "free-willy" | |
elif "upstage/llama-" in args.base_url.lower(): | |
model_type_tmp = "upstage-llama" | |
elif "llama-2" in args.base_url.lower(): | |
model_type_tmp = "llama2" | |
elif "xgen" in args.base_url.lower(): | |
model_type_tmp = "xgen" | |
elif "orca_mini" in args.base_url.lower(): | |
model_type_tmp = "orcamini" | |
elif "open-llama" in args.base_url.lower(): | |
model_type_tmp = "openllama" | |
elif "wizardcoder" in args.base_url.lower(): | |
model_type_tmp = "wizard-coder" | |
elif "wizard-vicuna" in args.base_url.lower(): | |
model_type_tmp = "wizard-vicuna" | |
elif "llms/wizardlm" in args.base_url.lower(): | |
model_type_tmp = "wizardlm" | |
elif "chronos" in args.base_url.lower(): | |
model_type_tmp = "chronos" | |
elif "lazarus" in args.base_url.lower(): | |
model_type_tmp = "lazarus" | |
elif "samantha" in args.base_url.lower(): | |
model_type_tmp = "samantha-vicuna" | |
elif "airoboros" in args.base_url.lower(): | |
model_type_tmp = "airoboros" | |
elif "replit" in args.base_url.lower(): | |
model_type_tmp = "replit-instruct" | |
elif "kullm" in args.base_url.lower(): | |
model_type_tmp = "kullm-polyglot" | |
elif "nous-hermes" in args.base_url.lower(): | |
model_type_tmp = "nous-hermes" | |
elif "guanaco" in args.base_url.lower(): | |
model_type_tmp = "guanaco" | |
elif "wizardlm-uncensored-falcon" in args.base_url.lower(): | |
model_type_tmp = "wizard-falcon" | |
elif "falcon" in args.base_url.lower(): | |
model_type_tmp = "falcon" | |
elif "baize" in args.base_url.lower(): | |
model_type_tmp = "baize" | |
elif "stable-vicuna" in args.base_url.lower(): | |
model_type_tmp = "stable-vicuna" | |
elif "vicuna" in args.base_url.lower(): | |
model_type_tmp = "vicuna" | |
elif "mpt" in args.base_url.lower(): | |
model_type_tmp = "mpt" | |
elif "redpajama-incite-7b-instruct" in args.base_url.lower(): | |
model_type_tmp = "redpajama-instruct" | |
elif "redpajama" in args.base_url.lower(): | |
model_type_tmp = "redpajama" | |
elif "starchat" in args.base_url.lower(): | |
model_type_tmp = "starchat" | |
elif "camel" in args.base_url.lower(): | |
model_type_tmp = "camel" | |
elif "flan-alpaca" in args.base_url.lower(): | |
model_type_tmp = "flan-alpaca" | |
elif "openassistant/stablelm" in args.base_url.lower(): | |
model_type_tmp = "os-stablelm" | |
elif "stablelm" in args.base_url.lower(): | |
model_type_tmp = "stablelm" | |
elif "fastchat-t5" in args.base_url.lower(): | |
model_type_tmp = "t5-vicuna" | |
elif "koalpaca-polyglot" in args.base_url.lower(): | |
model_type_tmp = "koalpaca-polyglot" | |
elif "alpacagpt4" in args.ft_ckpt_url.lower(): | |
model_type_tmp = "alpaca-gpt4" | |
elif "alpaca" in args.ft_ckpt_url.lower(): | |
model_type_tmp = "alpaca" | |
elif "llama-deus" in args.ft_ckpt_url.lower(): | |
model_type_tmp = "llama-deus" | |
elif "vicuna-lora-evolinstruct" in args.ft_ckpt_url.lower(): | |
model_type_tmp = "evolinstruct-vicuna" | |
elif "alpacoom" in args.ft_ckpt_url.lower(): | |
model_type_tmp = "alpacoom" | |
elif "guanaco" in args.ft_ckpt_url.lower(): | |
model_type_tmp = "guanaco" | |
else: | |
print("unsupported model type") | |
quit() | |
print(f"determined model type: {model_type_tmp}") | |
device = "cpu" | |
if args.mode_cpu: | |
device = "cpu" | |
elif args.mode_mps: | |
device = "mps" | |
else: | |
device = "cuda" | |
try: | |
if model is not None: | |
del model | |
if stream_model is not None: | |
del stream_model | |
if tokenizer is not None: | |
del tokenizer | |
gc.collect() | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
elif device == "mps": | |
torch.mps.empty_cache() | |
except NameError: | |
pass | |
model_type = model_type_tmp | |
load_model = get_load_model(model_type_tmp) | |
model, tokenizer = load_model( | |
base=args.base_url, | |
finetuned=args.ft_ckpt_url, | |
mode_cpu=args.mode_cpu, | |
mode_mps=args.mode_mps, | |
mode_full_gpu=args.mode_full_gpu, | |
mode_8bit=args.mode_8bit, | |
mode_4bit=args.mode_4bit, | |
force_download_ckpt=args.force_download_ckpt, | |
local_files_only=args.local_files_only | |
) | |
model.eval() | |
model_thumbnail_tiny = args.thumbnail_tiny | |
gen_config, gen_config_raw = get_generation_config(args.gen_config_path) | |
gen_config_summarization, _ = get_generation_config(args.gen_config_summarization_path) | |
stream_model = model | |
def get_load_model(model_type): | |
if model_type == "alpaca" or \ | |
model_type == "alpaca-gpt4" or \ | |
model_type == "llama-deus" or \ | |
model_type == "nous-hermes" or \ | |
model_type == "lazarus" or \ | |
model_type == "chronos" or \ | |
model_type == "wizardlm" or \ | |
model_type == "openllama" or \ | |
model_type == "orcamini" or \ | |
model_type == "llama2" or \ | |
model_type == "upstage-llama": | |
return alpaca.load_model | |
elif model_type == "free-willy": | |
return freewilly.load_model | |
elif model_type == "stablelm" or model_type == "os-stablelm": | |
return stablelm.load_model | |
elif model_type == "koalpaca-polyglot": | |
return koalpaca.load_model | |
elif model_type == "kullm-polyglot": | |
return kullm.load_model | |
elif model_type == "flan-alpaca": | |
return flan_alpaca.load_model | |
elif model_type == "camel": | |
return camel.load_model | |
elif model_type == "t5-vicuna": | |
return t5_vicuna.load_model | |
elif model_type == "stable-vicuna": | |
return vicuna.load_model | |
elif model_type == "starchat": | |
return starchat.load_model | |
elif model_type == "wizard-coder": | |
return wizard_coder.load_model | |
elif model_type == "mpt": | |
return mpt.load_model | |
elif model_type == "redpajama" or \ | |
model_type == "redpajama-instruct": | |
return redpajama.load_model | |
elif model_type == "vicuna": | |
return vicuna.load_model | |
elif model_type == "evolinstruct-vicuna" or \ | |
model_type == "wizard-vicuna": | |
return alpaca.load_model | |
elif model_type == "alpacoom": | |
return bloom.load_model | |
elif model_type == "baize": | |
return baize.load_model | |
elif model_type == "guanaco": | |
return guanaco.load_model | |
elif model_type == "falcon" or model_type == "wizard-falcon": | |
return falcon.load_model | |
elif model_type == "replit-instruct": | |
return replit.load_model | |
elif model_type == "airoboros": | |
return airoboros.load_model | |
elif model_type == "samantha-vicuna": | |
return samantha_vicuna.load_model | |
elif model_type == "xgen": | |
return xgen.load_model | |
else: | |
return None | |
def get_generation_config(path): | |
with open(path, 'rb') as f: | |
generation_config = yaml.safe_load(f.read()) | |
generation_config = generation_config["generation_config"] | |
return GenerationConfig(**generation_config), generation_config | |
def get_constraints_config(path): | |
with open(path, 'rb') as f: | |
constraints_config = yaml.safe_load(f.read()) | |
return ConstraintsConfig(**constraints_config), constraints_config["constraints"] | |