Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1 |
-
import
|
2 |
|
3 |
import gradio as gr
|
4 |
import spaces
|
5 |
-
from threading import Thread
|
6 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
7 |
import torch
|
|
|
8 |
|
9 |
MAX_INPUT_LIMIT = 3584
|
10 |
-
|
11 |
MODEL_NAME = "Azure99/blossom-v5.1-9b"
|
12 |
|
13 |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
|
14 |
-
|
15 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
16 |
|
17 |
GENERATE_CONFIG = dict(
|
@@ -22,7 +19,6 @@ GENERATE_CONFIG = dict(
|
|
22 |
repetition_penalty=1.05
|
23 |
)
|
24 |
|
25 |
-
|
26 |
def get_input_ids(inst, history):
|
27 |
prefix = ("A chat between a human and an artificial intelligence bot. "
|
28 |
"The bot gives helpful, detailed, and polite answers to the human's questions.")
|
@@ -46,27 +42,17 @@ def chat(inst, history):
|
|
46 |
with torch.no_grad():
|
47 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
48 |
input_ids = get_input_ids(inst, history)
|
49 |
-
print(len(input_ids))
|
50 |
if len(input_ids) > MAX_INPUT_LIMIT:
|
51 |
yield "The input is too long, please clear the history."
|
52 |
return
|
53 |
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
|
54 |
streamer=streamer, **GENERATE_CONFIG)
|
55 |
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
56 |
-
|
57 |
-
# stop watch
|
58 |
-
start = time.time()
|
59 |
outputs = ""
|
60 |
for new_text in streamer:
|
61 |
outputs += new_text
|
62 |
yield outputs
|
63 |
-
total_time = time.time() - start
|
64 |
-
output_token_len = len(tokenizer.encode(outputs, add_special_tokens=False))
|
65 |
-
speed = output_token_len / total_time
|
66 |
-
print("----------")
|
67 |
-
print(history)
|
68 |
-
print([inst, outputs])
|
69 |
-
print(f"Speed: {speed:.2f} tokens/s")
|
70 |
|
71 |
|
72 |
gr.ChatInterface(chat,
|
|
|
1 |
+
from threading import Thread
|
2 |
|
3 |
import gradio as gr
|
4 |
import spaces
|
|
|
|
|
5 |
import torch
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
7 |
|
8 |
MAX_INPUT_LIMIT = 3584
|
|
|
9 |
MODEL_NAME = "Azure99/blossom-v5.1-9b"
|
10 |
|
11 |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
|
|
|
12 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
13 |
|
14 |
GENERATE_CONFIG = dict(
|
|
|
19 |
repetition_penalty=1.05
|
20 |
)
|
21 |
|
|
|
22 |
def get_input_ids(inst, history):
|
23 |
prefix = ("A chat between a human and an artificial intelligence bot. "
|
24 |
"The bot gives helpful, detailed, and polite answers to the human's questions.")
|
|
|
42 |
with torch.no_grad():
|
43 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
44 |
input_ids = get_input_ids(inst, history)
|
|
|
45 |
if len(input_ids) > MAX_INPUT_LIMIT:
|
46 |
yield "The input is too long, please clear the history."
|
47 |
return
|
48 |
generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(model.device), do_sample=True,
|
49 |
streamer=streamer, **GENERATE_CONFIG)
|
50 |
Thread(target=model.generate, kwargs=generation_kwargs).start()
|
51 |
+
|
|
|
|
|
52 |
outputs = ""
|
53 |
for new_text in streamer:
|
54 |
outputs += new_text
|
55 |
yield outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
gr.ChatInterface(chat,
|