ctheodoris
commited on
Commit
•
ace12e9
1
Parent(s):
3fe35ba
update get_embs with token_gene_dict arg
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -228,6 +228,7 @@ class InSilicoPerturber:
|
|
228 |
# load token dictionary (Ensembl IDs:token)
|
229 |
with open(token_dictionary_file, "rb") as f:
|
230 |
self.gene_token_dict = pickle.load(f)
|
|
|
231 |
|
232 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
233 |
|
@@ -560,6 +561,7 @@ class InSilicoPerturber:
|
|
560 |
layer_to_quant,
|
561 |
self.pad_token_id,
|
562 |
self.forward_batch_size,
|
|
|
563 |
summary_stat=None,
|
564 |
silent=True,
|
565 |
)
|
@@ -579,6 +581,7 @@ class InSilicoPerturber:
|
|
579 |
layer_to_quant,
|
580 |
self.pad_token_id,
|
581 |
self.forward_batch_size,
|
|
|
582 |
summary_stat=None,
|
583 |
silent=True,
|
584 |
)
|
@@ -738,6 +741,7 @@ class InSilicoPerturber:
|
|
738 |
layer_to_quant,
|
739 |
self.pad_token_id,
|
740 |
self.forward_batch_size,
|
|
|
741 |
summary_stat=None,
|
742 |
silent=True,
|
743 |
)
|
@@ -765,6 +769,7 @@ class InSilicoPerturber:
|
|
765 |
layer_to_quant,
|
766 |
self.pad_token_id,
|
767 |
self.forward_batch_size,
|
|
|
768 |
summary_stat=None,
|
769 |
silent=True,
|
770 |
)
|
|
|
228 |
# load token dictionary (Ensembl IDs:token)
|
229 |
with open(token_dictionary_file, "rb") as f:
|
230 |
self.gene_token_dict = pickle.load(f)
|
231 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
232 |
|
233 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
234 |
|
|
|
561 |
layer_to_quant,
|
562 |
self.pad_token_id,
|
563 |
self.forward_batch_size,
|
564 |
+
token_gene_dict=self.token_gene_dict,
|
565 |
summary_stat=None,
|
566 |
silent=True,
|
567 |
)
|
|
|
581 |
layer_to_quant,
|
582 |
self.pad_token_id,
|
583 |
self.forward_batch_size,
|
584 |
+
token_gene_dict=self.token_gene_dict,
|
585 |
summary_stat=None,
|
586 |
silent=True,
|
587 |
)
|
|
|
741 |
layer_to_quant,
|
742 |
self.pad_token_id,
|
743 |
self.forward_batch_size,
|
744 |
+
token_gene_dict=self.token_gene_dict,
|
745 |
summary_stat=None,
|
746 |
silent=True,
|
747 |
)
|
|
|
769 |
layer_to_quant,
|
770 |
self.pad_token_id,
|
771 |
self.forward_batch_size,
|
772 |
+
token_gene_dict=self.token_gene_dict,
|
773 |
summary_stat=None,
|
774 |
silent=True,
|
775 |
)
|