Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2023 Bofeng Huang | |
""" | |
Modified from: https://huggingface.co/spaces/mosaicml/mpt-7b-chat/raw/main/app.py | |
Usage: | |
CUDA_VISIBLE_DEVICES=0 | |
python vigogne/demo/demo_chat.py \ | |
--base_model_name_or_path huggyllama/llama-7b \ | |
--lora_model_name_or_path bofenghuang/vigogne-chat-7b | |
""" | |
# import datetime | |
import logging | |
import os | |
import re | |
from threading import Event, Thread | |
from typing import List, Optional | |
# from uuid import uuid4 | |
import json | |
import gradio as gr | |
# import requests | |
import torch | |
from peft import PeftModel | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
GenerationConfig, | |
StoppingCriteriaList, | |
TextIteratorStreamer, | |
) | |
from vigogne.constants import ASSISTANT, USER | |
from vigogne.preprocess import generate_inference_chat_prompt | |
from vigogne.inference.inference_utils import StopWordsCriteria | |
logging.basicConfig( | |
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s", | |
datefmt="%Y-%m-%dT%H:%M:%SZ", | |
) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.DEBUG) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
if torch.backends.mps.is_available(): | |
device = "mps" | |
except: | |
pass | |
logger.info(f"Model will be loaded on device `{device}`") | |
# def log_conversation(conversation_id, history, messages, generate_kwargs): | |
# logging_url = os.getenv("LOGGING_URL", None) | |
# if logging_url is None: | |
# return | |
# timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S") | |
# data = { | |
# "conversation_id": conversation_id, | |
# "timestamp": timestamp, | |
# "history": history, | |
# "messages": messages, | |
# "generate_kwargs": generate_kwargs, | |
# } | |
# try: | |
# requests.post(logging_url, json=data) | |
# except requests.exceptions.RequestException as e: | |
# print(f"Error logging conversation: {e}") | |
def user(message, history): | |
# Append the user's message to the conversation history | |
return "", history + [[message, ""]] | |
# def get_uuid(): | |
# return str(uuid4()) | |
def main( | |
base_model_name_or_path: str = "huggyllama/llama-7b", | |
lora_model_name_or_path: str = "bofenghuang/vigogne-chat-7b", | |
load_8bit: bool = False, | |
server_name: Optional[str] = "0.0.0.0", | |
server_port: Optional[str] = None, | |
share: bool = False, | |
): | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False) | |
if device == "cuda": | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name_or_path, | |
load_in_8bit=load_8bit, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_model_name_or_path, | |
torch_dtype=torch.float16, | |
) | |
elif device == "mps": | |
model = AutoModelForCausalLM.from_pretrained( | |
base_model_name_or_path, | |
device_map={"": device}, | |
torch_dtype=torch.float16, | |
) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_model_name_or_path, | |
device_map={"": device}, | |
torch_dtype=torch.float16, | |
) | |
else: | |
model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, device_map={"": device}, low_cpu_mem_usage=True) | |
model = PeftModel.from_pretrained( | |
model, | |
lora_model_name_or_path, | |
device_map={"": device}, | |
) | |
if not load_8bit and device != "cpu": | |
model.half() # seems to fix bugs for some users. | |
model.eval() | |
# NB | |
stop_words = [f"<|{ASSISTANT}|>", f"<|{USER}|>"] | |
stop_words_criteria = StopWordsCriteria(stop_words=stop_words, tokenizer=tokenizer) | |
pattern_trailing_stop_words = re.compile(rf'(?:{"|".join([re.escape(stop_word) for stop_word in stop_words])})\W*$') | |
def bot(history, max_new_tokens, temperature, top_p, top_k, repetition_penalty, conversation_id=None): | |
# logger.info(f"History: {json.dumps(history, indent=4, ensure_ascii=False)}") | |
# Construct the input message string for the model by concatenating the current system message and conversation history | |
messages = generate_inference_chat_prompt(history, tokenizer) | |
logger.info(messages) | |
assert messages is not None, "User input is too long!" | |
# Tokenize the messages string | |
input_ids = tokenizer(messages, return_tensors="pt")["input_ids"].to(device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
generation_config=GenerationConfig( | |
temperature=temperature, | |
do_sample=temperature > 0.0, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
max_new_tokens=max_new_tokens, | |
), | |
streamer=streamer, | |
stopping_criteria=StoppingCriteriaList([stop_words_criteria]), | |
) | |
# stream_complete = Event() | |
def generate_and_signal_complete(): | |
model.generate(**generate_kwargs) | |
# stream_complete.set() | |
# def log_after_stream_complete(): | |
# stream_complete.wait() | |
# log_conversation( | |
# conversation_id, | |
# history, | |
# messages, | |
# { | |
# "top_k": top_k, | |
# "top_p": top_p, | |
# "temperature": temperature, | |
# "repetition_penalty": repetition_penalty, | |
# }, | |
# ) | |
t1 = Thread(target=generate_and_signal_complete) | |
t1.start() | |
# t2 = Thread(target=log_after_stream_complete) | |
# t2.start() | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
# NB | |
new_text = pattern_trailing_stop_words.sub("", new_text) | |
partial_text += new_text | |
history[-1][1] = partial_text | |
yield history | |
logger.info(f"Response: {history[-1][1]}") | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
css=".disclaimer {font-variant-caps: all-small-caps;}", | |
) as demo: | |
# conversation_id = gr.State(get_uuid) | |
gr.Markdown( | |
"""<h1><center>π¦ Vigogne Chat</center></h1> | |
This demo is of [Vigogne-Chat-7B](https://huggingface.co/bofenghuang/vigogne-chat-7b). It's based on [LLaMA-7B](https://github.com/facebookresearch/llama) finetuned to conduct French π«π· dialogues between a user and an AI assistant. | |
For more information, please visit the [Github repo](https://github.com/bofenghuang/vigogne) of the Vigogne project. | |
""" | |
) | |
chatbot = gr.Chatbot().style(height=500) | |
with gr.Row(): | |
with gr.Column(): | |
msg = gr.Textbox( | |
label="Chat Message Box", | |
placeholder="Chat Message Box", | |
show_label=False, | |
).style(container=False) | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
stop = gr.Button("Stop") | |
clear = gr.Button("Clear") | |
with gr.Row(): | |
with gr.Accordion("Advanced Options:", open=False): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
max_new_tokens = gr.Slider( | |
label="Max New Tokens", | |
value=512, | |
minimum=0, | |
maximum=1024, | |
step=1, | |
interactive=True, | |
info="The Max number of new tokens to generate.", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
temperature = gr.Slider( | |
label="Temperature", | |
value=0.1, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
interactive=True, | |
info="Higher values produce more diverse outputs.", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
value=1.0, | |
minimum=0.0, | |
maximum=1, | |
step=0.01, | |
interactive=True, | |
info=( | |
"Sample from the smallest possible set of tokens whose cumulative probability " | |
"exceeds top_p. Set to 1 to disable and sample from all tokens." | |
), | |
) | |
with gr.Column(): | |
with gr.Row(): | |
top_k = gr.Slider( | |
label="Top-k", | |
value=0, | |
minimum=0.0, | |
maximum=200, | |
step=1, | |
interactive=True, | |
info="Sample from a shortlist of top-k tokens β 0 to disable and sample from all tokens.", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
repetition_penalty = gr.Slider( | |
label="Repetition Penalty", | |
value=1.0, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.1, | |
interactive=True, | |
info="Penalize repetition β 1.0 to disable.", | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"Disclaimer: Vigogne is still under development, and there are many limitations that have to be addressed. Please note that it is possible that the model generates harmful or biased content, incorrect information or generally unhelpful answers.", | |
elem_classes=["disclaimer"], | |
) | |
with gr.Row(): | |
gr.Markdown( | |
"Acknowledgements: This demo is built on top of [MPT-7B-Chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat). Thanks for their contribution!", | |
elem_classes=["disclaimer"], | |
) | |
submit_event = msg.submit( | |
fn=user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).then( | |
fn=bot, | |
inputs=[ | |
chatbot, | |
max_new_tokens, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
# conversation_id, | |
], | |
outputs=chatbot, | |
queue=True, | |
) | |
submit_click_event = submit.click( | |
fn=user, | |
inputs=[msg, chatbot], | |
outputs=[msg, chatbot], | |
queue=False, | |
).then( | |
fn=bot, | |
inputs=[ | |
chatbot, | |
max_new_tokens, | |
temperature, | |
top_p, | |
top_k, | |
repetition_penalty, | |
# conversation_id, | |
], | |
outputs=chatbot, | |
queue=True, | |
) | |
stop.click( | |
fn=None, | |
inputs=None, | |
outputs=None, | |
cancels=[submit_event, submit_click_event], | |
queue=False, | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue(max_size=128, concurrency_count=2) | |
demo.launch(enable_queue=True, share=share, server_name=server_name, server_port=server_port) | |
main() | |