Christina Theodoris commited on
Commit
3d06203
1 Parent(s): 3072225

Correct order of state dict in in silico perturber stats and tensor dims of alt state emb in in silico perturber

Browse files
geneformer/in_silico_perturber.py CHANGED
@@ -266,7 +266,6 @@ def quant_cos_sims(model,
266
  def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
267
  cos = torch.nn.CosineSimilarity(dim=2)
268
  original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
269
- alt_emb = alt_emb[None, None, :]
270
  origin_v_end = cos(original_emb,alt_emb)
271
  perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
272
  return [(perturb_v_end-origin_v_end).to("cpu")]
@@ -483,7 +482,7 @@ class InSilicoPerturber:
483
  "only outputs effect on cell embeddings.")
484
 
485
  if self.cell_states_to_model is not None:
486
- if (len(self.cell_states_to_model.items()) == 1):
487
  for key,value in self.cell_states_to_model.items():
488
  if (len(value) == 3) and isinstance(value, tuple):
489
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
266
  def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
267
  cos = torch.nn.CosineSimilarity(dim=2)
268
  original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
 
269
  origin_v_end = cos(original_emb,alt_emb)
270
  perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
271
  return [(perturb_v_end-origin_v_end).to("cpu")]
 
482
  "only outputs effect on cell embeddings.")
483
 
484
  if self.cell_states_to_model is not None:
485
+ if len(self.cell_states_to_model.items()) == 1:
486
  for key,value in self.cell_states_to_model.items():
487
  if (len(value) == 3) and isinstance(value, tuple):
488
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
geneformer/in_silico_perturber_stats.py CHANGED
@@ -108,9 +108,10 @@ def get_impact_component(test_value, gaussian_mixture_model):
108
 
109
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
110
  def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
111
- if cell_states_to_model["disease"][2] == []:
 
112
  alt_end_state_exists = False
113
- elif (len(cell_states_to_model["disease"][2]) > 0) & (cell_states_to_model["disease"][2] != [None]):
114
  alt_end_state_exists = True
115
 
116
  random_tuples = []
@@ -120,20 +121,15 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
120
  random_tuples += dict_i.get((token, "cell_emb"),[])
121
 
122
  if alt_end_state_exists == False:
123
- goal_end_random_megalist = [goal_end for goal_end,start_state in random_tuples]
124
- start_state_random_megalist = [start_state for goal_end,start_state in random_tuples]
125
  elif alt_end_state_exists == True:
126
- goal_end_random_megalist = [goal_end for goal_end,alt_end,start_state in random_tuples]
127
- alt_end_random_megalist = [alt_end for goal_end,alt_end,start_state in random_tuples]
128
- start_state_random_megalist = [start_state for goal_end,alt_end,start_state in random_tuples]
129
 
130
  # downsample to improve speed of ranksums
131
  if len(goal_end_random_megalist) > 100_000:
132
  random.seed(42)
133
  goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
134
- if len(start_state_random_megalist) > 100_000:
135
- random.seed(42)
136
- start_state_random_megalist = random.sample(start_state_random_megalist, k=100_000)
137
  if alt_end_state_exists == True:
138
  if len(alt_end_random_megalist) > 100_000:
139
  random.seed(42)
@@ -161,10 +157,10 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
161
  cos_shift_data += dict_i.get((token, "cell_emb"),[])
162
 
163
  if alt_end_state_exists == False:
164
- goal_end_cos_sim_megalist = [goal_end for goal_end,start_state in cos_shift_data]
165
  elif alt_end_state_exists == True:
166
- goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end,start_state in cos_shift_data]
167
- alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end,start_state in cos_shift_data]
168
  mean_alt_end = np.mean(alt_end_cos_sim_megalist)
169
  pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
170
 
@@ -451,7 +447,7 @@ class InSilicoPerturberStats:
451
  raise
452
 
453
  if self.cell_states_to_model is not None:
454
- if (len(self.cell_states_to_model.items()) == 1):
455
  for key,value in self.cell_states_to_model.items():
456
  if (len(value) == 3) and isinstance(value, tuple):
457
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
 
108
 
109
  # stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
110
  def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
111
+ cell_state_key = list(cell_states_to_model.keys())[0]
112
+ if cell_states_to_model[cell_state_key][2] == []:
113
  alt_end_state_exists = False
114
+ elif (len(cell_states_to_model[cell_state_key][2]) > 0) and (cell_states_to_model[cell_state_key][2] != [None]):
115
  alt_end_state_exists = True
116
 
117
  random_tuples = []
 
121
  random_tuples += dict_i.get((token, "cell_emb"),[])
122
 
123
  if alt_end_state_exists == False:
124
+ goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
 
125
  elif alt_end_state_exists == True:
126
+ goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
127
+ alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
 
128
 
129
  # downsample to improve speed of ranksums
130
  if len(goal_end_random_megalist) > 100_000:
131
  random.seed(42)
132
  goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
 
 
 
133
  if alt_end_state_exists == True:
134
  if len(alt_end_random_megalist) > 100_000:
135
  random.seed(42)
 
157
  cos_shift_data += dict_i.get((token, "cell_emb"),[])
158
 
159
  if alt_end_state_exists == False:
160
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
161
  elif alt_end_state_exists == True:
162
+ goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
163
+ alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
164
  mean_alt_end = np.mean(alt_end_cos_sim_megalist)
165
  pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
166
 
 
447
  raise
448
 
449
  if self.cell_states_to_model is not None:
450
+ if len(self.cell_states_to_model.items()) == 1:
451
  for key,value in self.cell_states_to_model.items():
452
  if (len(value) == 3) and isinstance(value, tuple):
453
  if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):