ctheodoris commited on
Commit
ace12e9
1 Parent(s): 3fe35ba

update get_embs with token_gene_dict arg

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +5 -0
geneformer/in_silico_perturber.py CHANGED
@@ -228,6 +228,7 @@ class InSilicoPerturber:
228
  # load token dictionary (Ensembl IDs:token)
229
  with open(token_dictionary_file, "rb") as f:
230
  self.gene_token_dict = pickle.load(f)
 
231
 
232
  self.pad_token_id = self.gene_token_dict.get("<pad>")
233
 
@@ -560,6 +561,7 @@ class InSilicoPerturber:
560
  layer_to_quant,
561
  self.pad_token_id,
562
  self.forward_batch_size,
 
563
  summary_stat=None,
564
  silent=True,
565
  )
@@ -579,6 +581,7 @@ class InSilicoPerturber:
579
  layer_to_quant,
580
  self.pad_token_id,
581
  self.forward_batch_size,
 
582
  summary_stat=None,
583
  silent=True,
584
  )
@@ -738,6 +741,7 @@ class InSilicoPerturber:
738
  layer_to_quant,
739
  self.pad_token_id,
740
  self.forward_batch_size,
 
741
  summary_stat=None,
742
  silent=True,
743
  )
@@ -765,6 +769,7 @@ class InSilicoPerturber:
765
  layer_to_quant,
766
  self.pad_token_id,
767
  self.forward_batch_size,
 
768
  summary_stat=None,
769
  silent=True,
770
  )
 
228
  # load token dictionary (Ensembl IDs:token)
229
  with open(token_dictionary_file, "rb") as f:
230
  self.gene_token_dict = pickle.load(f)
231
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
232
 
233
  self.pad_token_id = self.gene_token_dict.get("<pad>")
234
 
 
561
  layer_to_quant,
562
  self.pad_token_id,
563
  self.forward_batch_size,
564
+ token_gene_dict=self.token_gene_dict,
565
  summary_stat=None,
566
  silent=True,
567
  )
 
581
  layer_to_quant,
582
  self.pad_token_id,
583
  self.forward_batch_size,
584
+ token_gene_dict=self.token_gene_dict,
585
  summary_stat=None,
586
  silent=True,
587
  )
 
741
  layer_to_quant,
742
  self.pad_token_id,
743
  self.forward_batch_size,
744
+ token_gene_dict=self.token_gene_dict,
745
  summary_stat=None,
746
  silent=True,
747
  )
 
769
  layer_to_quant,
770
  self.pad_token_id,
771
  self.forward_batch_size,
772
+ token_gene_dict=self.token_gene_dict,
773
  summary_stat=None,
774
  silent=True,
775
  )