DuckyBlender commited on
Commit
a755228
1 Parent(s): 1bf6bb5

fixed cpu inference?

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -11,27 +11,34 @@ import torch
11
  if torch.cuda.is_available():
12
  device = torch.device("cuda")
13
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
 
 
 
14
  import subprocess
15
  subprocess.run(
16
  "pip install flash_attn --no-build-isolation --break-system-packages",
17
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
18
  shell=True,
19
  )
 
 
 
 
 
 
 
 
 
 
 
20
  else:
21
  device = torch.device("cpu")
 
22
  print("Using CPU")
23
 
24
  # Uncomment and set your Hugging Face token if needed
25
  token = os.environ["HF_TOKEN"]
26
 
27
- # Configure 4-bit quantization for model loading
28
- bnb_config = BitsAndBytesConfig(
29
- load_in_4bit=True,
30
- bnb_4bit_use_double_quant=True,
31
- bnb_4bit_quant_type="nf4",
32
- bnb_4bit_compute_dtype=torch.bfloat16,
33
- attn_implementation="flash_attention_2",
34
- )
35
 
36
  # Load the Phi-3 model and tokenizer
37
  print("Loading model and tokenizer...")
 
11
  if torch.cuda.is_available():
12
  device = torch.device("cuda")
13
  print(f"Using GPU: {torch.cuda.get_device_name(device)}")
14
+
15
+ # Install the Flash Attention library
16
+ print("Installing Flash Attention library...")
17
  import subprocess
18
  subprocess.run(
19
  "pip install flash_attn --no-build-isolation --break-system-packages",
20
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
21
  shell=True,
22
  )
23
+ print("Flash Attention library installed")
24
+
25
+ # Configure 4-bit quantization for model loading
26
+ bnb_config = BitsAndBytesConfig(
27
+ load_in_4bit=True,
28
+ bnb_4bit_use_double_quant=True,
29
+ bnb_4bit_quant_type="nf4",
30
+ bnb_4bit_compute_dtype=torch.bfloat16,
31
+ attn_implementation="flash_attention_2",
32
+ )
33
+
34
  else:
35
  device = torch.device("cpu")
36
+ bnb_config = None
37
  print("Using CPU")
38
 
39
  # Uncomment and set your Hugging Face token if needed
40
  token = os.environ["HF_TOKEN"]
41
 
 
 
 
 
 
 
 
 
42
 
43
  # Load the Phi-3 model and tokenizer
44
  print("Loading model and tokenizer...")