Update utils/inference.py
Browse files- utils/inference.py +6 -6
utils/inference.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from transformers import
|
3 |
from peft import PeftModel
|
4 |
from typing import Iterator
|
5 |
from variables import SYSTEM, HUMAN, AI
|
@@ -24,15 +24,15 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
|
|
24 |
device = "mps"
|
25 |
except:
|
26 |
pass
|
27 |
-
tokenizer =
|
28 |
if device == "cuda":
|
29 |
-
model =
|
30 |
base_model,
|
31 |
load_in_8bit=load_8bit,
|
32 |
torch_dtype=torch.float16
|
33 |
)
|
34 |
elif device == "mps":
|
35 |
-
model =
|
36 |
base_model,
|
37 |
device_map={"": device}
|
38 |
)
|
@@ -44,7 +44,7 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
|
|
44 |
torch_dtype=torch.float16,
|
45 |
)
|
46 |
else:
|
47 |
-
model =
|
48 |
base_model,
|
49 |
device_map={"": device},
|
50 |
low_cpu_mem_usage=True,
|
@@ -76,7 +76,7 @@ shared_state = State()
|
|
76 |
def decode(
|
77 |
input_ids: torch.Tensor,
|
78 |
model: PeftModel,
|
79 |
-
tokenizer:
|
80 |
stop_words: list,
|
81 |
max_length: int,
|
82 |
temperature: float = 1.0,
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
from peft import PeftModel
|
4 |
from typing import Iterator
|
5 |
from variables import SYSTEM, HUMAN, AI
|
|
|
24 |
device = "mps"
|
25 |
except:
|
26 |
pass
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
28 |
if device == "cuda":
|
29 |
+
model = AutoModelForCausalLM.from_pretrained(
|
30 |
base_model,
|
31 |
load_in_8bit=load_8bit,
|
32 |
torch_dtype=torch.float16
|
33 |
)
|
34 |
elif device == "mps":
|
35 |
+
model = AutoModelForCausalLM.from_pretrained(
|
36 |
base_model,
|
37 |
device_map={"": device}
|
38 |
)
|
|
|
44 |
torch_dtype=torch.float16,
|
45 |
)
|
46 |
else:
|
47 |
+
model = AutoModelForCausalLM.from_pretrained(
|
48 |
base_model,
|
49 |
device_map={"": device},
|
50 |
low_cpu_mem_usage=True,
|
|
|
76 |
def decode(
|
77 |
input_ids: torch.Tensor,
|
78 |
model: PeftModel,
|
79 |
+
tokenizer: AutoTokenizer,
|
80 |
stop_words: list,
|
81 |
max_length: int,
|
82 |
temperature: float = 1.0,
|