booksouls commited on
Commit
0a1317c
·
verified ·
1 Parent(s): 8fb1919

add 4-bit quantization to handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -6
handler.py CHANGED
@@ -1,14 +1,24 @@
1
  import torch
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  from typing import Any
4
 
5
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
-
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
- self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
10
  self.tokenizer = AutoTokenizer.from_pretrained(path)
11
-
12
  def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
13
  inputs = data.get("inputs")
14
  parameters = data.get("parameters")
@@ -32,7 +42,7 @@ class EndpointHandler():
32
  )
33
 
34
  # Ensure the input_ids and the model are on the same device to prevent errors.
35
- input_ids = tokens.input_ids.to(device)
36
 
37
  # Gradient calculation is not needed for inference.
38
  with torch.no_grad():
 
1
  import torch
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
3
  from typing import Any
4
 
 
 
5
  class EndpointHandler():
6
  def __init__(self, path=""):
7
+ # bitsandbytes quantization is only supported on CUDA devices.
8
+ bits_and_bytes_config = BitsAndBytesConfig(
9
+ load_in_4bit=True,
10
+ bnb_4bit_compute_dtype=torch.bfloat16,
11
+ )
12
+ quantization_config = bits_and_bytes_config if torch.cuda.is_available() else None
13
+
14
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
16
+ path,
17
+ quantization_config=quantization_config,
18
+ device_map="auto",
19
+ )
20
  self.tokenizer = AutoTokenizer.from_pretrained(path)
21
+
22
  def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
23
  inputs = data.get("inputs")
24
  parameters = data.get("parameters")
 
42
  )
43
 
44
  # Ensure the input_ids and the model are on the same device to prevent errors.
45
+ input_ids = tokens.input_ids.to(self.device)
46
 
47
  # Gradient calculation is not needed for inference.
48
  with torch.no_grad():