Addressed issues for tokenizer, anndata tokenizer now uses a fraction of memory
Browse files- geneformer/tokenizer.py +46 -30
geneformer/tokenizer.py
CHANGED
@@ -27,6 +27,7 @@ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
|
|
27 |
import anndata as ad
|
28 |
import loompy as lp
|
29 |
import numpy as np
|
|
|
30 |
from datasets import Dataset
|
31 |
|
32 |
logger = logging.getLogger(__name__)
|
@@ -35,6 +36,15 @@ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
|
35 |
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def tokenize_cell(gene_vector, gene_tokens):
|
39 |
"""
|
40 |
Convert normalized gene expression vector to tokenized rank value encoding.
|
@@ -42,11 +52,8 @@ def tokenize_cell(gene_vector, gene_tokens):
|
|
42 |
# create array of gene vector with token indices
|
43 |
# mask undetected genes
|
44 |
nonzero_mask = np.nonzero(gene_vector)[0]
|
45 |
-
#
|
46 |
-
|
47 |
-
# tokenize
|
48 |
-
sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
|
49 |
-
return sentence_tokens
|
50 |
|
51 |
|
52 |
class TranscriptomeTokenizer:
|
@@ -101,6 +108,7 @@ class TranscriptomeTokenizer:
|
|
101 |
output_directory: Path | str,
|
102 |
output_prefix: str,
|
103 |
file_format: Literal["loom", "h5ad"] = "loom",
|
|
|
104 |
):
|
105 |
"""
|
106 |
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
|
@@ -115,11 +123,13 @@ class TranscriptomeTokenizer:
|
|
115 |
Prefix for output .dataset
|
116 |
file_format : str
|
117 |
Format of input files. Can be "loom" or "h5ad".
|
|
|
|
|
118 |
"""
|
119 |
tokenized_cells, cell_metadata = self.tokenize_files(
|
120 |
Path(data_directory), file_format
|
121 |
)
|
122 |
-
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
|
123 |
|
124 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
125 |
tokenized_dataset.save_to_disk(output_path)
|
@@ -129,7 +139,7 @@ class TranscriptomeTokenizer:
|
|
129 |
):
|
130 |
tokenized_cells = []
|
131 |
if self.custom_attr_name_dict is not None:
|
132 |
-
|
133 |
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
|
134 |
|
135 |
# loops through directories to tokenize .loom files
|
@@ -144,7 +154,7 @@ class TranscriptomeTokenizer:
|
|
144 |
file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
|
145 |
tokenized_cells += file_tokenized_cells
|
146 |
if self.custom_attr_name_dict is not None:
|
147 |
-
for k in
|
148 |
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
|
149 |
else:
|
150 |
cell_metadata = None
|
@@ -155,8 +165,8 @@ class TranscriptomeTokenizer:
|
|
155 |
raise
|
156 |
return tokenized_cells, cell_metadata
|
157 |
|
158 |
-
def tokenize_anndata(self, adata_file_path):
|
159 |
-
adata = ad.read(adata_file_path)
|
160 |
file_cell_metadata = {
|
161 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
162 |
}
|
@@ -176,7 +186,7 @@ class TranscriptomeTokenizer:
|
|
176 |
)
|
177 |
|
178 |
try:
|
179 |
-
adata.obs["filter_pass"]
|
180 |
except KeyError:
|
181 |
var_exists = False
|
182 |
else:
|
@@ -193,24 +203,26 @@ class TranscriptomeTokenizer:
|
|
193 |
filter_pass_loc = np.array([i for i in range(adata.shape[0])])
|
194 |
|
195 |
tokenized_cells = []
|
196 |
-
adata_filter = adata[
|
197 |
-
filter_pass_loc, coding_miRNA_loc # filter cells and genes
|
198 |
-
]
|
199 |
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
|
211 |
return tokenized_cells, file_cell_metadata
|
212 |
|
213 |
-
def tokenize_file(self, loom_file_path):
|
214 |
if self.custom_attr_name_dict is not None:
|
215 |
file_cell_metadata = {
|
216 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
@@ -261,7 +273,7 @@ class TranscriptomeTokenizer:
|
|
261 |
subview_norm_array = (
|
262 |
subview[:, :]
|
263 |
/ subview.ca.n_counts
|
264 |
-
*
|
265 |
/ norm_factor_vector[:, None]
|
266 |
)
|
267 |
# tokenize subview gene vectors
|
@@ -279,21 +291,25 @@ class TranscriptomeTokenizer:
|
|
279 |
|
280 |
return tokenized_cells, file_cell_metadata
|
281 |
|
282 |
-
def create_dataset(self, tokenized_cells, cell_metadata):
|
|
|
283 |
# create dict for dataset creation
|
284 |
dataset_dict = {"input_ids": tokenized_cells}
|
285 |
if self.custom_attr_name_dict is not None:
|
286 |
dataset_dict.update(cell_metadata)
|
287 |
|
288 |
# create dataset
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
293 |
|
294 |
# truncate dataset
|
295 |
def truncate(example):
|
296 |
-
example["input_ids"] = example["input_ids"][
|
297 |
return example
|
298 |
|
299 |
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|
|
|
27 |
import anndata as ad
|
28 |
import loompy as lp
|
29 |
import numpy as np
|
30 |
+
import scipy.sparse as sp
|
31 |
from datasets import Dataset
|
32 |
|
33 |
logger = logging.getLogger(__name__)
|
|
|
36 |
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
37 |
|
38 |
|
39 |
+
def rank_genes(gene_vector, gene_tokens):
|
40 |
+
"""
|
41 |
+
Rank gene expression vector.
|
42 |
+
"""
|
43 |
+
# sort by median-scaled gene values
|
44 |
+
sorted_indices = np.argsort(-gene_vector)
|
45 |
+
return gene_tokens[sorted_indices]
|
46 |
+
|
47 |
+
|
48 |
def tokenize_cell(gene_vector, gene_tokens):
|
49 |
"""
|
50 |
Convert normalized gene expression vector to tokenized rank value encoding.
|
|
|
52 |
# create array of gene vector with token indices
|
53 |
# mask undetected genes
|
54 |
nonzero_mask = np.nonzero(gene_vector)[0]
|
55 |
+
# rank by median-scaled gene values
|
56 |
+
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
class TranscriptomeTokenizer:
|
|
|
108 |
output_directory: Path | str,
|
109 |
output_prefix: str,
|
110 |
file_format: Literal["loom", "h5ad"] = "loom",
|
111 |
+
use_generator: bool = False,
|
112 |
):
|
113 |
"""
|
114 |
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
|
|
|
123 |
Prefix for output .dataset
|
124 |
file_format : str
|
125 |
Format of input files. Can be "loom" or "h5ad".
|
126 |
+
use_generator : bool
|
127 |
+
Whether to use generator or dict for tokenization.
|
128 |
"""
|
129 |
tokenized_cells, cell_metadata = self.tokenize_files(
|
130 |
Path(data_directory), file_format
|
131 |
)
|
132 |
+
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata, use_generator=use_generator)
|
133 |
|
134 |
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
135 |
tokenized_dataset.save_to_disk(output_path)
|
|
|
139 |
):
|
140 |
tokenized_cells = []
|
141 |
if self.custom_attr_name_dict is not None:
|
142 |
+
cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
|
143 |
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
|
144 |
|
145 |
# loops through directories to tokenize .loom files
|
|
|
154 |
file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
|
155 |
tokenized_cells += file_tokenized_cells
|
156 |
if self.custom_attr_name_dict is not None:
|
157 |
+
for k in cell_attr:
|
158 |
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
|
159 |
else:
|
160 |
cell_metadata = None
|
|
|
165 |
raise
|
166 |
return tokenized_cells, cell_metadata
|
167 |
|
168 |
+
def tokenize_anndata(self, adata_file_path, target_sum=10_000, chunk_size=512):
|
169 |
+
adata = ad.read(adata_file_path, backed="r")
|
170 |
file_cell_metadata = {
|
171 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
172 |
}
|
|
|
186 |
)
|
187 |
|
188 |
try:
|
189 |
+
_ = adata.obs["filter_pass"]
|
190 |
except KeyError:
|
191 |
var_exists = False
|
192 |
else:
|
|
|
203 |
filter_pass_loc = np.array([i for i in range(adata.shape[0])])
|
204 |
|
205 |
tokenized_cells = []
|
|
|
|
|
|
|
206 |
|
207 |
+
for i in range(0, len(filter_pass_loc), chunk_size):
|
208 |
+
idx = filter_pass_loc[i:i+chunk_size]
|
209 |
+
X = adata[idx].X
|
210 |
+
|
211 |
+
X_norm = (X / X[:, coding_miRNA_loc].sum(axis=1) * target_sum / norm_factor_vector)
|
212 |
+
X_norm = sp.csr_matrix(X_norm)
|
213 |
|
214 |
+
tokenized_cells += [
|
215 |
+
rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
|
216 |
+
for i in range(X_norm.shape[0])
|
217 |
+
]
|
218 |
|
219 |
+
# add custom attributes for subview to dict
|
220 |
+
for k in file_cell_metadata.keys():
|
221 |
+
file_cell_metadata[k] += adata[idx].obs[k].tolist()
|
222 |
|
223 |
return tokenized_cells, file_cell_metadata
|
224 |
|
225 |
+
def tokenize_file(self, loom_file_path, target_sum=10_000):
|
226 |
if self.custom_attr_name_dict is not None:
|
227 |
file_cell_metadata = {
|
228 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
|
|
273 |
subview_norm_array = (
|
274 |
subview[:, :]
|
275 |
/ subview.ca.n_counts
|
276 |
+
* target_sum
|
277 |
/ norm_factor_vector[:, None]
|
278 |
)
|
279 |
# tokenize subview gene vectors
|
|
|
291 |
|
292 |
return tokenized_cells, file_cell_metadata
|
293 |
|
294 |
+
def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False):
|
295 |
+
print("Creating dataset...")
|
296 |
# create dict for dataset creation
|
297 |
dataset_dict = {"input_ids": tokenized_cells}
|
298 |
if self.custom_attr_name_dict is not None:
|
299 |
dataset_dict.update(cell_metadata)
|
300 |
|
301 |
# create dataset
|
302 |
+
if use_generator:
|
303 |
+
def dict_generator():
|
304 |
+
for i in range(len(tokenized_cells)):
|
305 |
+
yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
|
306 |
+
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
|
307 |
+
else:
|
308 |
+
output_dataset = Dataset.from_dict(dataset_dict)
|
309 |
|
310 |
# truncate dataset
|
311 |
def truncate(example):
|
312 |
+
example["input_ids"] = example["input_ids"][:2048]
|
313 |
return example
|
314 |
|
315 |
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|