gaochangkuan commited on
Commit
5bebd03
1 Parent(s): f27de23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -4,7 +4,9 @@ import gradio as gr
4
  from transformers import GemmaTokenizer, AutoModelForCausalLM
5
  from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
 
7
 
 
8
  # Set an environment variable
9
  token = os.getenv('HUGGINGFACE_TOKEN')
10
 
@@ -15,13 +17,13 @@ model= AutoModelForCausalLM.from_pretrained(
15
  torch_dtype= torch.bfloat16,
16
  low_cpu_mem_usage= True,
17
  token=token,
18
- #attn_implementation="flash_attention_2",
19
- #device_map= "auto"
20
  )
21
 
22
 
23
  model = torch.compile(model)
24
- model.to("cuda")
25
  model = model.eval()
26
 
27
 
@@ -93,7 +95,7 @@ def chat_zhuji(
93
  conversation.extend([{"role": "system","content": "",},{"role": "user", "content": user}, {"role": "<|assistant|>", "content": assistant}])
94
  conversation.append({"role": "user", "content": message})
95
 
96
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to("cuda")
97
 
98
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
99
 
 
4
  from transformers import GemmaTokenizer, AutoModelForCausalLM
5
  from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
+ import subprocess
8
 
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
  # Set an environment variable
11
  token = os.getenv('HUGGINGFACE_TOKEN')
12
 
 
17
  torch_dtype= torch.bfloat16,
18
  low_cpu_mem_usage= True,
19
  token=token,
20
+ attn_implementation="flash_attention_2",
21
+ device_map= "auto"
22
  )
23
 
24
 
25
  model = torch.compile(model)
26
+ #model.to("cuda")
27
  model = model.eval()
28
 
29
 
 
95
  conversation.extend([{"role": "system","content": "",},{"role": "user", "content": user}, {"role": "<|assistant|>", "content": assistant}])
96
  conversation.append({"role": "user", "content": message})
97
 
98
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
99
 
100
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
101