Christina Theodoris commited on
Commit
7b591f6
·
1 Parent(s): 69e6887

add quantization for pretrained model

Browse files
geneformer/in_silico_perturber.py CHANGED
@@ -62,7 +62,7 @@ class InSilicoPerturber:
62
  "genes_to_perturb": {"all", list},
63
  "combos": {0, 1},
64
  "anchor_gene": {None, str},
65
- "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"},
66
  "num_classes": {int},
67
  "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
68
  "cell_emb_style": {"mean_pool"},
@@ -132,7 +132,7 @@ class InSilicoPerturber:
132
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
133
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
134
  | anchor gene will be perturbed in combination with each other gene.
135
- model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"}
136
  | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
137
  num_classes : int
138
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
 
62
  "genes_to_perturb": {"all", list},
63
  "combos": {0, 1},
64
  "anchor_gene": {None, str},
65
+ "model_type": {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"},
66
  "num_classes": {int},
67
  "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
68
  "cell_emb_style": {"mean_pool"},
 
132
  | ENSEMBL ID of gene to use as anchor in combination perturbations.
133
  | For example, if combos=1 and anchor_gene="ENSG00000148400":
134
  | anchor gene will be perturbed in combination with each other gene.
135
+ model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "Pretrained-Quantized", "MTLCellClassifier-Quantized"}
136
  | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization).
137
  num_classes : int
138
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
geneformer/perturber_utils.py CHANGED
@@ -113,15 +113,22 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb):
113
 
114
  # load model to GPU
115
  def load_model(model_type, num_classes, model_directory, mode, quantize=False):
116
- if model_type == "MTLCellClassifier-Quantized":
 
 
 
 
 
117
  model_type = "MTLCellClassifier"
118
  quantize = True
 
 
119
 
120
  output_hidden_states = (mode == "eval")
121
 
122
  # Quantization logic
123
  if quantize:
124
- if model_type == "MTLCellClassifier":
125
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
126
  peft_config = None
127
  else:
@@ -179,7 +186,7 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
179
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
  model = model.to(device)
181
  elif peft_config:
182
- # Apply PEFT for quantized models (except MTLCellClassifier)
183
  model.enable_input_require_grads()
184
  model = get_peft_model(model, peft_config)
185
 
 
113
 
114
  # load model to GPU
115
  def load_model(model_type, num_classes, model_directory, mode, quantize=False):
116
+ if model_type == "Pretrained-Quantized":
117
+ inference_only = True
118
+ model_type = "Pretrained"
119
+ quantize = True
120
+ elif model_type == "MTLCellClassifier-Quantized":
121
+ inference_only = True
122
  model_type = "MTLCellClassifier"
123
  quantize = True
124
+ else:
125
+ inference_only = False
126
 
127
  output_hidden_states = (mode == "eval")
128
 
129
  # Quantization logic
130
  if quantize:
131
+ if inference_only:
132
  quantize_config = BitsAndBytesConfig(load_in_8bit=True)
133
  peft_config = None
134
  else:
 
186
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
  model = model.to(device)
188
  elif peft_config:
189
+ # Apply PEFT for quantized models (except MTLCellClassifier and CellClassifier-QuantInf)
190
  model.enable_input_require_grads()
191
  model = get_peft_model(model, peft_config)
192