Aakash Vardhan commited on
Commit
2910f9f
·
1 Parent(s): bb8f386
Files changed (2) hide show
  1. app.py +9 -1
  2. config.yaml +1 -1
app.py CHANGED
@@ -19,8 +19,16 @@ if "torch_dtype" in model_config:
19
  elif model_config["torch_dtype"] == "bfloat16":
20
  model_config["torch_dtype"] = torch.bfloat16
21
 
 
22
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
23
- model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config, **model_config)
 
 
 
 
 
 
 
24
 
25
  checkpoint_model = "checkpoint_dir/checkpoint-650"
26
 
 
19
  elif model_config["torch_dtype"] == "bfloat16":
20
  model_config["torch_dtype"] = torch.bfloat16
21
 
22
+ # Create quantization config
23
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
24
+
25
+ # Load the model with quantization config
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ quantization_config=quantization_config,
29
+ low_cpu_mem_usage=True,
30
+ **model_config
31
+ )
32
 
33
  checkpoint_model = "checkpoint_dir/checkpoint-650"
34
 
config.yaml CHANGED
@@ -5,4 +5,4 @@ model_config:
5
  use_cache: True
6
  attn_implementation: "eager"
7
  device_map: "cpu"
8
- load_in_8bit: True
 
5
  use_cache: True
6
  attn_implementation: "eager"
7
  device_map: "cpu"
8
+ # Remove the load_in_8bit line