Christina Theodoris commited on
Commit
fd93ebf
1 Parent(s): 5d0082c

Add option for modifying chunk size for anndata tokenizer

Browse files
Files changed (1) hide show
  1. geneformer/tokenizer.py +47 -35
geneformer/tokenizer.py CHANGED
@@ -11,18 +11,16 @@ Optional col (cell) attributes: any other cell metadata can be passed on to the
11
  Usage:
12
  from geneformer import TranscriptomeTokenizer
13
  tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
14
- tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
15
  """
16
 
17
  from __future__ import annotations
18
- from typing import Literal
19
- import pickle
20
- from pathlib import Path
21
 
22
  import logging
23
-
24
  import warnings
25
- warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
 
26
 
27
  import anndata as ad
28
  import loompy as lp
@@ -30,6 +28,7 @@ import numpy as np
30
  import scipy.sparse as sp
31
  from datasets import Dataset
32
 
 
33
  logger = logging.getLogger(__name__)
34
 
35
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
@@ -61,6 +60,7 @@ class TranscriptomeTokenizer:
61
  self,
62
  custom_attr_name_dict=None,
63
  nproc=1,
 
64
  gene_median_file=GENE_MEDIAN_FILE,
65
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
66
  ):
@@ -75,6 +75,8 @@ class TranscriptomeTokenizer:
75
  Values are the names of the attributes in the dataset.
76
  nproc : int
77
  Number of processes to use for dataset mapping.
 
 
78
  gene_median_file : Path
79
  Path to pickle file containing dictionary of non-zero median
80
  gene expression values across Genecorpus-30M.
@@ -87,6 +89,9 @@ class TranscriptomeTokenizer:
87
  # number of processes for dataset mapping
88
  self.nproc = nproc
89
 
 
 
 
90
  # load dictionary of gene normalization factors
91
  # (non-zero median value of expression across Genecorpus-30M)
92
  with open(gene_median_file, "rb") as f:
@@ -111,11 +116,11 @@ class TranscriptomeTokenizer:
111
  use_generator: bool = False,
112
  ):
113
  """
114
- Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
115
 
116
  Parameters
117
  ----------
118
- loom_data_directory : Path
119
  Path to directory containing loom files or anndata files
120
  output_directory : Path
121
  Path to directory where tokenized data will be saved as .dataset
@@ -129,7 +134,9 @@ class TranscriptomeTokenizer:
129
  tokenized_cells, cell_metadata = self.tokenize_files(
130
  Path(data_directory), file_format
131
  )
132
- tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata, use_generator=use_generator)
 
 
133
 
134
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
135
  tokenized_dataset.save_to_disk(output_path)
@@ -140,7 +147,9 @@ class TranscriptomeTokenizer:
140
  tokenized_cells = []
141
  if self.custom_attr_name_dict is not None:
142
  cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
143
- cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
 
 
144
 
145
  # loops through directories to tokenize .loom files
146
  file_found = 0
@@ -155,17 +164,20 @@ class TranscriptomeTokenizer:
155
  tokenized_cells += file_tokenized_cells
156
  if self.custom_attr_name_dict is not None:
157
  for k in cell_attr:
158
- cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
 
 
159
  else:
160
  cell_metadata = None
161
 
162
  if file_found == 0:
163
  logger.error(
164
- f"No .{file_format} files found in directory {data_directory}.")
 
165
  raise
166
  return tokenized_cells, cell_metadata
167
 
168
- def tokenize_anndata(self, adata_file_path, target_sum=10_000, chunk_size=512):
169
  adata = ad.read(adata_file_path, backed="r")
170
 
171
  if self.custom_attr_name_dict is not None:
@@ -195,9 +207,7 @@ class TranscriptomeTokenizer:
195
  var_exists = True
196
 
197
  if var_exists:
198
- filter_pass_loc = np.where(
199
- [i == 1 for i in adata.obs["filter_pass"]]
200
- )[0]
201
  elif not var_exists:
202
  print(
203
  f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
@@ -206,12 +216,12 @@ class TranscriptomeTokenizer:
206
 
207
  tokenized_cells = []
208
 
209
- for i in range(0, len(filter_pass_loc), chunk_size):
210
- idx = filter_pass_loc[i:i+chunk_size]
211
 
212
- n_counts = adata[idx].obs['n_counts'].values[:, None]
213
  X_view = adata[idx, coding_miRNA_loc].X
214
- X_norm = (X_view / n_counts * target_sum / norm_factor_vector)
215
  X_norm = sp.csr_matrix(X_norm)
216
 
217
  tokenized_cells += [
@@ -259,9 +269,7 @@ class TranscriptomeTokenizer:
259
  var_exists = True
260
 
261
  if var_exists:
262
- filter_pass_loc = np.where(
263
- [i == 1 for i in data.ca["filter_pass"]]
264
- )[0]
265
  elif not var_exists:
266
  print(
267
  f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
@@ -270,7 +278,7 @@ class TranscriptomeTokenizer:
270
 
271
  # scan through .loom files and tokenize cells
272
  tokenized_cells = []
273
- for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1):
274
  # select subview with protein-coding and miRNA genes
275
  subview = view.view[coding_miRNA_loc, :]
276
 
@@ -297,7 +305,13 @@ class TranscriptomeTokenizer:
297
 
298
  return tokenized_cells, file_cell_metadata
299
 
300
- def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False, keep_uncropped_input_ids=False):
 
 
 
 
 
 
301
  print("Creating dataset.")
302
  # create dict for dataset creation
303
  dataset_dict = {"input_ids": tokenized_cells}
@@ -306,30 +320,28 @@ class TranscriptomeTokenizer:
306
 
307
  # create dataset
308
  if use_generator:
 
309
  def dict_generator():
310
  for i in range(len(tokenized_cells)):
311
  yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
 
312
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
313
  else:
314
  output_dataset = Dataset.from_dict(dataset_dict)
315
-
316
  def format_cell_features(example):
317
  # Store original uncropped input_ids in separate feature
318
  if keep_uncropped_input_ids:
319
- example['input_ids_uncropped'] = example['input_ids']
320
- example['length_uncropped'] = len(example['input_ids'])
321
 
322
  # Truncate/Crop input_ids to size 2,048
323
- example['input_ids'] = example['input_ids'][0:2048]
324
- example['length'] = len(example['input_ids'])
325
 
326
  return example
327
 
328
  output_dataset_truncated = output_dataset.map(
329
- format_cell_features,
330
- num_proc=self.nproc
331
  )
332
  return output_dataset_truncated
333
-
334
-
335
-
 
11
  Usage:
12
  from geneformer import TranscriptomeTokenizer
13
  tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
14
+ tk.tokenize_data("data_directory", "output_directory", "output_prefix")
15
  """
16
 
17
  from __future__ import annotations
 
 
 
18
 
19
  import logging
20
+ import pickle
21
  import warnings
22
+ from pathlib import Path
23
+ from typing import Literal
24
 
25
  import anndata as ad
26
  import loompy as lp
 
28
  import scipy.sparse as sp
29
  from datasets import Dataset
30
 
31
+ warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
32
  logger = logging.getLogger(__name__)
33
 
34
  GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
 
60
  self,
61
  custom_attr_name_dict=None,
62
  nproc=1,
63
+ chunk_size=512,
64
  gene_median_file=GENE_MEDIAN_FILE,
65
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
66
  ):
 
75
  Values are the names of the attributes in the dataset.
76
  nproc : int
77
  Number of processes to use for dataset mapping.
78
+ chunk_size: int = 512
79
+ Chunk size for anndata tokenizer.
80
  gene_median_file : Path
81
  Path to pickle file containing dictionary of non-zero median
82
  gene expression values across Genecorpus-30M.
 
89
  # number of processes for dataset mapping
90
  self.nproc = nproc
91
 
92
+ # chunk size for anndata tokenizer
93
+ self.chunk_size = chunk_size
94
+
95
  # load dictionary of gene normalization factors
96
  # (non-zero median value of expression across Genecorpus-30M)
97
  with open(gene_median_file, "rb") as f:
 
116
  use_generator: bool = False,
117
  ):
118
  """
119
+ Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
120
 
121
  Parameters
122
  ----------
123
+ data_directory : Path
124
  Path to directory containing loom files or anndata files
125
  output_directory : Path
126
  Path to directory where tokenized data will be saved as .dataset
 
134
  tokenized_cells, cell_metadata = self.tokenize_files(
135
  Path(data_directory), file_format
136
  )
137
+ tokenized_dataset = self.create_dataset(
138
+ tokenized_cells, cell_metadata, use_generator=use_generator
139
+ )
140
 
141
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
142
  tokenized_dataset.save_to_disk(output_path)
 
147
  tokenized_cells = []
148
  if self.custom_attr_name_dict is not None:
149
  cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
150
+ cell_metadata = {
151
+ attr_key: [] for attr_key in self.custom_attr_name_dict.values()
152
+ }
153
 
154
  # loops through directories to tokenize .loom files
155
  file_found = 0
 
164
  tokenized_cells += file_tokenized_cells
165
  if self.custom_attr_name_dict is not None:
166
  for k in cell_attr:
167
+ cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[
168
+ k
169
+ ]
170
  else:
171
  cell_metadata = None
172
 
173
  if file_found == 0:
174
  logger.error(
175
+ f"No .{file_format} files found in directory {data_directory}."
176
+ )
177
  raise
178
  return tokenized_cells, cell_metadata
179
 
180
+ def tokenize_anndata(self, adata_file_path, target_sum=10_000):
181
  adata = ad.read(adata_file_path, backed="r")
182
 
183
  if self.custom_attr_name_dict is not None:
 
207
  var_exists = True
208
 
209
  if var_exists:
210
+ filter_pass_loc = np.where([i == 1 for i in adata.obs["filter_pass"]])[0]
 
 
211
  elif not var_exists:
212
  print(
213
  f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
 
216
 
217
  tokenized_cells = []
218
 
219
+ for i in range(0, len(filter_pass_loc), self.chunk_size):
220
+ idx = filter_pass_loc[i : i + self.chunk_size]
221
 
222
+ n_counts = adata[idx].obs["n_counts"].values[:, None]
223
  X_view = adata[idx, coding_miRNA_loc].X
224
+ X_norm = X_view / n_counts * target_sum / norm_factor_vector
225
  X_norm = sp.csr_matrix(X_norm)
226
 
227
  tokenized_cells += [
 
269
  var_exists = True
270
 
271
  if var_exists:
272
+ filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
 
 
273
  elif not var_exists:
274
  print(
275
  f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
 
278
 
279
  # scan through .loom files and tokenize cells
280
  tokenized_cells = []
281
+ for _ix, _selection, view in data.scan(items=filter_pass_loc, axis=1):
282
  # select subview with protein-coding and miRNA genes
283
  subview = view.view[coding_miRNA_loc, :]
284
 
 
305
 
306
  return tokenized_cells, file_cell_metadata
307
 
308
+ def create_dataset(
309
+ self,
310
+ tokenized_cells,
311
+ cell_metadata,
312
+ use_generator=False,
313
+ keep_uncropped_input_ids=False,
314
+ ):
315
  print("Creating dataset.")
316
  # create dict for dataset creation
317
  dataset_dict = {"input_ids": tokenized_cells}
 
320
 
321
  # create dataset
322
  if use_generator:
323
+
324
  def dict_generator():
325
  for i in range(len(tokenized_cells)):
326
  yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
327
+
328
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
329
  else:
330
  output_dataset = Dataset.from_dict(dataset_dict)
331
+
332
  def format_cell_features(example):
333
  # Store original uncropped input_ids in separate feature
334
  if keep_uncropped_input_ids:
335
+ example["input_ids_uncropped"] = example["input_ids"]
336
+ example["length_uncropped"] = len(example["input_ids"])
337
 
338
  # Truncate/Crop input_ids to size 2,048
339
+ example["input_ids"] = example["input_ids"][0:2048]
340
+ example["length"] = len(example["input_ids"])
341
 
342
  return example
343
 
344
  output_dataset_truncated = output_dataset.map(
345
+ format_cell_features, num_proc=self.nproc
 
346
  )
347
  return output_dataset_truncated