#!/usr/bin/env python # coding=utf-8 # Copyright 2023 Bofeng Huang """ Modified from: https://huggingface.co/spaces/mosaicml/mpt-7b-chat/raw/main/app.py Usage: CUDA_VISIBLE_DEVICES=0 python vigogne/demo/demo_chat.py \ --base_model_name_or_path huggyllama/llama-7b \ --lora_model_name_or_path bofenghuang/vigogne-chat-7b """ # import datetime import logging import os import re from threading import Event, Thread from typing import List, Optional # from uuid import uuid4 import json import gradio as gr # import requests import torch from peft import PeftModel from transformers import ( AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StoppingCriteriaList, TextIteratorStreamer, ) from vigogne.constants import ASSISTANT, USER from vigogne.preprocess import generate_inference_chat_prompt from vigogne.inference.inference_utils import StopWordsCriteria logging.basicConfig( format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", datefmt="%Y-%m-%dT%H:%M:%SZ", ) logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) device = "cuda" if torch.cuda.is_available() else "cpu" try: if torch.backends.mps.is_available(): device = "mps" except: pass logger.info(f"Model will be loaded on device `{device}`") # def log_conversation(conversation_id, history, messages, generate_kwargs): # logging_url = os.getenv("LOGGING_URL", None) # if logging_url is None: # return # timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") # data = { # "conversation_id": conversation_id, # "timestamp": timestamp, # "history": history, # "messages": messages, # "generate_kwargs": generate_kwargs, # } # try: # requests.post(logging_url, json=data) # except requests.exceptions.RequestException as e: # print(f"Error logging conversation: {e}") def user(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] # def get_uuid(): # return str(uuid4()) def main( base_model_name_or_path: str = "huggyllama/llama-7b", lora_model_name_or_path: str = "bofenghuang/vigogne-chat-7b", load_8bit: bool = False, server_name: Optional[str] = "0.0.0.0", server_port: Optional[str] = None, share: bool = False, ): tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False) if device == "cuda": model = AutoModelForCausalLM.from_pretrained( base_model_name_or_path, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map="auto", ) model = PeftModel.from_pretrained( model, lora_model_name_or_path, torch_dtype=torch.float16, ) elif device == "mps": model = AutoModelForCausalLM.from_pretrained( base_model_name_or_path, device_map={"": device}, torch_dtype=torch.float16, ) model = PeftModel.from_pretrained( model, lora_model_name_or_path, device_map={"": device}, torch_dtype=torch.float16, ) else: model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, device_map={"": device}, low_cpu_mem_usage=True) model = PeftModel.from_pretrained( model, lora_model_name_or_path, device_map={"": device}, ) if not load_8bit and device != "cpu": model.half() # seems to fix bugs for some users. model.eval() # NB stop_words = [f"<|{ASSISTANT}|>", f"<|{USER}|>"] stop_words_criteria = StopWordsCriteria(stop_words=stop_words, tokenizer=tokenizer) pattern_trailing_stop_words = re.compile(rf'(?:{"|".join([re.escape(stop_word) for stop_word in stop_words])})\W*$') def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, conversation_id=None): # logger.info(f"History: {json.dumps(history, indent=4, ensure_ascii=False)}") # Construct the input message string for the model by concatenating the current system message and conversation history messages = generate_inference_chat_prompt(history, tokenizer) logger.info(messages) assert messages is not None, "User input is too long!" # Tokenize the messages string input_ids = tokenizer(messages, return_tensors="pt")["input_ids"].to(device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, generation_config=GenerationConfig( temperature=temperature, do_sample=temperature > 0.0, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens, ), streamer=streamer, stopping_criteria=StoppingCriteriaList([stop_words_criteria]), ) # stream_complete = Event() def generate_and_signal_complete(): model.generate(**generate_kwargs) # stream_complete.set() # def log_after_stream_complete(): # stream_complete.wait() # log_conversation( # conversation_id, # history, # messages, # { # "top_k": top_k, # "top_p": top_p, # "temperature": temperature, # "repetition_penalty": repetition_penalty, # }, # ) t1 = Thread(target=generate_and_signal_complete) t1.start() # t2 = Thread(target=log_after_stream_complete) # t2.start() # Initialize an empty string to store the generated text partial_text = "" for new_text in streamer: # NB new_text = pattern_trailing_stop_words.sub("", new_text) partial_text += new_text history[-1][1] = partial_text yield history logger.info(f"Response: {history[-1][1]}") with gr.Blocks( theme=gr.themes.Soft(), css=".disclaimer {font-variant-caps: all-small-caps;}", ) as demo: # conversation_id = gr.State(get_uuid) gr.Markdown( """