Christina Theodoris
commited on
Commit
·
a13a2cf
1
Parent(s):
cfc8cdb
fix exact_mean subselection to be for index rather than col name
Browse files
geneformer/emb_extractor.py
CHANGED
@@ -644,19 +644,15 @@ class EmbExtractor:
|
|
644 |
if self.exact_summary_stat == "exact_mean":
|
645 |
embs = embs.mean(dim=0)
|
646 |
emb_dims = pu.get_model_emb_dims(model)
|
647 |
-
print(embs_df.shape)
|
648 |
-
print(embs_df)
|
649 |
-
print("#######")
|
650 |
-
print(embs_df.iloc[:, 0 : emb_dims - 1])
|
651 |
embs_df = pd.DataFrame(
|
652 |
-
embs_df.iloc[:, 0 : emb_dims
|
653 |
columns=[self.exact_summary_stat],
|
654 |
).T
|
655 |
elif self.exact_summary_stat == "exact_median":
|
656 |
embs = torch.median(embs, dim=0)[0]
|
657 |
emb_dims = pu.get_model_emb_dims(model)
|
658 |
embs_df = pd.DataFrame(
|
659 |
-
embs_df.iloc[:, 0 : emb_dims
|
660 |
columns=[self.exact_summary_stat],
|
661 |
).T
|
662 |
|
|
|
644 |
if self.exact_summary_stat == "exact_mean":
|
645 |
embs = embs.mean(dim=0)
|
646 |
emb_dims = pu.get_model_emb_dims(model)
|
|
|
|
|
|
|
|
|
647 |
embs_df = pd.DataFrame(
|
648 |
+
embs_df.iloc[:, 0 : emb_dims].mean(axis="rows"),
|
649 |
columns=[self.exact_summary_stat],
|
650 |
).T
|
651 |
elif self.exact_summary_stat == "exact_median":
|
652 |
embs = torch.median(embs, dim=0)[0]
|
653 |
emb_dims = pu.get_model_emb_dims(model)
|
654 |
embs_df = pd.DataFrame(
|
655 |
+
embs_df.iloc[:, 0 : emb_dims].median(axis="rows"),
|
656 |
columns=[self.exact_summary_stat],
|
657 |
).T
|
658 |
|