ricomnl commited on
Commit
b6ca566
1 Parent(s): c4b1f94

Added anndata tokenizer and switched to Dataset.from_generator

Browse files
examples/tokenizing_scRNAseq_data.ipynb CHANGED
@@ -1,6 +1,7 @@
1
  {
2
  "cells": [
3
  {
 
4
  "cell_type": "markdown",
5
  "id": "a91bca46-c056-4784-8c6c-b0f5d3f33496",
6
  "metadata": {
@@ -11,6 +12,7 @@
11
  ]
12
  },
13
  {
 
14
  "cell_type": "markdown",
15
  "id": "350e6252-b783-494b-9767-f087eb868a15",
16
  "metadata": {},
@@ -44,7 +46,7 @@
44
  "outputs": [],
45
  "source": [
46
  "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ_major\"}, nproc=4)\n",
47
- "tk.tokenize_data(\"loom_data_directory\", \"output_directory\", \"output_prefix\")"
48
  ]
49
  }
50
  ],
 
1
  {
2
  "cells": [
3
  {
4
+ "attachments": {},
5
  "cell_type": "markdown",
6
  "id": "a91bca46-c056-4784-8c6c-b0f5d3f33496",
7
  "metadata": {
 
12
  ]
13
  },
14
  {
15
+ "attachments": {},
16
  "cell_type": "markdown",
17
  "id": "350e6252-b783-494b-9767-f087eb868a15",
18
  "metadata": {},
 
46
  "outputs": [],
47
  "source": [
48
  "tk = TranscriptomeTokenizer({\"cell_type\": \"cell_type\", \"organ_major\": \"organ_major\"}, nproc=4)\n",
49
+ "tk.tokenize_data(\"loom_data_directory\", \"output_directory\", \"output_prefix\", file_format=\"loom\")"
50
  ]
51
  }
52
  ],
geneformer/tokenizer.py CHANGED
@@ -14,6 +14,8 @@ Usage:
14
  tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
15
  """
16
 
 
 
17
  import pickle
18
  from pathlib import Path
19
 
@@ -22,6 +24,7 @@ import logging
22
  import warnings
23
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
24
 
 
25
  import loompy as lp
26
  import numpy as np
27
  from datasets import Dataset
@@ -92,26 +95,38 @@ class TranscriptomeTokenizer:
92
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
93
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
94
 
95
- def tokenize_data(self, loom_data_directory, output_directory, output_prefix):
 
 
 
 
 
 
96
  """
97
  Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
98
 
99
  Parameters
100
  ----------
101
  loom_data_directory : Path
102
- Path to directory containing loom files
103
  output_directory : Path
104
  Path to directory where tokenized data will be saved as .dataset
105
  output_prefix : str
106
  Prefix for output .dataset
 
 
107
  """
108
- tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory))
 
 
109
  tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
110
 
111
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
112
  tokenized_dataset.save_to_disk(output_path)
113
 
114
- def tokenize_files(self, loom_data_directory):
 
 
115
  tokenized_cells = []
116
  if self.custom_attr_name_dict is not None:
117
  loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
@@ -119,12 +134,14 @@ class TranscriptomeTokenizer:
119
 
120
  # loops through directories to tokenize .loom files
121
  file_found = 0
122
- for loom_file_path in loom_data_directory.glob("*.loom"):
 
 
 
 
123
  file_found = 1
124
- print(f"Tokenizing {loom_file_path}")
125
- file_tokenized_cells, file_cell_metadata = self.tokenize_file(
126
- loom_file_path
127
- )
128
  tokenized_cells += file_tokenized_cells
129
  if self.custom_attr_name_dict is not None:
130
  for k in loom_cell_attr:
@@ -134,10 +151,65 @@ class TranscriptomeTokenizer:
134
 
135
  if file_found == 0:
136
  logger.error(
137
- f"No .loom files found in directory {loom_data_directory}.")
138
  raise
139
  return tokenized_cells, cell_metadata
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def tokenize_file(self, loom_file_path):
142
  if self.custom_attr_name_dict is not None:
143
  file_cell_metadata = {
@@ -214,7 +286,13 @@ class TranscriptomeTokenizer:
214
  dataset_dict.update(cell_metadata)
215
 
216
  # create dataset
217
- output_dataset = Dataset.from_dict(dataset_dict)
 
 
 
 
 
 
218
 
219
  # truncate dataset
220
  def truncate(example):
 
14
  tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
15
  """
16
 
17
+ from __future__ import annotations
18
+ from typing import Literal
19
  import pickle
20
  from pathlib import Path
21
 
 
24
  import warnings
25
  warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")
26
 
27
+ import anndata as ad
28
  import loompy as lp
29
  import numpy as np
30
  from datasets import Dataset
 
95
  # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
96
  self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
97
 
98
+ def tokenize_data(
99
+ self,
100
+ data_directory: Path | str,
101
+ output_directory: Path | str,
102
+ output_prefix: str,
103
+ file_format: Literal["loom", "h5ad"] = "loom",
104
+ ):
105
  """
106
  Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
107
 
108
  Parameters
109
  ----------
110
  loom_data_directory : Path
111
+ Path to directory containing loom files or anndata files
112
  output_directory : Path
113
  Path to directory where tokenized data will be saved as .dataset
114
  output_prefix : str
115
  Prefix for output .dataset
116
+ file_format : str
117
+ Format of input files. Can be "loom" or "h5ad".
118
  """
119
+ tokenized_cells, cell_metadata = self.tokenize_files(
120
+ Path(data_directory), file_format
121
+ )
122
  tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
123
 
124
  output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
125
  tokenized_dataset.save_to_disk(output_path)
126
 
127
+ def tokenize_files(
128
+ self, data_directory, file_format: Literal["loom", "h5ad"] = "loom"
129
+ ):
130
  tokenized_cells = []
131
  if self.custom_attr_name_dict is not None:
132
  loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
 
134
 
135
  # loops through directories to tokenize .loom files
136
  file_found = 0
137
+ # loops through directories to tokenize .loom or .h5ad files
138
+ tokenize_file_fn = (
139
+ self.tokenize_file if file_format == "loom" else self.tokenize_anndata
140
+ )
141
+ for file_path in data_directory.glob("*.{}".format(file_format)):
142
  file_found = 1
143
+ print(f"Tokenizing {file_path}")
144
+ file_tokenized_cells, file_cell_metadata = tokenize_file_fn(file_path)
 
 
145
  tokenized_cells += file_tokenized_cells
146
  if self.custom_attr_name_dict is not None:
147
  for k in loom_cell_attr:
 
151
 
152
  if file_found == 0:
153
  logger.error(
154
+ f"No .{file_format} files found in directory {data_directory}.")
155
  raise
156
  return tokenized_cells, cell_metadata
157
 
158
+ def tokenize_anndata(self, adata_file_path):
159
+ adata = ad.read(adata_file_path)
160
+ file_cell_metadata = {
161
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
162
+ }
163
+
164
+ coding_miRNA_loc = np.where(
165
+ [self.genelist_dict.get(i, False) for i in adata.var["ensembl_id"]]
166
+ )[0]
167
+ norm_factor_vector = np.array(
168
+ [
169
+ self.gene_median_dict[i]
170
+ for i in adata.var["ensembl_id"][coding_miRNA_loc]
171
+ ]
172
+ )
173
+ coding_miRNA_ids = adata.var["ensembl_id"][coding_miRNA_loc]
174
+ coding_miRNA_tokens = np.array(
175
+ [self.gene_token_dict[i] for i in coding_miRNA_ids]
176
+ )
177
+
178
+ try:
179
+ adata.obs["filter_pass"]
180
+ except KeyError:
181
+ var_exists = False
182
+ else:
183
+ var_exists = True
184
+
185
+ if var_exists is True:
186
+ filter_pass_loc = np.where(
187
+ [True if i == 1 else False for i in adata.obs["filter_pass"]]
188
+ )[0]
189
+ elif var_exists is False:
190
+ print(
191
+ f"{adata_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
192
+ )
193
+ filter_pass_loc = np.array([i for i in range(adata.shape[0])])
194
+
195
+ tokenized_cells = []
196
+ adata_filter = adata[
197
+ filter_pass_loc, coding_miRNA_loc # filter cells and genes
198
+ ]
199
+
200
+ X_norm = (adata_filter.X / adata.X.sum(1) * 10_000 / norm_factor_vector).tocsr()
201
+
202
+ tokenized_cells += [
203
+ tokenize_cell(X_norm[i, ...].A.flatten(), coding_miRNA_tokens)
204
+ for i in range(X_norm.shape[0])
205
+ ]
206
+
207
+ # add custom attributes for subview to dict
208
+ for k in file_cell_metadata.keys():
209
+ file_cell_metadata[k] += adata_filter.obs[k].tolist()
210
+
211
+ return tokenized_cells, file_cell_metadata
212
+
213
  def tokenize_file(self, loom_file_path):
214
  if self.custom_attr_name_dict is not None:
215
  file_cell_metadata = {
 
286
  dataset_dict.update(cell_metadata)
287
 
288
  # create dataset
289
+ def dict_generator():
290
+ for i in range(len(tokenized_cells)):
291
+ yield {
292
+ 'input_ids': dataset_dict['input_ids'][i],
293
+ 'cell_type': dataset_dict['cell_type'][i]
294
+ }
295
+ output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
296
 
297
  # truncate dataset
298
  def truncate(example):