ctheodoris
commited on
Commit
•
ea428cb
1
Parent(s):
eb2a04b
move dicts to init
Browse files- geneformer/__init__.py +7 -1
- geneformer/classifier.py +1 -1
- geneformer/collator_for_classification.py +6 -1
- geneformer/emb_extractor.py +1 -1
- geneformer/evaluation_utils.py +1 -1
- geneformer/in_silico_perturber.py +3 -7
- geneformer/in_silico_perturber_stats.py +2 -4
- geneformer/perturber_utils.py +1 -5
- geneformer/pretrainer.py +1 -1
- geneformer/tokenizer.py +1 -1
geneformer/__init__.py
CHANGED
@@ -1,4 +1,10 @@
|
|
1 |
# ruff: noqa: F401
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from . import (
|
3 |
collator_for_classification,
|
4 |
emb_extractor,
|
@@ -18,4 +24,4 @@ from .pretrainer import GeneformerPretrainer
|
|
18 |
from .tokenizer import TranscriptomeTokenizer
|
19 |
|
20 |
from . import classifier # noqa # isort:skip
|
21 |
-
from .classifier import Classifier # noqa # isort:skip
|
|
|
1 |
# ruff: noqa: F401
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
5 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
6 |
+
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
7 |
+
|
8 |
from . import (
|
9 |
collator_for_classification,
|
10 |
emb_extractor,
|
|
|
24 |
from .tokenizer import TranscriptomeTokenizer
|
25 |
|
26 |
from . import classifier # noqa # isort:skip
|
27 |
+
from .classifier import Classifier # noqa # isort:skip
|
geneformer/classifier.py
CHANGED
@@ -61,7 +61,7 @@ from . import DataCollatorForCellClassification, DataCollatorForGeneClassificati
|
|
61 |
from . import classifier_utils as cu
|
62 |
from . import evaluation_utils as eu
|
63 |
from . import perturber_utils as pu
|
64 |
-
from .
|
65 |
|
66 |
sns.set()
|
67 |
|
|
|
61 |
from . import classifier_utils as cu
|
62 |
from . import evaluation_utils as eu
|
63 |
from . import perturber_utils as pu
|
64 |
+
from . import TOKEN_DICTIONARY_FILE
|
65 |
|
66 |
sns.set()
|
67 |
|
geneformer/collator_for_classification.py
CHANGED
@@ -4,6 +4,7 @@ Geneformer collator for gene and cell classification.
|
|
4 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
5 |
"""
|
6 |
import numpy as np
|
|
|
7 |
import torch
|
8 |
import warnings
|
9 |
from enum import Enum
|
@@ -17,7 +18,11 @@ from transformers import (
|
|
17 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
18 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
19 |
|
20 |
-
from .
|
|
|
|
|
|
|
|
|
21 |
|
22 |
EncodedInput = List[int]
|
23 |
logger = logging.get_logger(__name__)
|
|
|
4 |
Huggingface data collator modified to accommodate single-cell transcriptomics data for gene and cell classification.
|
5 |
"""
|
6 |
import numpy as np
|
7 |
+
import pickle
|
8 |
import torch
|
9 |
import warnings
|
10 |
from enum import Enum
|
|
|
18 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
19 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
20 |
|
21 |
+
from . import TOKEN_DICTIONARY_FILE
|
22 |
+
|
23 |
+
# load token dictionary (Ensembl IDs:token)
|
24 |
+
with open(TOKEN_DICTIONARY_FILE, "rb") as f:
|
25 |
+
token_dictionary = pickle.load(f)
|
26 |
|
27 |
EncodedInput = List[int]
|
28 |
logger = logging.get_logger(__name__)
|
geneformer/emb_extractor.py
CHANGED
@@ -25,7 +25,7 @@ from tdigest import TDigest
|
|
25 |
from tqdm.auto import trange
|
26 |
|
27 |
from . import perturber_utils as pu
|
28 |
-
from .
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
|
|
25 |
from tqdm.auto import trange
|
26 |
|
27 |
from . import perturber_utils as pu
|
28 |
+
from . import TOKEN_DICTIONARY_FILE
|
29 |
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
geneformer/evaluation_utils.py
CHANGED
@@ -21,7 +21,7 @@ from sklearn.metrics import (
|
|
21 |
from tqdm.auto import trange
|
22 |
|
23 |
from .emb_extractor import make_colorbar
|
24 |
-
from .
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
|
|
21 |
from tqdm.auto import trange
|
22 |
|
23 |
from .emb_extractor import make_colorbar
|
24 |
+
from . import TOKEN_DICTIONARY_FILE
|
25 |
|
26 |
logger = logging.getLogger(__name__)
|
27 |
|
geneformer/in_silico_perturber.py
CHANGED
@@ -38,21 +38,17 @@ import logging
|
|
38 |
import os
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
41 |
-
from typing import List
|
42 |
from multiprocess import set_start_method
|
43 |
|
44 |
-
import seaborn as sns
|
45 |
import torch
|
46 |
-
from datasets import Dataset
|
47 |
from tqdm.auto import trange
|
48 |
|
49 |
from . import perturber_utils as pu
|
50 |
from .emb_extractor import get_embs
|
51 |
-
from .
|
52 |
-
|
53 |
-
|
54 |
-
sns.set()
|
55 |
|
|
|
56 |
|
57 |
logger = logging.getLogger(__name__)
|
58 |
|
|
|
38 |
import os
|
39 |
import pickle
|
40 |
from collections import defaultdict
|
|
|
41 |
from multiprocess import set_start_method
|
42 |
|
|
|
43 |
import torch
|
44 |
+
from datasets import Dataset, disable_progress_bars
|
45 |
from tqdm.auto import trange
|
46 |
|
47 |
from . import perturber_utils as pu
|
48 |
from .emb_extractor import get_embs
|
49 |
+
from . import TOKEN_DICTIONARY_FILE
|
|
|
|
|
|
|
50 |
|
51 |
+
disable_progress_bars()
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -38,9 +38,7 @@ from sklearn.mixture import GaussianMixture
|
|
38 |
from tqdm.auto import tqdm, trange
|
39 |
|
40 |
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
41 |
-
from .
|
42 |
-
|
43 |
-
GENE_NAME_ID_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
44 |
|
45 |
logger = logging.getLogger(__name__)
|
46 |
|
@@ -673,7 +671,7 @@ class InSilicoPerturberStats:
|
|
673 |
cell_states_to_model=None,
|
674 |
pickle_suffix="_raw.pickle",
|
675 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
676 |
-
gene_name_id_dictionary_file=
|
677 |
):
|
678 |
"""
|
679 |
Initialize in silico perturber stats generator.
|
|
|
38 |
from tqdm.auto import tqdm, trange
|
39 |
|
40 |
from .perturber_utils import flatten_list, validate_cell_states_to_model
|
41 |
+
from . import TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
|
|
|
|
|
42 |
|
43 |
logger = logging.getLogger(__name__)
|
44 |
|
|
|
671 |
cell_states_to_model=None,
|
672 |
pickle_suffix="_raw.pickle",
|
673 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
674 |
+
gene_name_id_dictionary_file=ENSEMBL_DICTIONARY_FILE,
|
675 |
):
|
676 |
"""
|
677 |
Initialize in silico perturber stats generator.
|
geneformer/perturber_utils.py
CHANGED
@@ -18,13 +18,9 @@ from transformers import (
|
|
18 |
BertForTokenClassification,
|
19 |
)
|
20 |
|
21 |
-
|
22 |
-
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
23 |
-
ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl"
|
24 |
|
25 |
|
26 |
-
sns.set()
|
27 |
-
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
30 |
|
|
|
18 |
BertForTokenClassification,
|
19 |
)
|
20 |
|
21 |
+
from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE
|
|
|
|
|
22 |
|
23 |
|
|
|
|
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
|
geneformer/pretrainer.py
CHANGED
@@ -32,7 +32,7 @@ from transformers.training_args import ParallelMode
|
|
32 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
33 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
34 |
|
35 |
-
from .
|
36 |
|
37 |
logger = logging.get_logger(__name__)
|
38 |
EncodedInput = List[int]
|
|
|
32 |
from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj
|
33 |
from transformers.utils.generic import _is_tensorflow, _is_torch
|
34 |
|
35 |
+
from . import TOKEN_DICTIONARY_FILE
|
36 |
|
37 |
logger = logging.get_logger(__name__)
|
38 |
EncodedInput = List[int]
|
geneformer/tokenizer.py
CHANGED
@@ -52,7 +52,7 @@ import loompy as lp # noqa
|
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
55 |
-
from .
|
56 |
|
57 |
|
58 |
def rank_genes(gene_vector, gene_tokens):
|
|
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
55 |
+
from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
56 |
|
57 |
|
58 |
def rank_genes(gene_vector, gene_tokens):
|