# Copyright 2023 MosaicML spaces authors # SPDX-License-Identifier: Apache-2.0 import datetime import os from threading import Event, Thread from uuid import uuid4 from peft import PeftModel import gradio as gr import requests import torch import transformers from transformers import ( AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer, ) # model_name = "lmsys/vicuna-7b-delta-v1.1" model_name = "timdettmers/guanaco-33b-merged" max_new_tokens = 1536 auth_token = os.getenv("HF_TOKEN", None) print(f"Starting to load the model {model_name} into memory") m = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=transformers.BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4' # {'fp4', 'nf4'} ), torch_dtype=torch.bfloat16, device_map={"": 0} ) #m = PeftModel.from_pretrained(m, 'timdettmers/guanaco-65b') m.eval() tok = LlamaTokenizer.from_pretrained("decapoda-research/llama-33b-hf") tok.bos_token_id = 1 stop_token_ids = [0] print(f"Successfully loaded the model {model_name} into memory") start_message = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: for stop_id in stop_token_ids: if input_ids[0][-1] == stop_id: return True return False def convert_history_to_text(history): text = start_message + "".join( [ "".join( [ f"### Human: {item[0]}\n", f"### Assistant: {item[1]}\n", ] ) for item in history[:-1] ] ) text += "".join( [ "".join( [ f"### Human: {history[-1][0]}\n", f"### Assistant: {history[-1][1]}\n", ] ) ] ) return text def log_conversation(conversation_id, history, messages, generate_kwargs): logging_url = os.getenv("LOGGING_URL", None) if logging_url is None: return timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") data = { "conversation_id": conversation_id, "timestamp": timestamp, "history": history, "messages": messages, "generate_kwargs": generate_kwargs, } try: requests.post(logging_url, json=data) except requests.exceptions.RequestException as e: print(f"Error logging conversation: {e}") def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id): print(f"history: {history}") # Initialize a StopOnTokens object stop = StopOnTokens() # Construct the input message string for the model by concatenating the current system message and conversation history messages = convert_history_to_text(history) # Tokenize the messages string input_ids = tok(messages, return_tensors="pt").input_ids input_ids = input_ids.to(m.device) streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, streamer=streamer, stopping_criteria=StoppingCriteriaList([stop]), ) stream_complete = Event() def generate_and_signal_complete(): m.generate(**generate_kwargs) stream_complete.set() def log_after_stream_complete(): stream_complete.wait() log_conversation( conversation_id, history, messages, { "top_k": top_k, "top_p": top_p, "temperature": temperature, "repetition_penalty": repetition_penalty, }, ) t1 = Thread(target=generate_and_signal_complete) t1.start() t2 = Thread(target=log_after_stream_complete) t2.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: partial_text += new_text history[-1][1] = partial_text yield history def get_uuid(): return str(uuid4()) with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: conversation_id = gr.State(get_uuid) gr.Markdown( """