Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -31,29 +31,27 @@ def get_input_ids(inst, history):
|
|
31 |
return input_ids
|
32 |
|
33 |
|
34 |
-
|
35 |
-
def chat(inst, history, temperature, top_p, repetition_penalty):
|
36 |
with torch.no_grad():
|
37 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
38 |
-
input_ids = get_input_ids(inst, history)
|
39 |
-
if len(input_ids) > MAX_INPUT_LIMIT:
|
40 |
-
yield "The input is too long, please clear the history."
|
41 |
-
return
|
42 |
-
generate_config = dict(
|
43 |
-
max_new_tokens=MAX_NEW_TOKENS,
|
44 |
-
temperature=temperature,
|
45 |
-
top_p=top_p,
|
46 |
-
repetition_penalty=repetition_penalty
|
47 |
-
)
|
48 |
-
print(generate_config)
|
49 |
-
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
|
50 |
-
streamer=streamer, **generate_config)
|
51 |
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
additional_inputs = [
|
@@ -93,7 +91,8 @@ gr.ChatInterface(chat,
|
|
93 |
description='Hello, I am Blossom, an open source conversational large language model.🌠'
|
94 |
'<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
|
95 |
theme="soft",
|
96 |
-
examples=[["Hello"], ["What is MBTI"], ["用Python实现二分查找"],
|
|
|
97 |
additional_inputs=additional_inputs,
|
98 |
additional_inputs_accordion=gr.Accordion(label="Config", open=True),
|
99 |
clear_btn="🗑️Clear",
|
|
|
31 |
return input_ids
|
32 |
|
33 |
|
34 |
+
def generate(generation_kwargs):
|
|
|
35 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
37 |
|
38 |
+
|
39 |
+
@spaces.GPU
|
40 |
+
def chat(inst, history, temperature, top_p, repetition_penalty):
|
41 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
42 |
+
input_ids = get_input_ids(inst, history)
|
43 |
+
if len(input_ids) > MAX_INPUT_LIMIT:
|
44 |
+
yield "The input is too long, please clear the history."
|
45 |
+
return
|
46 |
+
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device),
|
47 |
+
streamer=streamer, do_sample=True, max_new_tokens=MAX_NEW_TOKENS,
|
48 |
+
temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)
|
49 |
+
generate(generation_kwargs)
|
50 |
+
|
51 |
+
outputs = ""
|
52 |
+
for new_text in streamer:
|
53 |
+
outputs += new_text
|
54 |
+
yield outputs
|
55 |
|
56 |
|
57 |
additional_inputs = [
|
|
|
91 |
description='Hello, I am Blossom, an open source conversational large language model.🌠'
|
92 |
'<a href="https://github.com/Azure99/BlossomLM">GitHub</a>',
|
93 |
theme="soft",
|
94 |
+
examples=[["Hello"], ["What is MBTI"], ["用Python实现二分查找"],
|
95 |
+
["为switch写一篇小红书种草文案,带上emoji"]],
|
96 |
additional_inputs=additional_inputs,
|
97 |
additional_inputs_accordion=gr.Accordion(label="Config", open=True),
|
98 |
clear_btn="🗑️Clear",
|