embs_df with all model embeddings
#363
by
hchen725
- opened
- geneformer/emb_extractor.py +6 -14
geneformer/emb_extractor.py
CHANGED
@@ -49,10 +49,8 @@ def get_embs(
|
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
52 |
-
#
|
53 |
-
|
54 |
-
example.set_format(type="torch")
|
55 |
-
emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
|
56 |
if emb_mode == "cell":
|
57 |
# initiate tdigests for # of emb dims
|
58 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
@@ -239,14 +237,6 @@ def tdigest_median(embs_tdigests, emb_dims):
|
|
239 |
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
240 |
|
241 |
|
242 |
-
def test_emb(model, example, layer_to_quant):
|
243 |
-
with torch.no_grad():
|
244 |
-
outputs = model(input_ids=example.to("cuda"))
|
245 |
-
|
246 |
-
embs_test = outputs.hidden_states[layer_to_quant]
|
247 |
-
return embs_test.size()[2]
|
248 |
-
|
249 |
-
|
250 |
def label_cell_embs(embs, downsampled_data, emb_labels):
|
251 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
252 |
if emb_labels is not None:
|
@@ -632,13 +622,15 @@ class EmbExtractor:
|
|
632 |
|
633 |
if self.exact_summary_stat == "exact_mean":
|
634 |
embs = embs.mean(dim=0)
|
|
|
635 |
embs_df = pd.DataFrame(
|
636 |
-
embs_df[0:
|
637 |
).T
|
638 |
elif self.exact_summary_stat == "exact_median":
|
639 |
embs = torch.median(embs, dim=0)[0]
|
|
|
640 |
embs_df = pd.DataFrame(
|
641 |
-
embs_df[0:
|
642 |
).T
|
643 |
|
644 |
if cell_state is not None:
|
|
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
52 |
+
# get # of emb dims
|
53 |
+
emb_dims = pu.get_model_emb_dims(model)
|
|
|
|
|
54 |
if emb_mode == "cell":
|
55 |
# initiate tdigests for # of emb dims
|
56 |
embs_tdigests = [TDigest() for _ in range(emb_dims)]
|
|
|
237 |
return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
|
238 |
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
def label_cell_embs(embs, downsampled_data, emb_labels):
|
241 |
embs_df = pd.DataFrame(embs.cpu().numpy())
|
242 |
if emb_labels is not None:
|
|
|
622 |
|
623 |
if self.exact_summary_stat == "exact_mean":
|
624 |
embs = embs.mean(dim=0)
|
625 |
+
emb_dims = pu.get_model_embedding_dimensions(model)
|
626 |
embs_df = pd.DataFrame(
|
627 |
+
embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
|
628 |
).T
|
629 |
elif self.exact_summary_stat == "exact_median":
|
630 |
embs = torch.median(embs, dim=0)[0]
|
631 |
+
emb_dims = pu.get_model_embedding_dimensions(model)
|
632 |
embs_df = pd.DataFrame(
|
633 |
+
embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
|
634 |
).T
|
635 |
|
636 |
if cell_state is not None:
|