gaochangkuan
commited on
Commit
•
5bebd03
1
Parent(s):
f27de23
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,9 @@ import gradio as gr
|
|
4 |
from transformers import GemmaTokenizer, AutoModelForCausalLM
|
5 |
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextIteratorStreamer
|
6 |
from threading import Thread
|
|
|
7 |
|
|
|
8 |
# Set an environment variable
|
9 |
token = os.getenv('HUGGINGFACE_TOKEN')
|
10 |
|
@@ -15,13 +17,13 @@ model= AutoModelForCausalLM.from_pretrained(
|
|
15 |
torch_dtype= torch.bfloat16,
|
16 |
low_cpu_mem_usage= True,
|
17 |
token=token,
|
18 |
-
|
19 |
-
|
20 |
)
|
21 |
|
22 |
|
23 |
model = torch.compile(model)
|
24 |
-
model.to("cuda")
|
25 |
model = model.eval()
|
26 |
|
27 |
|
@@ -93,7 +95,7 @@ def chat_zhuji(
|
|
93 |
conversation.extend([{"role": "system","content": "",},{"role": "user", "content": user}, {"role": "<|assistant|>", "content": assistant}])
|
94 |
conversation.append({"role": "user", "content": message})
|
95 |
|
96 |
-
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(
|
97 |
|
98 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
99 |
|
|
|
4 |
from transformers import GemmaTokenizer, AutoModelForCausalLM
|
5 |
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextIteratorStreamer
|
6 |
from threading import Thread
|
7 |
+
import subprocess
|
8 |
|
9 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
10 |
# Set an environment variable
|
11 |
token = os.getenv('HUGGINGFACE_TOKEN')
|
12 |
|
|
|
17 |
torch_dtype= torch.bfloat16,
|
18 |
low_cpu_mem_usage= True,
|
19 |
token=token,
|
20 |
+
attn_implementation="flash_attention_2",
|
21 |
+
device_map= "auto"
|
22 |
)
|
23 |
|
24 |
|
25 |
model = torch.compile(model)
|
26 |
+
#model.to("cuda")
|
27 |
model = model.eval()
|
28 |
|
29 |
|
|
|
95 |
conversation.extend([{"role": "system","content": "",},{"role": "user", "content": user}, {"role": "<|assistant|>", "content": assistant}])
|
96 |
conversation.append({"role": "user", "content": message})
|
97 |
|
98 |
+
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
|
99 |
|
100 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
101 |
|