Christina Theodoris
commited on
Commit
•
3d06203
1
Parent(s):
3072225
Correct order of state dict in in silico perturber stats and tensor dims of alt state emb in in silico perturber
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -266,7 +266,6 @@ def quant_cos_sims(model,
|
|
266 |
def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
|
267 |
cos = torch.nn.CosineSimilarity(dim=2)
|
268 |
original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
|
269 |
-
alt_emb = alt_emb[None, None, :]
|
270 |
origin_v_end = cos(original_emb,alt_emb)
|
271 |
perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
|
272 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
@@ -483,7 +482,7 @@ class InSilicoPerturber:
|
|
483 |
"only outputs effect on cell embeddings.")
|
484 |
|
485 |
if self.cell_states_to_model is not None:
|
486 |
-
if
|
487 |
for key,value in self.cell_states_to_model.items():
|
488 |
if (len(value) == 3) and isinstance(value, tuple):
|
489 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
|
|
266 |
def cos_sim_shift(original_emb, minibatch_emb, alt_emb):
|
267 |
cos = torch.nn.CosineSimilarity(dim=2)
|
268 |
original_emb = torch.mean(original_emb,dim=0,keepdim=True)[None, :]
|
|
|
269 |
origin_v_end = cos(original_emb,alt_emb)
|
270 |
perturb_v_end = cos(torch.mean(minibatch_emb,dim=1,keepdim=True),alt_emb)
|
271 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
|
|
482 |
"only outputs effect on cell embeddings.")
|
483 |
|
484 |
if self.cell_states_to_model is not None:
|
485 |
+
if len(self.cell_states_to_model.items()) == 1:
|
486 |
for key,value in self.cell_states_to_model.items():
|
487 |
if (len(value) == 3) and isinstance(value, tuple):
|
488 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -108,9 +108,10 @@ def get_impact_component(test_value, gaussian_mixture_model):
|
|
108 |
|
109 |
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
110 |
def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
111 |
-
|
|
|
112 |
alt_end_state_exists = False
|
113 |
-
elif (len(cell_states_to_model[
|
114 |
alt_end_state_exists = True
|
115 |
|
116 |
random_tuples = []
|
@@ -120,20 +121,15 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
|
120 |
random_tuples += dict_i.get((token, "cell_emb"),[])
|
121 |
|
122 |
if alt_end_state_exists == False:
|
123 |
-
goal_end_random_megalist = [goal_end for goal_end
|
124 |
-
start_state_random_megalist = [start_state for goal_end,start_state in random_tuples]
|
125 |
elif alt_end_state_exists == True:
|
126 |
-
goal_end_random_megalist = [goal_end for goal_end,alt_end
|
127 |
-
alt_end_random_megalist = [alt_end for goal_end,alt_end
|
128 |
-
start_state_random_megalist = [start_state for goal_end,alt_end,start_state in random_tuples]
|
129 |
|
130 |
# downsample to improve speed of ranksums
|
131 |
if len(goal_end_random_megalist) > 100_000:
|
132 |
random.seed(42)
|
133 |
goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
|
134 |
-
if len(start_state_random_megalist) > 100_000:
|
135 |
-
random.seed(42)
|
136 |
-
start_state_random_megalist = random.sample(start_state_random_megalist, k=100_000)
|
137 |
if alt_end_state_exists == True:
|
138 |
if len(alt_end_random_megalist) > 100_000:
|
139 |
random.seed(42)
|
@@ -161,10 +157,10 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
|
161 |
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
162 |
|
163 |
if alt_end_state_exists == False:
|
164 |
-
goal_end_cos_sim_megalist = [goal_end for goal_end
|
165 |
elif alt_end_state_exists == True:
|
166 |
-
goal_end_cos_sim_megalist = [goal_end for goal_end,alt_end
|
167 |
-
alt_end_cos_sim_megalist = [alt_end for goal_end,alt_end
|
168 |
mean_alt_end = np.mean(alt_end_cos_sim_megalist)
|
169 |
pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
|
170 |
|
@@ -451,7 +447,7 @@ class InSilicoPerturberStats:
|
|
451 |
raise
|
452 |
|
453 |
if self.cell_states_to_model is not None:
|
454 |
-
if
|
455 |
for key,value in self.cell_states_to_model.items():
|
456 |
if (len(value) == 3) and isinstance(value, tuple):
|
457 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
|
|
108 |
|
109 |
# stats comparing cos sim shifts towards goal state of test perturbations vs random perturbations
|
110 |
def isp_stats_to_goal_state(cos_sims_df, dict_list, cell_states_to_model):
|
111 |
+
cell_state_key = list(cell_states_to_model.keys())[0]
|
112 |
+
if cell_states_to_model[cell_state_key][2] == []:
|
113 |
alt_end_state_exists = False
|
114 |
+
elif (len(cell_states_to_model[cell_state_key][2]) > 0) and (cell_states_to_model[cell_state_key][2] != [None]):
|
115 |
alt_end_state_exists = True
|
116 |
|
117 |
random_tuples = []
|
|
|
121 |
random_tuples += dict_i.get((token, "cell_emb"),[])
|
122 |
|
123 |
if alt_end_state_exists == False:
|
124 |
+
goal_end_random_megalist = [goal_end for start_state,goal_end in random_tuples]
|
|
|
125 |
elif alt_end_state_exists == True:
|
126 |
+
goal_end_random_megalist = [goal_end for start_state,goal_end,alt_end in random_tuples]
|
127 |
+
alt_end_random_megalist = [alt_end for start_state,goal_end,alt_end in random_tuples]
|
|
|
128 |
|
129 |
# downsample to improve speed of ranksums
|
130 |
if len(goal_end_random_megalist) > 100_000:
|
131 |
random.seed(42)
|
132 |
goal_end_random_megalist = random.sample(goal_end_random_megalist, k=100_000)
|
|
|
|
|
|
|
133 |
if alt_end_state_exists == True:
|
134 |
if len(alt_end_random_megalist) > 100_000:
|
135 |
random.seed(42)
|
|
|
157 |
cos_shift_data += dict_i.get((token, "cell_emb"),[])
|
158 |
|
159 |
if alt_end_state_exists == False:
|
160 |
+
goal_end_cos_sim_megalist = [goal_end for start_state,goal_end in cos_shift_data]
|
161 |
elif alt_end_state_exists == True:
|
162 |
+
goal_end_cos_sim_megalist = [goal_end for start_state,goal_end,alt_end in cos_shift_data]
|
163 |
+
alt_end_cos_sim_megalist = [alt_end for start_state,goal_end,alt_end in cos_shift_data]
|
164 |
mean_alt_end = np.mean(alt_end_cos_sim_megalist)
|
165 |
pval_alt_end = ranksums(alt_end_random_megalist,alt_end_cos_sim_megalist).pvalue
|
166 |
|
|
|
447 |
raise
|
448 |
|
449 |
if self.cell_states_to_model is not None:
|
450 |
+
if len(self.cell_states_to_model.items()) == 1:
|
451 |
for key,value in self.cell_states_to_model.items():
|
452 |
if (len(value) == 3) and isinstance(value, tuple):
|
453 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|