hfl-rc's picture
Update app.py
ff72155 verified
raw
history blame contribute delete
No virus
4.16 kB
from threading import Thread
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
BANNER_HTML = """
<p align="center">
<a href="https://github.com/ymcui/Chinese-LLaMA-Alpaca-3">
<img src="https://ymcui.com/images/chinese-llama-alpaca-3-banner.png" width="600"/>
</a>
</p>
<h3>
<center>Check our <a href='https://github.com/ymcui/Chinese-LLaMA-Alpaca-3' target='_blank'>Chinese-LLaMA-Alpaca-3 GitHub Project</a> for more information.
</center>
</h3>
<p>
<center><em>The demo is mainly for academic purposes. Illegal usages are prohibited. Default model: <a href="https://huggingface.co/hfl/llama-3-chinese-8b-instruct-v3">hfl/llama-3-chinese-8b-instruct-v3</a></em></center>
</p>
"""
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. 你是一个乐于助人的助手。"
# Load different instruct models based on the selected version
def load_model(version):
global tokenizer, model
if version == "v1":
model_name = "hfl/llama-3-chinese-8b-instruct"
elif version == "v2":
model_name = "hfl/llama-3-chinese-8b-instruct-v2"
elif version == "v3":
model_name = "hfl/llama-3-chinese-8b-instruct-v3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2")
return f"Model {model_name} loaded."
@spaces.GPU(duration=50)
def stream_chat(message: str, history: list, system_prompt: str, model_version: str, temperature: float, max_new_tokens: int):
conversation = [{"role": "system", "content": system_prompt or DEFAULT_SYSTEM_PROMPT}]
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"eos_token_id": terminators,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_k": 40,
"top_p": 0.9,
"num_beams": 1,
"repetition_penalty": 1.1,
"do_sample": temperature != 0,
}
generation_thread = Thread(target=model.generate, kwargs=generate_kwargs)
generation_thread.start()
output = ""
for new_token in streamer:
output += new_token
yield output
chatbot = gr.Chatbot(height=500)
with gr.Blocks() as demo:
gr.HTML(BANNER_HTML)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="Parameters / 参数设置", open=False, render=False),
additional_inputs=[
gr.Text(value=DEFAULT_SYSTEM_PROMPT, label="System Prompt / 系统提示词", render=False),
gr.Radio(choices=["v1", "v2", "v3"], label="Model Version / 模型版本", value="v3", interactive=False, render=False),
gr.Slider(minimum=0, maximum=1.5, step=0.1, value=0.6, label="Temperature / 温度系数", render=False),
gr.Slider(minimum=128, maximum=2048, step=1, value=512, label="Max new tokens / 最大生成长度", render=False),
],
cache_examples=False,
submit_btn="Send / 发送",
stop_btn="Stop / 停止",
retry_btn="🔄 Retry / 重试",
undo_btn="↩️ Undo / 撤销",
clear_btn="🗑️ Clear / 清空",
)
if __name__ == "__main__":
load_model("v3") # Load the default model
demo.launch()