embs_df with all model embeddings

#363
Files changed (1) hide show
  1. geneformer/emb_extractor.py +6 -14
geneformer/emb_extractor.py CHANGED
@@ -49,10 +49,8 @@ def get_embs(
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
52
- # test embedding extraction for example cell and extract # emb dims
53
- example = filtered_input_data.select([i for i in range(1)])
54
- example.set_format(type="torch")
55
- emb_dims = test_emb(model, example["input_ids"], layer_to_quant)
56
  if emb_mode == "cell":
57
  # initiate tdigests for # of emb dims
58
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
@@ -239,14 +237,6 @@ def tdigest_median(embs_tdigests, emb_dims):
239
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
240
 
241
 
242
- def test_emb(model, example, layer_to_quant):
243
- with torch.no_grad():
244
- outputs = model(input_ids=example.to("cuda"))
245
-
246
- embs_test = outputs.hidden_states[layer_to_quant]
247
- return embs_test.size()[2]
248
-
249
-
250
  def label_cell_embs(embs, downsampled_data, emb_labels):
251
  embs_df = pd.DataFrame(embs.cpu().numpy())
252
  if emb_labels is not None:
@@ -632,13 +622,15 @@ class EmbExtractor:
632
 
633
  if self.exact_summary_stat == "exact_mean":
634
  embs = embs.mean(dim=0)
 
635
  embs_df = pd.DataFrame(
636
- embs_df[0:255].mean(axis="rows"), columns=[self.exact_summary_stat]
637
  ).T
638
  elif self.exact_summary_stat == "exact_median":
639
  embs = torch.median(embs, dim=0)[0]
 
640
  embs_df = pd.DataFrame(
641
- embs_df[0:255].median(axis="rows"), columns=[self.exact_summary_stat]
642
  ).T
643
 
644
  if cell_state is not None:
 
49
  if summary_stat is None:
50
  embs_list = []
51
  elif summary_stat is not None:
52
+ # get # of emb dims
53
+ emb_dims = pu.get_model_emb_dims(model)
 
 
54
  if emb_mode == "cell":
55
  # initiate tdigests for # of emb dims
56
  embs_tdigests = [TDigest() for _ in range(emb_dims)]
 
237
  return [embs_tdigests[i].percentile(50) for i in range(emb_dims)]
238
 
239
 
 
 
 
 
 
 
 
 
240
  def label_cell_embs(embs, downsampled_data, emb_labels):
241
  embs_df = pd.DataFrame(embs.cpu().numpy())
242
  if emb_labels is not None:
 
622
 
623
  if self.exact_summary_stat == "exact_mean":
624
  embs = embs.mean(dim=0)
625
+ emb_dims = pu.get_model_embedding_dimensions(model)
626
  embs_df = pd.DataFrame(
627
+ embs_df[0:emb_dims-1].mean(axis="rows"), columns=[self.exact_summary_stat]
628
  ).T
629
  elif self.exact_summary_stat == "exact_median":
630
  embs = torch.median(embs, dim=0)[0]
631
+ emb_dims = pu.get_model_embedding_dimensions(model)
632
  embs_df = pd.DataFrame(
633
+ embs_df[0:emb_dims-1].median(axis="rows"), columns=[self.exact_summary_stat]
634
  ).T
635
 
636
  if cell_state is not None: