import argparse import gradio as gr from peft import AutoPeftModelForCausalLM import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread parser = argparse.ArgumentParser() parser.add_argument("--model_path_or_id", type=str, default = "NousResearch/Llama-2-7b-hf", required = False, help = "Model ID or path to saved model") parser.add_argument("--lora_path", type=str, default = None, required = False, help = "Path to the saved lora adapter") args = parser.parse_args() if args.lora_path: # load base LLM model with PEFT Adapter model = AutoPeftModelForCausalLM.from_pretrained( args.lora_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, load_in_4bit=True, ) tokenizer = AutoTokenizer.from_pretrained(args.lora_path) else: model = AutoModelForCausalLM.from_pretrained( args.model_path_or_id, low_cpu_mem_usage=True, torch_dtype=torch.float16, load_in_4bit=True ) tokenizer = AutoTokenizer.from_pretrained(args.model_path_or_id) with gr.Blocks() as demo: gr.HTML(f"""

Instruction Chat Bot Demo

Model ID : {args.model_path_or_id}

Peft Adapter : {args.lora_path}

""") chat_history = gr.Chatbot(label = "Instruction Bot") msg = gr.Textbox(label = "Instruction") with gr.Accordion(label = "Generation Parameters", open = False): prompt_format = gr.Textbox( label = "Formatting prompt", value = "{instruction}", lines = 8) with gr.Row(): max_new_tokens = gr.Number(minimum = 25, maximum = 500, value = 100, label = "Max New Tokens") temperature = gr.Slider(minimum = 0, maximum = 1.0, value = 0.7, label = "Temperature") clear = gr.ClearButton([msg, chat_history]) def user(user_message, history): return "", [[user_message, None]] def bot(chat_history, prompt_format, max_new_tokens, temperature): # Format the instruction using the format string with key # {instruction} formatted_inst = prompt_format.format( instruction = chat_history[-1][0] ) # Tokenize the input input_ids = tokenizer( formatted_inst, return_tensors="pt", truncation=True).input_ids.cuda() # Support for streaming of tokens within generate requires # generation to run in a separate thread streamer = TextIteratorStreamer(tokenizer, skip_prompt = True) generation_kwargs = dict( input_ids = input_ids, streamer = streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=0.9, temperature=temperature, use_cache=True ) thread = Thread(target = model.generate, kwargs = generation_kwargs) thread.start() chat_history[-1][1] = "" for new_text in streamer: chat_history[-1][1] += new_text yield chat_history msg.submit(user,[msg, chat_history], [msg, chat_history], queue = False).then( bot, [chat_history, prompt_format, max_new_tokens, temperature], chat_history ) demo.queue() demo.launch()