Caslow commited on
Commit
8a3e945
·
1 Parent(s): 78ad200
Files changed (1) hide show
  1. inference.py +3 -3
inference.py CHANGED
@@ -5,9 +5,9 @@ import torch
5
 
6
  def load_model(
7
  model_name: str,
8
- max_seq_length: int,
9
- dtype: torch.dtype,
10
- load_in_4bit: bool
11
  ) -> Tuple[AutoModelForCausalLM, any]:
12
  """
13
  Load and initialize the language model for inference.
 
5
 
6
  def load_model(
7
  model_name: str,
8
+ max_seq_length: int = 2048,
9
+ dtype: torch.dtype = torch.float32,
10
+ load_in_4bit: bool = False
11
  ) -> Tuple[AutoModelForCausalLM, any]:
12
  """
13
  Load and initialize the language model for inference.