File size: 4,776 Bytes
3708810
1d56001
9ad7e40
3708810
9ad7e40
 
 
1d56001
9ad7e40
 
3708810
 
a6c46e0
3708810
 
 
a6c46e0
1d56001
 
 
3708810
9ad7e40
1d56001
ab740b3
 
9ad7e40
1d56001
9ad7e40
 
 
 
 
 
 
 
1d56001
 
 
 
 
9ad7e40
 
1d56001
 
9ad7e40
 
3708810
 
 
9ad7e40
 
 
 
 
1d56001
 
 
 
 
3708810
1d56001
 
 
 
 
 
 
 
3708810
 
1d56001
3708810
 
 
 
 
 
9ad7e40
 
1d56001
9ad7e40
 
3708810
 
 
 
 
1d56001
3708810
 
 
 
 
 
 
1d56001
 
3708810
 
 
 
 
9ad7e40
 
 
 
 
 
 
1d56001
 
 
3708810
1d56001
9ad7e40
3708810
 
1d56001
 
 
 
 
3708810
 
1d56001
3708810
1d56001
 
 
 
 
 
 
 
 
9ad7e40
3708810
6f7c28c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import gc
from string import Template
from threading import Thread

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BatchEncoding, TextIteratorStreamer


auth_token = os.environ.get("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
    "stabilityai/stable-vicuna-13b-fp16",
    use_auth_token=auth_token if auth_token else True,
)
model = AutoModelForCausalLM.from_pretrained(
    "stabilityai/stable-vicuna-13b-fp16",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    use_auth_token=auth_token if auth_token else True,
)
model.eval()


max_context_length = model.config.max_position_embeddings
max_new_tokens = 768


prompt_template = Template("""\
### Human: $human
### Assistant: $bot\
""")


system_prompt = "### Assistant: I am StableVicuna, a large language model created by Stability AI. I am here to chat!"
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt")
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1)


def bot(history):
    history = history or []

    # Inject prompt formatting into the history
    prompt_history = []
    for human, bot in history:
        if bot is not None:
            bot = bot.replace("<br>", "\n")
            bot = bot.rstrip()
        prompt_history.append(
            prompt_template.substitute(
                human=human, bot=bot if bot is not None else "")
        )

    msg_tokens = tokenizer(
        "\n\n".join(prompt_history).strip(),
        return_tensors="pt",
        add_special_tokens=False  # Use <BOS> from the system prompt
    )

    # Take only the most recent context up to the max context length and prepend the
    # system prompt with the messages
    max_tokens = -max_context_length + max_new_tokens + max_sys_tokens
    inputs = BatchEncoding({
        k: torch.concat([system_prompt_tokens[k], msg_tokens[k][:, max_tokens:]], dim=-1)
        for k in msg_tokens
    }).to('cuda')
    # Remove `token_type_ids` b/c it's not yet supported for LLaMA `transformers` models
    if inputs.get("token_type_ids", None) is not None:
        inputs.pop("token_type_ids")

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=1.0,
        temperature=1.0,
    )
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    partial_text = ""
    for new_text in streamer:
        # Process out the prompt separator
        new_text = new_text.replace("<br>", "\n")
        if "###" in new_text:
            new_text = new_text.split("###")[0]
            partial_text += new_text.strip()
            history[-1][1] = partial_text
            break
        else:
            # Filter empty trailing new lines
            if new_text == "\n":
                new_text = new_text.strip()
            partial_text += new_text
            history[-1][1] = partial_text
        yield history
    return partial_text


def user(user_message, history):
    return "", history + [[user_message, None]]


with gr.Blocks() as demo:
    gr.Markdown("StableVicuna by Stability AI")
    gr.HTML("<a href='https://huggingface.co/stabilityai/stable-vicuna-13b-delta'><code>stabilityai/stable-vicuna-13b-delta</a>")
    gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stable-vicuna?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''')

    chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500)
    state = gr.State([])
    with gr.Row():
        with gr.Column():
            msg = gr.Textbox(
                label="Send a message",
                placeholder="Send a message",
                show_label=False
            ).style(container=False)
        with gr.Column():
            with gr.Row():
                submit = gr.Button("Send")
                stop = gr.Button("Stop")
                clear = gr.Button("Clear History")

    submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
        fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)
    submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
        fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True)

    stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False)
    clear.click(lambda: None, None, [chatbot], queue=True)

demo.queue(max_size=32, concurrency_count=2)
demo.launch()