ctheodoris ricomnl commited on
Commit
4302f48
1 Parent(s): 39ab62e

anndata_tokenizer (#170)

Browse files

- Added anndata tokenizer and switched to Dataset.from_generator (b6ca56647a16de97f6f95785eb76cb1a965f9960)
- Generalized (5cb733f4fed85e1568ba9671e5b1f9babd7b8491)
- Addressed issues for tokenizer, anndata tokenizer now uses a fraction of memory (b24676d0956565ff7b4f8b57d1e936ad17f8740e)
- Fixed issues (94e8d231c41a89232a55c064e5c8ea95edd51d2e)
- Fixed issue, loom and h5ad now produce same checksums (9c62f4c6e3ab0afd1943079306655bd20eba2246)
- Rename loom tokenizer function and modify example notebook for adata. (4fdb85061180ccf022ab761c2557be6ba67d848e)


Co-authored-by: Rico Meinl <[email protected]>

examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -7,7 +7,7 @@
7
  "tags": []
8
  },
9
  "source": [
10
- "## Tokenizing .loom single cell RNA-seq data to rank value encoding .dataset format"
11
  ]
12
  },
13
  {
@@ -15,15 +15,17 @@
15
  "id": "350e6252-b783-494b-9767-f087eb868a15",
16
  "metadata": {},
17
  "source": [
18
- "#### Input data is a directory with .loom files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. \n",
19
  "\n",
20
- "#### Genes should be labeled with Ensembl IDs (row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (column attribute \"n_counts\") to be used for normalization.\n",
 
 
21
  "\n",
22
  "#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
23
  "\n",
24
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
25
  "\n",
26
- "#### If one's data is in other formats besides .loom, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom format prior to running the transcriptome tokenizer."
27
  ]
28
  },
29
  {
@@ -43,8 +45,11 @@
43
  "metadata": {},
44
  "outputs": [],
45
  "source": [
46
- "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ_major\"}, nproc=4)\n",
47
- "tk.tokenize_data(\"loom_data_directory\", \"output_directory\", \"output_prefix\")"
 
 
 
48
  ]
49
  }
50
  ],
 
7
  "tags": []
8
  },
9
  "source": [
10
+ "## Tokenizing .loom or .h5ad single cell RNA-seq data to rank value encoding .dataset format"
11
  ]
12
  },
13
  {
 
15
  "id": "350e6252-b783-494b-9767-f087eb868a15",
16
  "metadata": {},
17
  "source": [
18
+ "#### Input data is a directory with .loom or .h5ad files containing raw counts from single cell RNAseq data, including all genes detected in the transcriptome without feature selection. The input file type is specified by the argument file_format in the tokenize_data function.\n",
19
  "\n",
20
+ "#### The discussion below references the .loom file format, but the analagous labels are required for .h5ad files, just that they will be column instead of row attributes and vice versa due to the transposed format of the two file types.\n",
21
+ "\n",
22
+ "#### Genes should be labeled with Ensembl IDs (loom row attribute \"ensembl_id\"), which provide a unique identifer for conversion to tokens. Other forms of gene annotations (e.g. gene names) can be converted to Ensembl IDs via Ensembl Biomart. Cells should be labeled with the total read count in the cell (loom column attribute \"n_counts\") to be used for normalization.\n",
23
  "\n",
24
  "#### No cell metadata is required, but custom cell attributes may be passed onto the tokenized dataset by providing a dictionary of custom attributes to be added, which is formatted as loom_col_attr_name : desired_dataset_col_attr_name. For example, if the original .loom dataset has column attributes \"cell_type\" and \"organ_major\" and one would like to retain these attributes as labels in the tokenized dataset with the new names \"cell_type\" and \"organ\", respectively, the following custom attribute dictionary should be provided: {\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}. \n",
25
  "\n",
26
  "#### Additionally, if the original .loom file contains a cell column attribute called \"filter_pass\", this column will be used as a binary indicator of whether to include these cells in the tokenized data. All cells with \"1\" in this attribute will be tokenized, whereas the others will be excluded. One may use this column to indicate QC filtering or other criteria for selection for inclusion in the final tokenized dataset.\n",
27
  "\n",
28
+ "#### If one's data is in other formats besides .loom or .h5ad, one can use the relevant tools (such as Anndata tools) to convert the file to a .loom or .h5ad format prior to running the transcriptome tokenizer."
29
  ]
30
  },
31
  {
 
45
  "metadata": {},
46
  "outputs": [],
47
  "source": [
48
+ "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ\"}, nproc=16)\n",
49
+ "tk.tokenize_data(\"loom_data_directory\", \n",
50
+ " \"output_directory\", \n",
51
+ " \"output_prefix\", \n",
52
+ " file_format=\"loom\")"
53
  ]
54
  }
55
  ],
geneformer/tokenizer.py CHANGED
@@ -14,6 +14,8 @@ Usage:
14
  tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
15
  """
16
 
 
 
17
  import pickle
18
  from pathlib import Path
19
 
@@ -22,8 +24,10 @@ import logging
22
  import warnings
23
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
24
 
 
25
  import loompy as lp
26
  import numpy as np
 
27
  from datasets import Dataset
28
 
29
  logger = logging.getLogger(__name__)
@@ -32,6 +36,15 @@ GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
32
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
33
 
34
 
 
 
 
 
 
 
 
 
 
35
  def tokenize_cell(gene_vector, gene_tokens):
36
  """
37
  Convert normalized gene expression vector to tokenized rank value encoding.
@@ -39,11 +52,8 @@ def tokenize_cell(gene_vector, gene_tokens):
39
  # create array of gene vector with token indices
40
  # mask undetected genes
41
  nonzero_mask = np.nonzero(gene_vector)[0]
42
- # sort by median-scaled gene values
43
- sorted_indices = np.argsort(-gene_vector[nonzero_mask])
44
- # tokenize
45
- sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
46
- return sentence_tokens
47
 
48
 
49
  class TranscriptomeTokenizer:
@@ -92,53 +102,133 @@ class TranscriptomeTokenizer:
92
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
93
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
94
 
95
- def tokenize_data(self, loom_data_directory, output_directory, output_prefix):
 
 
 
 
 
 
 
96
  """
97
  Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
98
 
99
  Parameters
100
  ----------
101
  loom_data_directory : Path
102
- Path to directory containing loom files
103
  output_directory : Path
104
  Path to directory where tokenized data will be saved as .dataset
105
  output_prefix : str
106
  Prefix for output .dataset
 
 
 
 
107
  """
108
- tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory))
109
- tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
 
 
110
 
111
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
112
  tokenized_dataset.save_to_disk(output_path)
113
 
114
- def tokenize_files(self, loom_data_directory):
 
 
115
  tokenized_cells = []
116
  if self.custom_attr_name_dict is not None:
117
- loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
118
  cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
119
 
120
  # loops through directories to tokenize .loom files
121
  file_found = 0
122
- for loom_file_path in loom_data_directory.glob("*.loom"):
 
 
 
 
123
  file_found = 1
124
- print(f"Tokenizing {loom_file_path}")
125
- file_tokenized_cells, file_cell_metadata = self.tokenize_file(
126
- loom_file_path
127
- )
128
  tokenized_cells += file_tokenized_cells
129
  if self.custom_attr_name_dict is not None:
130
- for k in loom_cell_attr:
131
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
132
  else:
133
  cell_metadata = None
134
 
135
  if file_found == 0:
136
  logger.error(
137
- f"No .loom files found in directory {loom_data_directory}.")
138
  raise
139
  return tokenized_cells, cell_metadata
140
 
141
- def tokenize_file(self, loom_file_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  if self.custom_attr_name_dict is not None:
143
  file_cell_metadata = {
144
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
@@ -168,11 +258,11 @@ class TranscriptomeTokenizer:
168
  else:
169
  var_exists = True
170
 
171
- if var_exists is True:
172
  filter_pass_loc = np.where(
173
- [True if i == 1 else False for i in data.ca["filter_pass"]]
174
  )[0]
175
- elif var_exists is False:
176
  print(
177
  f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
178
  )
@@ -189,7 +279,7 @@ class TranscriptomeTokenizer:
189
  subview_norm_array = (
190
  subview[:, :]
191
  / subview.ca.n_counts
192
- * 10_000
193
  / norm_factor_vector[:, None]
194
  )
195
  # tokenize subview gene vectors
@@ -207,18 +297,25 @@ class TranscriptomeTokenizer:
207
 
208
  return tokenized_cells, file_cell_metadata
209
 
210
- def create_dataset(self, tokenized_cells, cell_metadata):
 
211
  # create dict for dataset creation
212
  dataset_dict = {"input_ids": tokenized_cells}
213
  if self.custom_attr_name_dict is not None:
214
  dataset_dict.update(cell_metadata)
215
 
216
  # create dataset
217
- output_dataset = Dataset.from_dict(dataset_dict)
 
 
 
 
 
 
218
 
219
  # truncate dataset
220
  def truncate(example):
221
- example["input_ids"] = example["input_ids"][0:2048]
222
  return example
223
 
224
  output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
@@ -232,4 +329,4 @@ class TranscriptomeTokenizer:
232
  measure_length, num_proc=self.nproc
233
  )
234
 
235
- return output_dataset_truncated_w_length
 
14
  tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
15
  """
16
 
17
+ from __future__ import annotations
18
+ from typing import Literal
19
  import pickle
20
  from pathlib import Path
21
 
 
24
  import warnings
25
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
26
 
27
+ import anndata as ad
28
  import loompy as lp
29
  import numpy as np
30
+ import scipy.sparse as sp
31
  from datasets import Dataset
32
 
33
  logger = logging.getLogger(__name__)
 
36
  TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
37
 
38
 
39
+ def rank_genes(gene_vector, gene_tokens):
40
+ """
41
+ Rank gene expression vector.
42
+ """
43
+ # sort by median-scaled gene values
44
+ sorted_indices = np.argsort(-gene_vector)
45
+ return gene_tokens[sorted_indices]
46
+
47
+
48
  def tokenize_cell(gene_vector, gene_tokens):
49
  """
50
  Convert normalized gene expression vector to tokenized rank value encoding.
 
52
  # create array of gene vector with token indices
53
  # mask undetected genes
54
  nonzero_mask = np.nonzero(gene_vector)[0]
55
+ # rank by median-scaled gene values
56
+ return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
 
 
 
57
 
58
 
59
  class TranscriptomeTokenizer:
 
102
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
103
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
104
 
105
+ def tokenize_data(
106
+ self,
107
+ data_directory: Path | str,
108
+ output_directory: Path | str,
109
+ output_prefix: str,
110
+ file_format: Literal["loom", "h5ad"] = "loom",
111
+ use_generator: bool = False,
112
+ ):
113
  """
114
  Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
115
 
116
  Parameters
117
  ----------
118
  loom_data_directory : Path
119
+ Path to directory containing loom files or anndata files
120
  output_directory : Path
121
  Path to directory where tokenized data will be saved as .dataset
122
  output_prefix : str
123
  Prefix for output .dataset
124
+ file_format : str
125
+ Format of input files. Can be "loom" or "h5ad".
126
+ use_generator : bool
127
+ Whether to use generator or dict for tokenization.
128
  """
129
+ tokenized_cells, cell_metadata = self.tokenize_files(
130
+ Path(data_directory), file_format
131
+ )
132
+ tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata, use_generator=use_generator)
133
 
134
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
135
  tokenized_dataset.save_to_disk(output_path)
136
 
137
+ def tokenize_files(
138
+ self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
139
+ ):
140
  tokenized_cells = []
141
  if self.custom_attr_name_dict is not None:
142
+ cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
143
  cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
144
 
145
  # loops through directories to tokenize .loom files
146
  file_found = 0
147
+ # loops through directories to tokenize .loom or .h5ad files
148
+ tokenize_file_fn = (
149
+ self.tokenize_loom if file_format == "loom" else self.tokenize_anndata
150
+ )
151
+ for file_path in data_directory.glob("*.{}".format(file_format)):
152
  file_found = 1
153
+ print(f"Tokenizing {file_path}")
154
+ file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
 
 
155
  tokenized_cells += file_tokenized_cells
156
  if self.custom_attr_name_dict is not None:
157
+ for k in cell_attr:
158
  cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
159
  else:
160
  cell_metadata = None
161
 
162
  if file_found == 0:
163
  logger.error(
164
+ f"No .{file_format} files found in directory {data_directory}.")
165
  raise
166
  return tokenized_cells, cell_metadata
167
 
168
+ def tokenize_anndata(self, adata_file_path, target_sum=10_000, chunk_size=512):
169
+ adata = ad.read(adata_file_path, backed="r")
170
+
171
+ if self.custom_attr_name_dict is not None:
172
+ file_cell_metadata = {
173
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
174
+ }
175
+
176
+ coding_miRNA_loc = np.where(
177
+ [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
178
+ )[0]
179
+ norm_factor_vector = np.array(
180
+ [
181
+ self.gene_median_dict[i]
182
+ for i in adata.var["ensembl_id"][coding_miRNA_loc]
183
+ ]
184
+ )
185
+ coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
186
+ coding_miRNA_tokens = np.array(
187
+ [self.gene_token_dict[i] for i in coding_miRNA_ids]
188
+ )
189
+
190
+ try:
191
+ _ = adata.obs["filter_pass"]
192
+ except KeyError:
193
+ var_exists = False
194
+ else:
195
+ var_exists = True
196
+
197
+ if var_exists:
198
+ filter_pass_loc = np.where(
199
+ [i == 1 for i in adata.obs["filter_pass"]]
200
+ )[0]
201
+ elif not var_exists:
202
+ print(
203
+ f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
204
+ )
205
+ filter_pass_loc = np.array([i for i in range(adata.shape[0])])
206
+
207
+ tokenized_cells = []
208
+
209
+ for i in range(0, len(filter_pass_loc), chunk_size):
210
+ idx = filter_pass_loc[i:i+chunk_size]
211
+
212
+ n_counts = adata[idx].obs['n_counts'].values[:, None]
213
+ X_view = adata[idx, coding_miRNA_loc].X
214
+ X_norm = (X_view / n_counts * target_sum / norm_factor_vector)
215
+ X_norm = sp.csr_matrix(X_norm)
216
+
217
+ tokenized_cells += [
218
+ rank_genes(X_norm[i].data, coding_miRNA_tokens[X_norm[i].indices])
219
+ for i in range(X_norm.shape[0])
220
+ ]
221
+
222
+ # add custom attributes for subview to dict
223
+ if self.custom_attr_name_dict is not None:
224
+ for k in file_cell_metadata.keys():
225
+ file_cell_metadata[k] += adata[idx].obs[k].tolist()
226
+ else:
227
+ file_cell_metadata = None
228
+
229
+ return tokenized_cells, file_cell_metadata
230
+
231
+ def tokenize_loom(self, loom_file_path, target_sum=10_000):
232
  if self.custom_attr_name_dict is not None:
233
  file_cell_metadata = {
234
  attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
 
258
  else:
259
  var_exists = True
260
 
261
+ if var_exists:
262
  filter_pass_loc = np.where(
263
+ [i == 1 for i in data.ca["filter_pass"]]
264
  )[0]
265
+ elif not var_exists:
266
  print(
267
  f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
268
  )
 
279
  subview_norm_array = (
280
  subview[:, :]
281
  / subview.ca.n_counts
282
+ * target_sum
283
  / norm_factor_vector[:, None]
284
  )
285
  # tokenize subview gene vectors
 
297
 
298
  return tokenized_cells, file_cell_metadata
299
 
300
+ def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False):
301
+ print("Creating dataset.")
302
  # create dict for dataset creation
303
  dataset_dict = {"input_ids": tokenized_cells}
304
  if self.custom_attr_name_dict is not None:
305
  dataset_dict.update(cell_metadata)
306
 
307
  # create dataset
308
+ if use_generator:
309
+ def dict_generator():
310
+ for i in range(len(tokenized_cells)):
311
+ yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}
312
+ output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
313
+ else:
314
+ output_dataset = Dataset.from_dict(dataset_dict)
315
 
316
  # truncate dataset
317
  def truncate(example):
318
+ example["input_ids"] = example["input_ids"][:2048]
319
  return example
320
 
321
  output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
 
329
  measure_length, num_proc=self.nproc
330
  )
331
 
332
+ return output_dataset_truncated_w_length