madhavanvenkatesh commited on
Commit
76101b4
1 Parent(s): 11bcee7

fixed bug related to dynamic ranges in dictionary with 'min' and 'max' value mismatch in optuna suggest fn

Browse files
Files changed (1) hide show
  1. geneformer/mtl/train.py +9 -14
geneformer/mtl/train.py CHANGED
@@ -9,7 +9,7 @@ from tqdm import tqdm
9
 
10
  from .imports import *
11
  from .model import GeneformerMultiTask
12
- from .utils import calculate_task_specific_metrics
13
 
14
 
15
  def set_seed(seed):
@@ -280,7 +280,7 @@ def objective(
280
  "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
  )
282
  config["use_attention_pooling"] = trial.suggest_categorical(
283
- "use_attention_pooling", [True, False]
284
  )
285
 
286
  if config["use_task_weights"]:
@@ -299,18 +299,13 @@ def objective(
299
  else:
300
  config["task_weights"] = None
301
 
302
- # Fix for max_layers_to_freeze
303
- if isinstance(config["max_layers_to_freeze"], dict):
304
- config["max_layers_to_freeze"] = trial.suggest_int(
305
- "max_layers_to_freeze",
306
- config["max_layers_to_freeze"]["min"],
307
- config["max_layers_to_freeze"]["max"],
308
- )
309
- elif isinstance(config["max_layers_to_freeze"], int):
310
- # If it's already an int, we don't need to suggest it
311
- pass
312
- else:
313
- raise ValueError("Invalid type for max_layers_to_freeze. Expected dict or int.")
314
 
315
  model = create_model(config, num_labels_list, device)
316
  total_steps = len(train_loader) * config["epochs"]
 
9
 
10
  from .imports import *
11
  from .model import GeneformerMultiTask
12
+ from .utils import calculate_task_specific_metrics, get_layer_freeze_range
13
 
14
 
15
  def set_seed(seed):
 
280
  "lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]
281
  )
282
  config["use_attention_pooling"] = trial.suggest_categorical(
283
+ "use_attention_pooling", [False]
284
  )
285
 
286
  if config["use_task_weights"]:
 
299
  else:
300
  config["task_weights"] = None
301
 
302
+ # Dynamic range for max_layers_to_freeze
303
+ freeze_range = get_layer_freeze_range(config["pretrained_path"])
304
+ config["max_layers_to_freeze"] = trial.suggest_int(
305
+ "max_layers_to_freeze",
306
+ freeze_range["min"],
307
+ freeze_range["max"]
308
+ )
 
 
 
 
 
309
 
310
  model = create_model(config, num_labels_list, device)
311
  total_steps = len(train_loader) * config["epochs"]