cahya commited on
Commit
84d80f5
1 Parent(s): 02ae971

add option for 8bit

Browse files
Files changed (2) hide show
  1. app/api.py +5 -3
  2. app/config.json +8 -2
app/api.py CHANGED
@@ -136,13 +136,13 @@ async def text_generate(
136
  return {"generated_text": generated_text, "processing_time": time_diff}
137
 
138
 
139
- def get_text_generator(model_name: str, device: str = "cpu"):
140
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
141
  print(f"hf_auth_token: {hf_auth_token}")
142
  print(f"Loading model with device: {device}...")
143
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
144
  model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
145
- load_in_8bit=True, device_map="auto", use_auth_token=hf_auth_token)
146
  # model.to(device)
147
  print("Model loaded")
148
  return model, tokenizer
@@ -156,7 +156,9 @@ config = get_config()
156
  device = "cuda" if torch.cuda.is_available() else "cpu"
157
  text_generator = {}
158
  for model_name in config["text-generator"]:
159
- model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name], device=device)
 
 
160
  text_generator[model_name] = {
161
  "model": model,
162
  "tokenizer": tokenizer
 
136
  return {"generated_text": generated_text, "processing_time": time_diff}
137
 
138
 
139
+ def get_text_generator(model_name: str, load_in_8bit: bool = False, device: str = "cpu"):
140
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
141
  print(f"hf_auth_token: {hf_auth_token}")
142
  print(f"Loading model with device: {device}...")
143
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
144
  model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
145
+ load_in_8bit=load_in_8bit, device_map="auto", use_auth_token=hf_auth_token)
146
  # model.to(device)
147
  print("Model loaded")
148
  return model, tokenizer
 
156
  device = "cuda" if torch.cuda.is_available() else "cpu"
157
  text_generator = {}
158
  for model_name in config["text-generator"]:
159
+ model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name]["name"],
160
+ load_in_8bit=config["text-generator"][model_name]["load_in_8bit"],
161
+ device=device)
162
  text_generator[model_name] = {
163
  "model": model,
164
  "tokenizer": tokenizer
app/config.json CHANGED
@@ -1,6 +1,12 @@
1
  {
2
  "text-generator": {
3
- "indochat-tiny": "cahya/indochat-tiny",
4
- "bloomz-1b1-instruct": "cahya/bloomz-1b7-instruct"
 
 
 
 
 
 
5
  }
6
  }
 
1
  {
2
  "text-generator": {
3
+ "indochat-tiny": {
4
+ "name": "cahya/indochat-tiny",
5
+ "load_in_8bit": false
6
+ },
7
+ "bloomz-1b1-instruct": {
8
+ "name": "cahya/bloomz-1b7-instruct",
9
+ "load_in_8bit": true
10
+ }
11
  }
12
  }