Caslow commited on
Commit
5b7d699
·
1 Parent(s): af99375
Files changed (2) hide show
  1. config.py +1 -1
  2. inference.py +9 -7
config.py CHANGED
@@ -4,7 +4,7 @@ from typing import List, Optional
4
  # Hyperparameters for Model
5
  max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
6
  dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
7
- load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
8
  lora_r = 16 # Number of attention heads for LoRA
9
  lora_alpha = 16 # Alpha value for LoRA
10
  lora_dropout = 0 # Dropout rate for LoRA
 
4
  # Hyperparameters for Model
5
  max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
6
  dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
7
+ load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.
8
  lora_r = 16 # Number of attention heads for LoRA
9
  lora_alpha = 16 # Alpha value for LoRA
10
  lora_dropout = 0 # Dropout rate for LoRA
inference.py CHANGED
@@ -24,7 +24,7 @@ def load_model(
24
 
25
  kwargs = {
26
  "device_map": "cpu",
27
- "torch_dtype": torch.float32,
28
  "low_cpu_mem_usage": True,
29
  "_from_auto": False, # Prevent automatic quantization detection
30
  "quantization_config": None # Explicitly set no quantization
@@ -34,7 +34,7 @@ def load_model(
34
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
  pretrained_model_name_or_path=model_name,
37
- config = kwargs
38
  )
39
 
40
  model.eval() # Set model to evaluation mode
@@ -57,12 +57,14 @@ def prepare_input(
57
  Returns:
58
  torch.Tensor: Prepared input tensor
59
  """
60
- return tokenizer.apply_chat_template(
61
  messages,
62
- tokenize=True,
63
- add_generation_prompt=True,
64
- return_tensors="pt"
65
- ).to(device)
 
 
66
 
67
  def generate_response(
68
  model: AutoModelForCausalLM,
 
24
 
25
  kwargs = {
26
  "device_map": "cpu",
27
+ "torch_dtype": dtype,
28
  "low_cpu_mem_usage": True,
29
  "_from_auto": False, # Prevent automatic quantization detection
30
  "quantization_config": None # Explicitly set no quantization
 
34
 
35
  model = AutoModelForCausalLM.from_pretrained(
36
  pretrained_model_name_or_path=model_name,
37
+ **kwargs
38
  )
39
 
40
  model.eval() # Set model to evaluation mode
 
57
  Returns:
58
  torch.Tensor: Prepared input tensor
59
  """
60
+ return tokenizer(
61
  messages,
62
+ # tokenize=True,
63
+ # add_generation_prompt=True,
64
+ return_tensors="pt",
65
+ padding=True,
66
+ truncation=True,
67
+ )["input_ids"]
68
 
69
  def generate_response(
70
  model: AutoModelForCausalLM,