Spaces:
Runtime error
Runtime error
add option for 8bit
Browse files- app/api.py +5 -3
- app/config.json +8 -2
app/api.py
CHANGED
@@ -136,13 +136,13 @@ async def text_generate(
|
|
136 |
return {"generated_text": generated_text, "processing_time": time_diff}
|
137 |
|
138 |
|
139 |
-
def get_text_generator(model_name: str, device: str = "cpu"):
|
140 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
141 |
print(f"hf_auth_token: {hf_auth_token}")
|
142 |
print(f"Loading model with device: {device}...")
|
143 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
144 |
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
|
145 |
-
load_in_8bit=
|
146 |
# model.to(device)
|
147 |
print("Model loaded")
|
148 |
return model, tokenizer
|
@@ -156,7 +156,9 @@ config = get_config()
|
|
156 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
157 |
text_generator = {}
|
158 |
for model_name in config["text-generator"]:
|
159 |
-
model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name],
|
|
|
|
|
160 |
text_generator[model_name] = {
|
161 |
"model": model,
|
162 |
"tokenizer": tokenizer
|
|
|
136 |
return {"generated_text": generated_text, "processing_time": time_diff}
|
137 |
|
138 |
|
139 |
+
def get_text_generator(model_name: str, load_in_8bit: bool = False, device: str = "cpu"):
|
140 |
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
141 |
print(f"hf_auth_token: {hf_auth_token}")
|
142 |
print(f"Loading model with device: {device}...")
|
143 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
144 |
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id,
|
145 |
+
load_in_8bit=load_in_8bit, device_map="auto", use_auth_token=hf_auth_token)
|
146 |
# model.to(device)
|
147 |
print("Model loaded")
|
148 |
return model, tokenizer
|
|
|
156 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
157 |
text_generator = {}
|
158 |
for model_name in config["text-generator"]:
|
159 |
+
model, tokenizer = get_text_generator(model_name=config["text-generator"][model_name]["name"],
|
160 |
+
load_in_8bit=config["text-generator"][model_name]["load_in_8bit"],
|
161 |
+
device=device)
|
162 |
text_generator[model_name] = {
|
163 |
"model": model,
|
164 |
"tokenizer": tokenizer
|
app/config.json
CHANGED
@@ -1,6 +1,12 @@
|
|
1 |
{
|
2 |
"text-generator": {
|
3 |
-
"indochat-tiny":
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
}
|
6 |
}
|
|
|
1 |
{
|
2 |
"text-generator": {
|
3 |
+
"indochat-tiny": {
|
4 |
+
"name": "cahya/indochat-tiny",
|
5 |
+
"load_in_8bit": false
|
6 |
+
},
|
7 |
+
"bloomz-1b1-instruct": {
|
8 |
+
"name": "cahya/bloomz-1b7-instruct",
|
9 |
+
"load_in_8bit": true
|
10 |
+
}
|
11 |
}
|
12 |
}
|