Christina Theodoris
commited on
Commit
•
624349c
1
Parent(s):
4302f48
Add option to output embs as tensor
Browse files
examples/extract_and_plot_cell_embeddings.ipynb
CHANGED
@@ -29,6 +29,7 @@
|
|
29 |
" nproc=16)\n",
|
30 |
"\n",
|
31 |
"# extracts embedding from input data\n",
|
|
|
32 |
"# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
33 |
"embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
|
34 |
" \"path/to/input_data/\",\n",
|
|
|
29 |
" nproc=16)\n",
|
30 |
"\n",
|
31 |
"# extracts embedding from input data\n",
|
32 |
+
"# input data is tokenized rank value encodings generated by Geneformer tokenizer (see tokenizing_scRNAseq_data.ipynb)\n",
|
33 |
"# example dataset: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
|
34 |
"embs = embex.extract_embs(\"../fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224\",\n",
|
35 |
" \"path/to/input_data/\",\n",
|
geneformer/emb_extractor.py
CHANGED
@@ -40,7 +40,7 @@ import seaborn as sns
|
|
40 |
import torch
|
41 |
from collections import Counter
|
42 |
from pathlib import Path
|
43 |
-
from tqdm.
|
44 |
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
45 |
|
46 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
@@ -64,7 +64,6 @@ def get_embs(model,
|
|
64 |
pad_token_id,
|
65 |
forward_batch_size,
|
66 |
summary_stat):
|
67 |
-
|
68 |
model_input_size = get_model_input_size(model)
|
69 |
total_batch_length = len(filtered_input_data)
|
70 |
|
@@ -138,7 +137,7 @@ def test_emb(model, example, layer_to_quant):
|
|
138 |
return embs_test.size()[2]
|
139 |
|
140 |
def label_embs(embs, downsampled_data, emb_labels):
|
141 |
-
embs_df = pd.DataFrame(embs.cpu())
|
142 |
if emb_labels is not None:
|
143 |
for label in emb_labels:
|
144 |
emb_label = downsampled_data[label]
|
@@ -367,7 +366,8 @@ class EmbExtractor:
|
|
367 |
model_directory,
|
368 |
input_data_file,
|
369 |
output_directory,
|
370 |
-
output_prefix
|
|
|
371 |
"""
|
372 |
Extract embeddings from input data and save as results in output_directory.
|
373 |
|
@@ -381,6 +381,9 @@ class EmbExtractor:
|
|
381 |
Path to directory where embedding data will be saved as csv
|
382 |
output_prefix : str
|
383 |
Prefix for output file
|
|
|
|
|
|
|
384 |
"""
|
385 |
|
386 |
filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
|
@@ -398,13 +401,16 @@ class EmbExtractor:
|
|
398 |
if self.summary_stat is None:
|
399 |
embs_df = label_embs(embs, downsampled_data, self.emb_label)
|
400 |
elif self.summary_stat is not None:
|
401 |
-
embs_df = pd.DataFrame(embs.cpu()).T
|
402 |
|
403 |
# save embeddings to output_path
|
404 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
405 |
embs_df.to_csv(output_path)
|
406 |
-
|
407 |
-
|
|
|
|
|
|
|
408 |
|
409 |
def plot_embs(self,
|
410 |
embs,
|
|
|
40 |
import torch
|
41 |
from collections import Counter
|
42 |
from pathlib import Path
|
43 |
+
from tqdm.auto import trange
|
44 |
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
45 |
|
46 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
|
|
64 |
pad_token_id,
|
65 |
forward_batch_size,
|
66 |
summary_stat):
|
|
|
67 |
model_input_size = get_model_input_size(model)
|
68 |
total_batch_length = len(filtered_input_data)
|
69 |
|
|
|
137 |
return embs_test.size()[2]
|
138 |
|
139 |
def label_embs(embs, downsampled_data, emb_labels):
|
140 |
+
embs_df = pd.DataFrame(embs.cpu().numpy())
|
141 |
if emb_labels is not None:
|
142 |
for label in emb_labels:
|
143 |
emb_label = downsampled_data[label]
|
|
|
366 |
model_directory,
|
367 |
input_data_file,
|
368 |
output_directory,
|
369 |
+
output_prefix,
|
370 |
+
output_torch_embs=False):
|
371 |
"""
|
372 |
Extract embeddings from input data and save as results in output_directory.
|
373 |
|
|
|
381 |
Path to directory where embedding data will be saved as csv
|
382 |
output_prefix : str
|
383 |
Prefix for output file
|
384 |
+
output_torch_embs : bool
|
385 |
+
Whether or not to also output the embeddings as a tensor.
|
386 |
+
Note, if true, will output embeddings as both dataframe and tensor.
|
387 |
"""
|
388 |
|
389 |
filtered_input_data = load_and_filter(self.filter_data, self.nproc, input_data_file)
|
|
|
401 |
if self.summary_stat is None:
|
402 |
embs_df = label_embs(embs, downsampled_data, self.emb_label)
|
403 |
elif self.summary_stat is not None:
|
404 |
+
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
405 |
|
406 |
# save embeddings to output_path
|
407 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
|
408 |
embs_df.to_csv(output_path)
|
409 |
+
|
410 |
+
if output_torch_embs == True:
|
411 |
+
return embs_df, embs
|
412 |
+
else:
|
413 |
+
return embs_df
|
414 |
|
415 |
def plot_embs(self,
|
416 |
embs,
|
geneformer/in_silico_perturber.py
CHANGED
@@ -34,7 +34,7 @@ import seaborn as sns; sns.set()
|
|
34 |
import torch
|
35 |
from collections import defaultdict
|
36 |
from datasets import Dataset, load_from_disk
|
37 |
-
from tqdm.
|
38 |
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
39 |
|
40 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
|
|
34 |
import torch
|
35 |
from collections import defaultdict
|
36 |
from datasets import Dataset, load_from_disk
|
37 |
+
from tqdm.auto import trange
|
38 |
from transformers import BertForMaskedLM, BertForTokenClassification, BertForSequenceClassification
|
39 |
|
40 |
from .tokenizer import TOKEN_DICTIONARY_FILE
|
geneformer/in_silico_perturber_stats.py
CHANGED
@@ -27,7 +27,7 @@ import statsmodels.stats.multitest as smt
|
|
27 |
from pathlib import Path
|
28 |
from scipy.stats import ranksums
|
29 |
from sklearn.mixture import GaussianMixture
|
30 |
-
from tqdm.
|
31 |
|
32 |
from .in_silico_perturber import flatten_list
|
33 |
|
|
|
27 |
from pathlib import Path
|
28 |
from scipy.stats import ranksums
|
29 |
from sklearn.mixture import GaussianMixture
|
30 |
+
from tqdm.auto import trange, tqdm
|
31 |
|
32 |
from .in_silico_perturber import flatten_list
|
33 |
|
setup.py
CHANGED
@@ -16,6 +16,7 @@ setup(
|
|
16 |
"datasets",
|
17 |
"loompy",
|
18 |
"numpy",
|
|
|
19 |
"transformers",
|
20 |
],
|
21 |
)
|
|
|
16 |
"datasets",
|
17 |
"loompy",
|
18 |
"numpy",
|
19 |
+
"tdigest",
|
20 |
"transformers",
|
21 |
],
|
22 |
)
|