dwb2023's picture
Update app.py
436f61e verified
raw
history blame
3.95 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
LlamaTokenizer,
)
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 50
MAX_INPUT_TOKEN_LENGTH = 512
DESCRIPTION = """\
# Phi-3-mini-4k-instruct
This Space demonstrates [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) by Microsoft. Please, check the original model card for details.
For additional detail on the model, including a link to the arXiv paper, refer to the [Hugging Face Paper page for Phi 3](https://huggingface.co/papers/2404.14219) .
"""
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
if tokenizer.pad_token == None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.eos_token_id
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.1,
top_p: float = 0.4,
top_k: int = 10,
repetition_penalty: float = 1.4,
) -> Iterator[str]:
historical_text = ""
#Prepend the entire chat history to the message with new lines between each message
for user, assistant in chat_history:
historical_text += f"\n{user}\n{assistant}"
if len(historical_text) > 0:
message = historical_text + f"\n{message}"
input_ids = tokenizer([message], return_tensors="pt").input_ids
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
pad_token_id = tokenizer.eos_token_id,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=5,
early_stopping=False,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.1,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.5,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=3,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.4,
),
],
stop_btn=None,
cache_examples=False,
examples=[
["Explain quantum physics in 5 words or less:"],
["Question: What do you call a bear with no teeth?\nAnswer:"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()