RinInori commited on
Commit
dba79d0
1 Parent(s): e68b314

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -5
app.py CHANGED
@@ -7,15 +7,11 @@ from transformers import Trainer
7
 
8
  BASE_MODEL = "TheBloke/vicuna-7B-1.1-HF"
9
 
10
- # Create a custom device map
11
- # This will vary based on the architecture of model and the memory capacity of GPU and CPU
12
- device_map = {0: [0, 1, 2], 1: [3, 4, 5]}
13
-
14
  model = LlamaForCausalLM.from_pretrained(
15
  BASE_MODEL,
16
  torch_dtype=torch.float16,
17
  load_in_8bit=True,
18
- device_map = {0: [0, 1, 2], 1: [3, 4, 5]},
19
  offload_folder="./cache",
20
  )
21
 
 
7
 
8
  BASE_MODEL = "TheBloke/vicuna-7B-1.1-HF"
9
 
 
 
 
 
10
  model = LlamaForCausalLM.from_pretrained(
11
  BASE_MODEL,
12
  torch_dtype=torch.float16,
13
  load_in_8bit=True,
14
+ device_map = "auto",
15
  offload_folder="./cache",
16
  )
17