tianyang commited on
Commit
b143c1f
·
1 Parent(s): 84fc3d3

Update utils/inference.py

Browse files
Files changed (1) hide show
  1. utils/inference.py +6 -6
utils/inference.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import LlamaTokenizer, LlamaForCausalLM
3
  from peft import PeftModel
4
  from typing import Iterator
5
  from variables import SYSTEM, HUMAN, AI
@@ -24,15 +24,15 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
24
  device = "mps"
25
  except:
26
  pass
27
- tokenizer = LlamaTokenizer.from_pretrained(base_model)
28
  if device == "cuda":
29
- model = LlamaForCausalLM.from_pretrained(
30
  base_model,
31
  load_in_8bit=load_8bit,
32
  torch_dtype=torch.float16
33
  )
34
  elif device == "mps":
35
- model = LlamaForCausalLM.from_pretrained(
36
  base_model,
37
  device_map={"": device}
38
  )
@@ -44,7 +44,7 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
44
  torch_dtype=torch.float16,
45
  )
46
  else:
47
- model = LlamaForCausalLM.from_pretrained(
48
  base_model,
49
  device_map={"": device},
50
  low_cpu_mem_usage=True,
@@ -76,7 +76,7 @@ shared_state = State()
76
  def decode(
77
  input_ids: torch.Tensor,
78
  model: PeftModel,
79
- tokenizer: LlamaTokenizer,
80
  stop_words: list,
81
  max_length: int,
82
  temperature: float = 1.0,
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  from typing import Iterator
5
  from variables import SYSTEM, HUMAN, AI
 
24
  device = "mps"
25
  except:
26
  pass
27
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
28
  if device == "cuda":
29
+ model = AutoModelForCausalLM.from_pretrained(
30
  base_model,
31
  load_in_8bit=load_8bit,
32
  torch_dtype=torch.float16
33
  )
34
  elif device == "mps":
35
+ model = AutoModelForCausalLM.from_pretrained(
36
  base_model,
37
  device_map={"": device}
38
  )
 
44
  torch_dtype=torch.float16,
45
  )
46
  else:
47
+ model = AutoModelForCausalLM.from_pretrained(
48
  base_model,
49
  device_map={"": device},
50
  low_cpu_mem_usage=True,
 
76
  def decode(
77
  input_ids: torch.Tensor,
78
  model: PeftModel,
79
+ tokenizer: AutoTokenizer,
80
  stop_words: list,
81
  max_length: int,
82
  temperature: float = 1.0,