hchen725 commited on
Commit
ae834bc
1 Parent(s): 4bddd45

Update with gene classifier, custom token dict, and str validate options

Browse files
Files changed (1) hide show
  1. geneformer/classifier.py +72 -38
geneformer/classifier.py CHANGED
@@ -53,7 +53,6 @@ from pathlib import Path
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
56
- from sklearn.model_selection import StratifiedKFold
57
  from tqdm.auto import tqdm, trange
58
  from transformers import Trainer
59
  from transformers.training_args import TrainingArguments
@@ -86,6 +85,7 @@ class Classifier:
86
  "no_eval": {bool},
87
  "stratify_splits_col": {None, str},
88
  "forward_batch_size": {int},
 
89
  "nproc": {int},
90
  "ngpu": {int},
91
  }
@@ -107,6 +107,7 @@ class Classifier:
107
  stratify_splits_col=None,
108
  no_eval=False,
109
  forward_batch_size=100,
 
110
  nproc=4,
111
  ngpu=1,
112
  ):
@@ -175,6 +176,9 @@ class Classifier:
175
  | Otherwise, will perform eval during training.
176
  forward_batch_size : int
177
  | Batch size for forward pass (for evaluation, not training).
 
 
 
178
  nproc : int
179
  | Number of CPU processes to use.
180
  ngpu : int
@@ -183,6 +187,10 @@ class Classifier:
183
  """
184
 
185
  self.classifier = classifier
 
 
 
 
186
  self.cell_state_dict = cell_state_dict
187
  self.gene_class_dict = gene_class_dict
188
  self.filter_data = filter_data
@@ -201,6 +209,7 @@ class Classifier:
201
  self.stratify_splits_col = stratify_splits_col
202
  self.no_eval = no_eval
203
  self.forward_batch_size = forward_batch_size
 
204
  self.nproc = nproc
205
  self.ngpu = ngpu
206
 
@@ -222,7 +231,9 @@ class Classifier:
222
  ] = self.cell_state_dict["states"]
223
 
224
  # load token dictionary (Ensembl IDs:token)
225
- with open(TOKEN_DICTIONARY_FILE, "rb") as f:
 
 
226
  self.gene_token_dict = pickle.load(f)
227
 
228
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
@@ -267,7 +278,7 @@ class Classifier:
267
  continue
268
  valid_type = False
269
  for option in valid_options:
270
- if (option in [int, float, list, dict, bool]) and isinstance(
271
  attr_value, option
272
  ):
273
  valid_type = True
@@ -630,7 +641,6 @@ class Classifier:
630
  | Number of trials to run for hyperparameter optimization
631
  | If 0, will not optimize hyperparameters
632
  """
633
-
634
  if self.num_crossval_splits == 0:
635
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
636
  raise
@@ -772,17 +782,20 @@ class Classifier:
772
  ]
773
  )
774
  assert len(targets) == len(labels)
775
- n_splits = int(1 / self.eval_size)
776
- skf = StratifiedKFold(n_splits=n_splits, random_state=0, shuffle=True)
777
  # (Cross-)validate
778
- for train_index, eval_index in tqdm(skf.split(targets, labels)):
 
 
 
779
  print(
780
  f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
781
  )
782
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
783
  # filter data for examples containing classes for this split
784
  # subsample to max_ncells and relabel data in column "labels"
785
- train_data, eval_data = cu.prep_gene_classifier_split(
786
  data,
787
  targets,
788
  labels,
@@ -793,6 +806,18 @@ class Classifier:
793
  self.nproc,
794
  )
795
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  if n_hyperopt_trials == 0:
797
  trainer = self.train_classifier(
798
  model_directory,
@@ -802,6 +827,15 @@ class Classifier:
802
  ksplit_output_dir,
803
  predict_trainer,
804
  )
 
 
 
 
 
 
 
 
 
805
  else:
806
  trainer = self.hyperopt_classifier(
807
  model_directory,
@@ -811,20 +845,27 @@ class Classifier:
811
  ksplit_output_dir,
812
  n_trials=n_hyperopt_trials,
813
  )
814
- if iteration_num == self.num_crossval_splits:
815
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  else:
817
- iteration_num = iteration_num + 1
818
- continue
819
- result = self.evaluate_model(
820
- trainer.model,
821
- num_classes,
822
- id_class_dict,
823
- eval_data,
824
- predict_eval,
825
- ksplit_output_dir,
826
- output_prefix,
827
- )
828
  results += [result]
829
  all_conf_mat = all_conf_mat + result["conf_mat"]
830
  # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
@@ -925,12 +966,7 @@ class Classifier:
925
  subprocess.call(f"mkdir {output_directory}", shell=True)
926
 
927
  ##### Load model and training args #####
928
- if self.classifier == "cell":
929
- model_type = "CellClassifier"
930
- elif self.classifier == "gene":
931
- model_type = "GeneClassifier"
932
-
933
- model = pu.load_model(model_type, num_classes, model_directory, "train")
934
  def_training_args, def_freeze_layers = cu.get_default_train_args(
935
  model, self.classifier, train_data, output_directory
936
  )
@@ -946,6 +982,9 @@ class Classifier:
946
  if eval_data is None:
947
  def_training_args["evaluation_strategy"] = "no"
948
  def_training_args["load_best_model_at_end"] = False
 
 
 
949
  training_args_init = TrainingArguments(**def_training_args)
950
 
951
  ##### Fine-tune the model #####
@@ -957,7 +996,9 @@ class Classifier:
957
 
958
  # define function to initiate model
959
  def model_init():
960
- model = pu.load_model(model_type, num_classes, model_directory, "train")
 
 
961
 
962
  if self.freeze_layers is not None:
963
  def_freeze_layers = self.freeze_layers
@@ -1018,6 +1059,7 @@ class Classifier:
1018
  metric="eval_macro_f1",
1019
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1020
  ),
 
1021
  )
1022
 
1023
  return trainer
@@ -1080,11 +1122,7 @@ class Classifier:
1080
  subprocess.call(f"mkdir {output_directory}", shell=True)
1081
 
1082
  ##### Load model and training args #####
1083
- if self.classifier == "cell":
1084
- model_type = "CellClassifier"
1085
- elif self.classifier == "gene":
1086
- model_type = "GeneClassifier"
1087
- model = pu.load_model(model_type, num_classes, model_directory, "train")
1088
 
1089
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1090
  model, self.classifier, train_data, output_directory
@@ -1238,11 +1276,7 @@ class Classifier:
1238
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1239
 
1240
  # load previously fine-tuned model
1241
- if self.classifier == "cell":
1242
- model_type = "CellClassifier"
1243
- elif self.classifier == "gene":
1244
- model_type = "GeneClassifier"
1245
- model = pu.load_model(model_type, num_classes, model_directory, "eval")
1246
 
1247
  # evaluate the model
1248
  result = self.evaluate_model(
 
53
  import numpy as np
54
  import pandas as pd
55
  import seaborn as sns
 
56
  from tqdm.auto import tqdm, trange
57
  from transformers import Trainer
58
  from transformers.training_args import TrainingArguments
 
85
  "no_eval": {bool},
86
  "stratify_splits_col": {None, str},
87
  "forward_batch_size": {int},
88
+ "token_dictionary_file": {None, str},
89
  "nproc": {int},
90
  "ngpu": {int},
91
  }
 
107
  stratify_splits_col=None,
108
  no_eval=False,
109
  forward_batch_size=100,
110
+ token_dictionary_file=None,
111
  nproc=4,
112
  ngpu=1,
113
  ):
 
176
  | Otherwise, will perform eval during training.
177
  forward_batch_size : int
178
  | Batch size for forward pass (for evaluation, not training).
179
+ token_dictionary_file : None, str
180
+ | Default is to use token dictionary file from Geneformer
181
+ | Otherwise, will load custom gene token dictionary.
182
  nproc : int
183
  | Number of CPU processes to use.
184
  ngpu : int
 
187
  """
188
 
189
  self.classifier = classifier
190
+ if self.classifier == "cell":
191
+ self.model_type = "CellClassifier"
192
+ elif self.classifier == "gene":
193
+ self.model_type = "GeneClassifier"
194
  self.cell_state_dict = cell_state_dict
195
  self.gene_class_dict = gene_class_dict
196
  self.filter_data = filter_data
 
209
  self.stratify_splits_col = stratify_splits_col
210
  self.no_eval = no_eval
211
  self.forward_batch_size = forward_batch_size
212
+ self.token_dictionary_file = token_dictionary_file
213
  self.nproc = nproc
214
  self.ngpu = ngpu
215
 
 
231
  ] = self.cell_state_dict["states"]
232
 
233
  # load token dictionary (Ensembl IDs:token)
234
+ if self.token_dictionary_file is None:
235
+ self.token_dictionary_file = TOKEN_DICTIONARY_FILE
236
+ with open(token_dictionary_file, "rb") as f:
237
  self.gene_token_dict = pickle.load(f)
238
 
239
  self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
 
278
  continue
279
  valid_type = False
280
  for option in valid_options:
281
+ if (option in [int, float, list, dict, bool, str]) and isinstance(
282
  attr_value, option
283
  ):
284
  valid_type = True
 
641
  | Number of trials to run for hyperparameter optimization
642
  | If 0, will not optimize hyperparameters
643
  """
 
644
  if self.num_crossval_splits == 0:
645
  logger.error("num_crossval_splits must be 1 or 5 to validate.")
646
  raise
 
782
  ]
783
  )
784
  assert len(targets) == len(labels)
785
+ n_splits = int(1 / (1 - self.train_size))
786
+ skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
787
  # (Cross-)validate
788
+ test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
789
+ for train_index, eval_index, test_index in tqdm(
790
+ skf.split(targets, labels, test_ratio)
791
+ ):
792
  print(
793
  f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
794
  )
795
  ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
796
  # filter data for examples containing classes for this split
797
  # subsample to max_ncells and relabel data in column "labels"
798
+ train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
799
  data,
800
  targets,
801
  labels,
 
806
  self.nproc,
807
  )
808
 
809
+ if self.oos_test_size > 0:
810
+ test_data = cu.prep_gene_classifier_split(
811
+ data,
812
+ targets,
813
+ labels,
814
+ test_index,
815
+ "test",
816
+ self.max_ncells,
817
+ iteration_num,
818
+ self.nproc,
819
+ )
820
+
821
  if n_hyperopt_trials == 0:
822
  trainer = self.train_classifier(
823
  model_directory,
 
827
  ksplit_output_dir,
828
  predict_trainer,
829
  )
830
+ result = self.evaluate_model(
831
+ trainer.model,
832
+ num_classes,
833
+ id_class_dict,
834
+ eval_data,
835
+ predict_eval,
836
+ ksplit_output_dir,
837
+ output_prefix,
838
+ )
839
  else:
840
  trainer = self.hyperopt_classifier(
841
  model_directory,
 
845
  ksplit_output_dir,
846
  n_trials=n_hyperopt_trials,
847
  )
848
+
849
+ model = cu.load_best_model(
850
+ ksplit_output_dir, self.model_type, num_classes
851
+ )
852
+
853
+ if self.oos_test_size > 0:
854
+ result = self.evaluate_model(
855
+ model,
856
+ num_classes,
857
+ id_class_dict,
858
+ test_data,
859
+ predict_eval,
860
+ ksplit_output_dir,
861
+ output_prefix,
862
+ )
863
  else:
864
+ if iteration_num == self.num_crossval_splits:
865
+ return
866
+ else:
867
+ iteration_num = iteration_num + 1
868
+ continue
 
 
 
 
 
 
869
  results += [result]
870
  all_conf_mat = all_conf_mat + result["conf_mat"]
871
  # break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
 
966
  subprocess.call(f"mkdir {output_directory}", shell=True)
967
 
968
  ##### Load model and training args #####
969
+ model = pu.load_model(self.model_type, num_classes, model_directory, "train")
 
 
 
 
 
970
  def_training_args, def_freeze_layers = cu.get_default_train_args(
971
  model, self.classifier, train_data, output_directory
972
  )
 
982
  if eval_data is None:
983
  def_training_args["evaluation_strategy"] = "no"
984
  def_training_args["load_best_model_at_end"] = False
985
+ def_training_args.update(
986
+ {"save_strategy": "epoch", "save_total_limit": 1}
987
+ ) # only save last model for each run
988
  training_args_init = TrainingArguments(**def_training_args)
989
 
990
  ##### Fine-tune the model #####
 
996
 
997
  # define function to initiate model
998
  def model_init():
999
+ model = pu.load_model(
1000
+ self.model_type, num_classes, model_directory, "train"
1001
+ )
1002
 
1003
  if self.freeze_layers is not None:
1004
  def_freeze_layers = self.freeze_layers
 
1059
  metric="eval_macro_f1",
1060
  metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1061
  ),
1062
+ local_dir=output_directory,
1063
  )
1064
 
1065
  return trainer
 
1122
  subprocess.call(f"mkdir {output_directory}", shell=True)
1123
 
1124
  ##### Load model and training args #####
1125
+ model = pu.load_model(self.model_type, num_classes, model_directory, "train")
 
 
 
 
1126
 
1127
  def_training_args, def_freeze_layers = cu.get_default_train_args(
1128
  model, self.classifier, train_data, output_directory
 
1276
  test_data = pu.load_and_filter(None, self.nproc, test_data_file)
1277
 
1278
  # load previously fine-tuned model
1279
+ model = pu.load_model(self.model_type, num_classes, model_directory, "eval")
 
 
 
 
1280
 
1281
  # evaluate the model
1282
  result = self.evaluate_model(