File size: 3,951 Bytes
fc1301c
 
 
 
 
 
 
9aa57c1
 
 
 
 
 
fc1301c
 
ce7a117
fc1301c
 
 
436f61e
fc1301c
1e9bcab
fc1301c
1e9bcab
fc1301c
 
9eaa1af
d85c6c2
 
9eaa1af
d85c6c2
fc1301c
9eaa1af
 
 
 
fc1301c
 
 
 
 
9eaa1af
05d3b25
421b69f
fc1301c
 
 
811b009
 
 
 
 
 
 
fc1301c
 
 
 
 
 
e8c32f6
fc1301c
 
 
 
 
 
 
 
 
8705a13
fc1301c
8705a13
811b009
fc1301c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
735d6aa
fc1301c
 
 
 
 
 
ce7a117
fc1301c
 
 
 
 
 
ce7a117
fc1301c
 
 
 
 
 
 
 
 
 
4e87751
fc1301c
9eaa1af
 
fc1301c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    LlamaTokenizer,
)

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 50
MAX_INPUT_TOKEN_LENGTH = 512

DESCRIPTION = """\
# Phi-3-mini-4k-instruct

This Space demonstrates [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) by Microsoft. Please, check the original model card for details.

For additional detail on the model, including a link to the arXiv paper, refer to the [Hugging Face Paper page for Phi 3](https://huggingface.co/papers/2404.14219) .
"""

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct", 
    trust_remote_code=True, 
)
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")

if tokenizer.pad_token == None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.eos_token_id

def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.1,
    top_p: float = 0.4,
    top_k: int = 10,
    repetition_penalty: float = 1.4,
) -> Iterator[str]:
    
    historical_text = ""
    #Prepend the entire chat history to the message with new lines between each message
    for user, assistant in chat_history:
        historical_text += f"\n{user}\n{assistant}"
        
    if len(historical_text) > 0:
        message = historical_text + f"\n{message}"
    input_ids = tokenizer([message], return_tensors="pt").input_ids
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        pad_token_id = tokenizer.eos_token_id,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=5,
        early_stopping=False,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.1,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.5,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=3,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.4,
        ),
    ],
    stop_btn=None,
    cache_examples=False,
    examples=[
        ["Explain quantum physics in 5 words or less:"],
        ["Question: What do you call a bear with no teeth?\nAnswer:"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()