CUDA kernels incompatible with standard PyTorch device movement with 4bit/8bit, necessitating device-specific handling

#416
Files changed (1) hide show
  1. geneformer/perturber_utils.py +60 -70
geneformer/perturber_utils.py CHANGED
@@ -117,83 +117,73 @@ def load_model(model_type, num_classes, model_directory, mode, quantize=False):
117
  model_type = "MTLCellClassifier"
118
  quantize = True
119
 
120
- if mode == "eval":
121
- output_hidden_states = True
122
- elif mode == "train":
123
- output_hidden_states = False
124
 
125
- if quantize is True:
 
126
  if model_type == "MTLCellClassifier":
127
- quantize = {
128
- "peft_config": None,
129
- "bnb_config": BitsAndBytesConfig(
130
- load_in_8bit=True,
131
- ),
132
- }
133
  else:
134
- quantize = {
135
- "peft_config": LoraConfig(
136
- lora_alpha=128,
137
- lora_dropout=0.1,
138
- r=64,
139
- bias="none",
140
- task_type="TokenClassification",
141
- ),
142
- "bnb_config": BitsAndBytesConfig(
143
- load_in_4bit=True,
144
- bnb_4bit_use_double_quant=True,
145
- bnb_4bit_quant_type="nf4",
146
- bnb_4bit_compute_dtype=torch.bfloat16,
147
- ),
148
- }
149
- elif quantize is False:
150
- quantize = {"bnb_config": None}
151
-
152
- if model_type == "Pretrained":
153
- model = BertForMaskedLM.from_pretrained(
154
- model_directory,
155
- output_hidden_states=output_hidden_states,
156
- output_attentions=False,
157
- quantization_config=quantize["bnb_config"],
158
- )
159
- elif model_type == "GeneClassifier":
160
- model = BertForTokenClassification.from_pretrained(
161
- model_directory,
162
- num_labels=num_classes,
163
- output_hidden_states=output_hidden_states,
164
- output_attentions=False,
165
- quantization_config=quantize["bnb_config"],
166
- )
167
- elif model_type == "CellClassifier":
168
- model = BertForSequenceClassification.from_pretrained(
169
- model_directory,
170
- num_labels=num_classes,
171
- output_hidden_states=output_hidden_states,
172
- output_attentions=False,
173
- quantization_config=quantize["bnb_config"],
174
- )
175
- elif model_type == "MTLCellClassifier":
176
- model = BertForMaskedLM.from_pretrained(
177
- model_directory,
178
- num_labels=num_classes,
179
- output_hidden_states=output_hidden_states,
180
- output_attentions=False,
181
- quantization_config=quantize["bnb_config"],
182
- )
183
- # if eval mode, put the model in eval mode for fwd pass
184
  if mode == "eval":
185
  model.eval()
186
- if (
187
- (quantize is False)
188
- or (quantize == {"bnb_config": None})
189
- or (model_type == "MTLCellClassifier")
190
- ):
191
- model = model.to("cuda")
192
- else:
 
193
  model.enable_input_require_grads()
194
- model = get_peft_model(model, quantize["peft_config"])
195
- return model
196
 
 
197
 
198
  def quant_layers(model):
199
  layer_nums = []
 
117
  model_type = "MTLCellClassifier"
118
  quantize = True
119
 
120
+ output_hidden_states = (mode == "eval")
 
 
 
121
 
122
+ # Quantization logic
123
+ if quantize:
124
  if model_type == "MTLCellClassifier":
125
+ quantize_config = BitsAndBytesConfig(load_in_8bit=True)
126
+ peft_config = None
 
 
 
 
127
  else:
128
+ quantize_config = BitsAndBytesConfig(
129
+ load_in_4bit=True,
130
+ bnb_4bit_use_double_quant=True,
131
+ bnb_4bit_quant_type="nf4",
132
+ bnb_4bit_compute_dtype=torch.bfloat16,
133
+ )
134
+ peft_config = LoraConfig(
135
+ lora_alpha=128,
136
+ lora_dropout=0.1,
137
+ r=64,
138
+ bias="none",
139
+ task_type="TokenClassification",
140
+ )
141
+ else:
142
+ quantize_config = None
143
+ peft_config = None
144
+
145
+ # Model class selection
146
+ model_classes = {
147
+ "Pretrained": BertForMaskedLM,
148
+ "GeneClassifier": BertForTokenClassification,
149
+ "CellClassifier": BertForSequenceClassification,
150
+ "MTLCellClassifier": BertForMaskedLM
151
+ }
152
+
153
+ model_class = model_classes.get(model_type)
154
+ if not model_class:
155
+ raise ValueError(f"Unknown model type: {model_type}")
156
+
157
+ # Model loading
158
+ model_args = {
159
+ "pretrained_model_name_or_path": model_directory,
160
+ "output_hidden_states": output_hidden_states,
161
+ "output_attentions": False,
162
+ }
163
+
164
+ if model_type != "Pretrained":
165
+ model_args["num_labels"] = num_classes
166
+
167
+ if quantize_config:
168
+ model_args["quantization_config"] = quantize_config
169
+
170
+ # Load the model
171
+ model = model_class.from_pretrained(**model_args)
172
+
 
 
 
 
 
173
  if mode == "eval":
174
  model.eval()
175
+
176
+ # Handle device placement and PEFT
177
+ if not quantize:
178
+ # Only move non-quantized models
179
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
+ model = model.to(device)
181
+ elif peft_config:
182
+ # Apply PEFT for quantized models (except MTLCellClassifier)
183
  model.enable_input_require_grads()
184
+ model = get_peft_model(model, peft_config)
 
185
 
186
+ return model
187
 
188
  def quant_layers(model):
189
  layer_nums = []