Christina Theodoris commited on
Commit
268e566
1 Parent(s): 57b9778

Fix min_genes to be >= tokens to perturb as a group

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +36 -20
geneformer/in_silico_perturber.py CHANGED
@@ -58,6 +58,16 @@ def measure_length(example):
58
  example["length"] = len(example["input_ids"])
59
  return example
60
 
 
 
 
 
 
 
 
 
 
 
61
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
62
  example_cell.set_format(type="torch")
63
  input_data = example_cell["input_ids"]
@@ -75,8 +85,8 @@ def perturb_emb_by_index(emb, indices):
75
  return emb[mask]
76
 
77
  def delete_indices(example):
78
- indices = example["perturb_index"]
79
- if len(indices)>1:
80
  indices = flatten_list(indices)
81
  for index in sorted(indices, reverse=True):
82
  del example["input_ids"][index]
@@ -84,10 +94,10 @@ def delete_indices(example):
84
 
85
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
86
  def overexpress_indices(example):
87
- indexes = example["perturb_index"]
88
- if len(indexes)>1:
89
- indexes = flatten_list(indexes)
90
- for index in sorted(indexes, reverse=True):
91
  example["input_ids"].insert(0, example["input_ids"].pop(index))
92
  return example
93
 
@@ -165,7 +175,7 @@ def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group)
165
  continue
166
  emb_list = []
167
  start = 0
168
- if len(indices)>1 and isinstance(indices[0],list):
169
  indices = flatten_list(indices)
170
  for i in sorted(indices):
171
  emb_list += [original_emb[start:i]]
@@ -724,8 +734,9 @@ class InSilicoPerturber:
724
  state_embs_dict = None
725
  else:
726
  # get dictionary of average cell state embeddings for comparison
 
727
  state_embs_dict = get_cell_state_avg_embs(model,
728
- filtered_input_data,
729
  self.cell_states_to_model,
730
  layer_to_quant,
731
  self.pad_token_id,
@@ -758,14 +769,7 @@ class InSilicoPerturber:
758
  "No cells remain after filtering. Check filtering criteria.")
759
  raise
760
  data_shuffled = data.shuffle(seed=42)
761
- num_cells = len(data_shuffled)
762
- # if max number of cells is defined, then subsample to this max number
763
- if self.max_ncells != None:
764
- num_cells = min(self.max_ncells,num_cells)
765
- data_subset = data_shuffled.select([i for i in range(num_cells)])
766
- # sort dataset with largest cell first to encounter any memory errors earlier
767
- data_sorted = data_subset.sort("length",reverse=True)
768
- return data_sorted
769
 
770
  # load model to GPU
771
  def load_model(self, model_directory):
@@ -804,17 +808,29 @@ class InSilicoPerturber:
804
  if self.anchor_token is not None:
805
  def if_has_tokens_to_perturb(example):
806
  return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
807
- filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
808
- logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
 
 
 
 
 
 
809
  if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
810
  # minimum # genes needed for perturbation test
811
  min_genes = len(self.tokens_to_perturb)
 
812
  def if_has_tokens_to_perturb(example):
813
- return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>min_genes)
814
  filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
815
-
 
 
 
 
816
  cos_sims_dict = defaultdict(list)
817
  pickle_batch = -1
 
818
 
819
  # make perturbation batch w/ single perturbation in multiple cells
820
  if self.perturb_group == True:
 
58
  example["length"] = len(example["input_ids"])
59
  return example
60
 
61
+ def downsample_and_sort(data_shuffled, max_ncells):
62
+ num_cells = len(data_shuffled)
63
+ # if max number of cells is defined, then subsample to this max number
64
+ if max_ncells != None:
65
+ num_cells = min(max_ncells,num_cells)
66
+ data_subset = data_shuffled.select([i for i in range(num_cells)])
67
+ # sort dataset with largest cell first to encounter any memory errors earlier
68
+ data_sorted = data_subset.sort("length",reverse=True)
69
+ return data_sorted
70
+
71
  def forward_pass_single_cell(model, example_cell, layer_to_quant):
72
  example_cell.set_format(type="torch")
73
  input_data = example_cell["input_ids"]
 
85
  return emb[mask]
86
 
87
  def delete_indices(example):
88
+ indices = example["perturb_index"]
89
+ if any(isinstance(el, list) for el in indices):
90
  indices = flatten_list(indices)
91
  for index in sorted(indices, reverse=True):
92
  del example["input_ids"][index]
 
94
 
95
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
96
  def overexpress_indices(example):
97
+ indices = example["perturb_index"]
98
+ if any(isinstance(el, list) for el in indices):
99
+ indices = flatten_list(indices)
100
+ for index in sorted(indices, reverse=True):
101
  example["input_ids"].insert(0, example["input_ids"].pop(index))
102
  return example
103
 
 
175
  continue
176
  emb_list = []
177
  start = 0
178
+ if any(isinstance(el, list) for el in indices):
179
  indices = flatten_list(indices)
180
  for i in sorted(indices):
181
  emb_list += [original_emb[start:i]]
 
734
  state_embs_dict = None
735
  else:
736
  # get dictionary of average cell state embeddings for comparison
737
+ downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
738
  state_embs_dict = get_cell_state_avg_embs(model,
739
+ downsampled_data,
740
  self.cell_states_to_model,
741
  layer_to_quant,
742
  self.pad_token_id,
 
769
  "No cells remain after filtering. Check filtering criteria.")
770
  raise
771
  data_shuffled = data.shuffle(seed=42)
772
+ return data_shuffled
 
 
 
 
 
 
 
773
 
774
  # load model to GPU
775
  def load_model(self, model_directory):
 
808
  if self.anchor_token is not None:
809
  def if_has_tokens_to_perturb(example):
810
  return (len(set(example["input_ids"]).intersection(self.anchor_token))==len(self.anchor_token))
811
+ filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
812
+ if len(filtered_input_data) == 0:
813
+ logger.error(
814
+ "No cells in dataset contain anchor gene.")
815
+ raise
816
+ else:
817
+ logger.info(f"# cells with anchor gene: {len(filtered_input_data)}")
818
+
819
  if (self.tokens_to_perturb != "all") and (self.perturb_type != "overexpress"):
820
  # minimum # genes needed for perturbation test
821
  min_genes = len(self.tokens_to_perturb)
822
+
823
  def if_has_tokens_to_perturb(example):
824
+ return (len(set(example["input_ids"]).intersection(self.tokens_to_perturb))>=min_genes)
825
  filtered_input_data = filtered_input_data.filter(if_has_tokens_to_perturb, num_proc=self.nproc)
826
+ if len(filtered_input_data) == 0:
827
+ logger.error(
828
+ "No cells in dataset contain all genes to perturb as a group.")
829
+ raise
830
+
831
  cos_sims_dict = defaultdict(list)
832
  pickle_batch = -1
833
+ filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
834
 
835
  # make perturbation batch w/ single perturbation in multiple cells
836
  if self.perturb_group == True: