Add function for summing of Ensembl IDs
Browse files- geneformer/tokenizer.py +135 -4
geneformer/tokenizer.py
CHANGED
@@ -36,14 +36,21 @@ Geneformer tokenizer.
|
|
36 |
|
37 |
from __future__ import annotations
|
38 |
|
|
|
39 |
import logging
|
40 |
import pickle
|
|
|
41 |
import warnings
|
42 |
from pathlib import Path
|
43 |
from typing import Literal
|
|
|
|
|
44 |
|
45 |
-
import anndata as ad
|
46 |
import numpy as np
|
|
|
|
|
|
|
|
|
47 |
import scipy.sparse as sp
|
48 |
from datasets import Dataset
|
49 |
|
@@ -52,7 +59,7 @@ import loompy as lp # noqa
|
|
52 |
|
53 |
logger = logging.getLogger(__name__)
|
54 |
|
55 |
-
from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
56 |
|
57 |
|
58 |
def rank_genes(gene_vector, gene_tokens):
|
@@ -74,6 +81,115 @@ def tokenize_cell(gene_vector, gene_tokens):
|
|
74 |
# rank by median-scaled gene values
|
75 |
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
class TranscriptomeTokenizer:
|
79 |
def __init__(
|
@@ -85,6 +201,7 @@ class TranscriptomeTokenizer:
|
|
85 |
special_token=False,
|
86 |
gene_median_file=GENE_MEDIAN_FILE,
|
87 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
|
|
88 |
):
|
89 |
"""
|
90 |
Initialize tokenizer.
|
@@ -103,11 +220,15 @@ class TranscriptomeTokenizer:
|
|
103 |
| Max input size of model to truncate input to.
|
104 |
special_token : bool = False
|
105 |
| Adds CLS token before and EOS token after rank value encoding.
|
|
|
|
|
106 |
gene_median_file : Path
|
107 |
| Path to pickle file containing dictionary of non-zero median
|
108 |
| gene expression values across Genecorpus-30M.
|
109 |
token_dictionary_file : Path
|
110 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
|
|
|
|
111 |
|
112 |
"""
|
113 |
# dictionary of custom attributes {output dataset column name: input .loom column name}
|
@@ -134,6 +255,10 @@ class TranscriptomeTokenizer:
|
|
134 |
with open(token_dictionary_file, "rb") as f:
|
135 |
self.gene_token_dict = pickle.load(f)
|
136 |
|
|
|
|
|
|
|
|
|
137 |
# gene keys for full vocabulary
|
138 |
self.gene_keys = list(self.gene_token_dict.keys())
|
139 |
|
@@ -214,7 +339,7 @@ class TranscriptomeTokenizer:
|
|
214 |
return tokenized_cells, cell_metadata
|
215 |
|
216 |
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
217 |
-
adata =
|
218 |
|
219 |
if self.custom_attr_name_dict is not None:
|
220 |
file_cell_metadata = {
|
@@ -256,7 +381,8 @@ class TranscriptomeTokenizer:
|
|
256 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
257 |
|
258 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
259 |
-
|
|
|
260 |
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
261 |
X_norm = sp.csr_matrix(X_norm)
|
262 |
|
@@ -280,6 +406,8 @@ class TranscriptomeTokenizer:
|
|
280 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
281 |
}
|
282 |
|
|
|
|
|
283 |
with lp.connect(str(loom_file_path)) as data:
|
284 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
285 |
coding_miRNA_loc = np.where(
|
@@ -341,6 +469,9 @@ class TranscriptomeTokenizer:
|
|
341 |
else:
|
342 |
file_cell_metadata = None
|
343 |
|
|
|
|
|
|
|
344 |
return tokenized_cells, file_cell_metadata
|
345 |
|
346 |
def create_dataset(
|
|
|
36 |
|
37 |
from __future__ import annotations
|
38 |
|
39 |
+
import os
|
40 |
import logging
|
41 |
import pickle
|
42 |
+
import sys
|
43 |
import warnings
|
44 |
from pathlib import Path
|
45 |
from typing import Literal
|
46 |
+
from tqdm import tqdm
|
47 |
+
from collections import Counter
|
48 |
|
|
|
49 |
import numpy as np
|
50 |
+
import scanpy as sc
|
51 |
+
import loompy as lp
|
52 |
+
import pandas as pd
|
53 |
+
import anndata as ad
|
54 |
import scipy.sparse as sp
|
55 |
from datasets import Dataset
|
56 |
|
|
|
59 |
|
60 |
logger = logging.getLogger(__name__)
|
61 |
|
62 |
+
from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_MAPPING_FILE
|
63 |
|
64 |
|
65 |
def rank_genes(gene_vector, gene_tokens):
|
|
|
81 |
# rank by median-scaled gene values
|
82 |
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
83 |
|
84 |
+
def sum_ensembl_ids(data_directory,
|
85 |
+
gene_mapping_dict,
|
86 |
+
file_format = "loom",
|
87 |
+
chunk_size = 512):
|
88 |
+
if file_format == "loom":
|
89 |
+
"""
|
90 |
+
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
91 |
+
"""
|
92 |
+
with lp.connect(data_directory) as data:
|
93 |
+
assert "ensembl_id" in data.ra.keys(), "'ensembl_id' column missing from data.ra.keys()"
|
94 |
+
gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id]
|
95 |
+
|
96 |
+
if len(set(gene_ids_collapsed)) == len(set(data.ra.ensembl_id)):
|
97 |
+
return data_directory
|
98 |
+
|
99 |
+
else:
|
100 |
+
dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
|
101 |
+
dup_genes = [idx for idx, count in Counter(data.ra["ensembl_id"]).items() if count > 1]
|
102 |
+
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
103 |
+
first_chunk = True
|
104 |
+
for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
|
105 |
+
def process_chunk(view, duplic_genes):
|
106 |
+
data_count_view = pd.DataFrame(view, index=data.ra["ensembl_id"])
|
107 |
+
unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
|
108 |
+
dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
|
109 |
+
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
110 |
+
if not summed_data.index.is_unique:
|
111 |
+
raise ValueError("Error: summed data frame non-unique.")
|
112 |
+
data_count_view = pd.concat([unique_data_df, summed_data], axis=0)
|
113 |
+
if not data_count_view.index.is_unique:
|
114 |
+
raise ValueError("Error: final data frame non-unique.")
|
115 |
+
return data_count_view
|
116 |
+
processed_chunk = process_chunk(view[:, :], dup_genes)
|
117 |
+
processed_array = processed_chunk.to_numpy()
|
118 |
+
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
119 |
+
|
120 |
+
ra_keys = [k for k in data.ra.keys() if k != "ensembl_id"]
|
121 |
+
for ra_value in ra_keys:
|
122 |
+
mapping_dict = dict(zip(data.ra["ensembl_id"], data.ra[ra_value]))
|
123 |
+
values_new = [mapping_dict[i] for i in processed_chunk.index]
|
124 |
+
new_row_attrs[ra_value] = np.array(values_new)
|
125 |
+
|
126 |
+
if "n_counts" not in view.ca.keys():
|
127 |
+
total_count_view = np.sum(view[:,:], axis=0).astype(int)
|
128 |
+
view.ca["n_counts"] = total_count_view
|
129 |
+
|
130 |
+
if first_chunk: # Create the Loom file with the first chunk
|
131 |
+
lp.create(f"{dedup_filename}", processed_array, row_attrs=new_row_attrs, col_attrs=view.ca)
|
132 |
+
first_chunk = False
|
133 |
+
else: # Append subsequent chunks
|
134 |
+
with lp.connect(dedup_filename, mode='r+') as dsout:
|
135 |
+
dsout.add_columns(processed_array, col_attrs=view.ca)
|
136 |
+
return dedup_filename
|
137 |
+
|
138 |
+
elif file_format == "h5ad":
|
139 |
+
"""
|
140 |
+
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
141 |
+
Returns adata object with deduplicated Ensembl IDs.
|
142 |
+
"""
|
143 |
+
|
144 |
+
data = sc.read_h5ad(str(data_directory))
|
145 |
+
|
146 |
+
assert "ensembl_id" in data.var.columns, "'ensembl_id' column missing from data.var"
|
147 |
+
|
148 |
+
gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id]
|
149 |
+
|
150 |
+
if len(set(gene_ids_collapsed)) == len(set(data.var.ensembl_id)):
|
151 |
+
return data
|
152 |
+
|
153 |
+
else:
|
154 |
+
data.var["gene_ids_collapsed"] = gene_ids_collapsed
|
155 |
+
data.var_names = gene_ids_collapsed
|
156 |
+
data = data[:, ~data.var.index.isna()]
|
157 |
+
dup_genes = [idx for idx, count in Counter(data.var_names).items() if count > 1]
|
158 |
+
|
159 |
+
num_chunks = int(np.ceil(data.shape[0] / chunk_size))
|
160 |
+
|
161 |
+
processed_genes = []
|
162 |
+
for i in tqdm(range(num_chunks)):
|
163 |
+
|
164 |
+
start_idx = i * chunk_size
|
165 |
+
end_idx = min((i + 1) * chunk_size, data.shape[0])
|
166 |
+
data_chunk = data[start_idx:end_idx, :]
|
167 |
+
|
168 |
+
processed_chunks = []
|
169 |
+
for dup_gene in dup_genes:
|
170 |
+
data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
|
171 |
+
df = pd.DataFrame.sparse.from_spmatrix(data_dup_gene.X,
|
172 |
+
index=data_dup_gene.obs_names,
|
173 |
+
columns=data_dup_gene.var_names)
|
174 |
+
df_sum = pd.DataFrame(df.sum(axis=1))
|
175 |
+
df_sum.columns = [dup_gene]
|
176 |
+
df_sum.index = data_dup_gene.obs.index
|
177 |
+
processed_chunks.append(df_sum)
|
178 |
+
|
179 |
+
processed_chunks = pd.concat(processed_chunks, axis=1)
|
180 |
+
processed_genes.append(processed_chunks)
|
181 |
+
processed_genes = pd.concat(processed_genes, axis = 0)
|
182 |
+
var_df = pd.DataFrame({"gene_ids_collapsed" : processed_genes.columns})
|
183 |
+
var_df.index = processed_genes.columns
|
184 |
+
processed_genes = sc.AnnData(X = processed_genes,
|
185 |
+
obs = data.obs,
|
186 |
+
var = var_df)
|
187 |
+
|
188 |
+
data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
|
189 |
+
data_dedup = sc.concat([data_dedup, processed_genes], axis = 1)
|
190 |
+
data_dedup.obs = data.obs
|
191 |
+
data_dedup.var = data_dedup.var.rename(columns = {"gene_ids_collapsed" : "ensembl_id"})
|
192 |
+
return data_dedup
|
193 |
|
194 |
class TranscriptomeTokenizer:
|
195 |
def __init__(
|
|
|
201 |
special_token=False,
|
202 |
gene_median_file=GENE_MEDIAN_FILE,
|
203 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
204 |
+
gene_mapping_file=ENSEMBL_MAPPING_FILE,
|
205 |
):
|
206 |
"""
|
207 |
Initialize tokenizer.
|
|
|
220 |
| Max input size of model to truncate input to.
|
221 |
special_token : bool = False
|
222 |
| Adds CLS token before and EOS token after rank value encoding.
|
223 |
+
collapse_gene_ids : bool = False
|
224 |
+
| Whether to collapse gene IDs based on gene mapping dictionary.
|
225 |
gene_median_file : Path
|
226 |
| Path to pickle file containing dictionary of non-zero median
|
227 |
| gene expression values across Genecorpus-30M.
|
228 |
token_dictionary_file : Path
|
229 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
230 |
+
gene_mapping_file : Path
|
231 |
+
| Path to pickle file containing dictionary for collapsing gene IDs.
|
232 |
|
233 |
"""
|
234 |
# dictionary of custom attributes {output dataset column name: input .loom column name}
|
|
|
255 |
with open(token_dictionary_file, "rb") as f:
|
256 |
self.gene_token_dict = pickle.load(f)
|
257 |
|
258 |
+
# load gene mappings dictionary (Ensembl IDs:Ensembl ID)
|
259 |
+
with open(gene_mapping_file, "rb") as f:
|
260 |
+
self.gene_mapping_dict = pickle.load(f)
|
261 |
+
|
262 |
# gene keys for full vocabulary
|
263 |
self.gene_keys = list(self.gene_token_dict.keys())
|
264 |
|
|
|
339 |
return tokenized_cells, cell_metadata
|
340 |
|
341 |
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
342 |
+
adata = sum_ensembl_ids(adata_file_path, self.gene_mapping_dict, file_format = "h5ad", chunk_size = self.chunk_size)
|
343 |
|
344 |
if self.custom_attr_name_dict is not None:
|
345 |
file_cell_metadata = {
|
|
|
381 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
382 |
|
383 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
384 |
+
X_view0 = adata[idx,:].X
|
385 |
+
X_view = X_view0[:, coding_miRNA_loc]
|
386 |
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
387 |
X_norm = sp.csr_matrix(X_norm)
|
388 |
|
|
|
406 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
407 |
}
|
408 |
|
409 |
+
loom_file_path = sum_ensembl_ids(loom_file_path, self.gene_mapping_dict, file_format = "loom", chunk_size = self.chunk_size)
|
410 |
+
|
411 |
with lp.connect(str(loom_file_path)) as data:
|
412 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
413 |
coding_miRNA_loc = np.where(
|
|
|
469 |
else:
|
470 |
file_cell_metadata = None
|
471 |
|
472 |
+
if "__dedup" in str(loom_file_path):
|
473 |
+
os.remove(str(loom_file_path))
|
474 |
+
|
475 |
return tokenized_cells, file_cell_metadata
|
476 |
|
477 |
def create_dataset(
|