Christina Theodoris commited on
Commit
a13a2cf
·
1 Parent(s): cfc8cdb

fix exact_mean subselection to be for index rather than col name

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +2 -6
geneformer/emb_extractor.py CHANGED
@@ -644,19 +644,15 @@ class EmbExtractor:
644
  if self.exact_summary_stat == "exact_mean":
645
  embs = embs.mean(dim=0)
646
  emb_dims = pu.get_model_emb_dims(model)
647
- print(embs_df.shape)
648
- print(embs_df)
649
- print("#######")
650
- print(embs_df.iloc[:, 0 : emb_dims - 1])
651
  embs_df = pd.DataFrame(
652
- embs_df.iloc[:, 0 : emb_dims - 1].mean(axis="rows"),
653
  columns=[self.exact_summary_stat],
654
  ).T
655
  elif self.exact_summary_stat == "exact_median":
656
  embs = torch.median(embs, dim=0)[0]
657
  emb_dims = pu.get_model_emb_dims(model)
658
  embs_df = pd.DataFrame(
659
- embs_df.iloc[:, 0 : emb_dims - 1].median(axis="rows"),
660
  columns=[self.exact_summary_stat],
661
  ).T
662
 
 
644
  if self.exact_summary_stat == "exact_mean":
645
  embs = embs.mean(dim=0)
646
  emb_dims = pu.get_model_emb_dims(model)
 
 
 
 
647
  embs_df = pd.DataFrame(
648
+ embs_df.iloc[:, 0 : emb_dims].mean(axis="rows"),
649
  columns=[self.exact_summary_stat],
650
  ).T
651
  elif self.exact_summary_stat == "exact_median":
652
  embs = torch.median(embs, dim=0)[0]
653
  emb_dims = pu.get_model_emb_dims(model)
654
  embs_df = pd.DataFrame(
655
+ embs_df.iloc[:, 0 : emb_dims].median(axis="rows"),
656
  columns=[self.exact_summary_stat],
657
  ).T
658