Update app.py
Browse files
app.py
CHANGED
@@ -45,7 +45,7 @@ h1 {
|
|
45 |
|
46 |
# Load the tokenizer and model
|
47 |
tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0")
|
48 |
-
model = AutoModelForCausalLM.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0", device_map="
|
49 |
model=model.eval()
|
50 |
|
51 |
@spaces.GPU()
|
@@ -70,16 +70,30 @@ def chat_llm_jp_v2(message: str,
|
|
70 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
71 |
conversation.append({"role": "user", "content": message})
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
76 |
max_new_tokens=max_new_tokens,
|
77 |
do_sample=True,
|
78 |
-
top_p=0.95,
|
79 |
temperature=temperature,
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
|
85 |
# Gradio block
|
@@ -97,7 +111,7 @@ with gr.Blocks(fill_height=True, css=css) as demo:
|
|
97 |
additional_inputs=[
|
98 |
gr.Slider(minimum=0.1,
|
99 |
maximum=1,
|
100 |
-
step=0.
|
101 |
value=0.7,
|
102 |
label="Temperature",
|
103 |
render=False),
|
|
|
45 |
|
46 |
# Load the tokenizer and model
|
47 |
tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0")
|
48 |
+
model = AutoModelForCausalLM.from_pretrained("llm-jp/llm-jp-13b-instruct-full-ac_001_16x-dolly-ichikara_004_001_single-oasst-oasst2-v2.0", device_map="cuda", torch_dtype=torch.bfloat16)
|
49 |
model=model.eval()
|
50 |
|
51 |
@spaces.GPU()
|
|
|
70 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
71 |
conversation.append({"role": "user", "content": message})
|
72 |
|
73 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
74 |
+
|
75 |
+
generate_kwargs = dict(
|
76 |
+
input_ids= input_ids,
|
77 |
+
streamer=streamer,
|
78 |
max_new_tokens=max_new_tokens,
|
79 |
do_sample=True,
|
|
|
80 |
temperature=temperature,
|
81 |
+
top_p=0.95,
|
82 |
+
repetition_penalty=1.1,
|
83 |
+
eos_token_id=terminators,
|
84 |
+
)
|
85 |
+
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
|
86 |
+
if temperature == 0:
|
87 |
+
generate_kwargs['do_sample'] = False
|
88 |
+
|
89 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
90 |
+
t.start()
|
91 |
+
|
92 |
+
outputs = []
|
93 |
+
for text in streamer:
|
94 |
+
outputs.append(text)
|
95 |
+
print(outputs)
|
96 |
+
yield "".join(outputs)
|
97 |
|
98 |
|
99 |
# Gradio block
|
|
|
111 |
additional_inputs=[
|
112 |
gr.Slider(minimum=0.1,
|
113 |
maximum=1,
|
114 |
+
step=0.0,
|
115 |
value=0.7,
|
116 |
label="Temperature",
|
117 |
render=False),
|