Update geneformer/tokenizer.py
Browse filesUpdate to use Ensembl ID mapped throughout
- geneformer/tokenizer.py +14 -30
geneformer/tokenizer.py
CHANGED
@@ -63,17 +63,6 @@ logger = logging.getLogger(__name__)
|
|
63 |
|
64 |
from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
65 |
|
66 |
-
def rename_attr(data_ra_or_ca, old_name, new_name):
|
67 |
-
""" Rename attributes
|
68 |
-
Args:
|
69 |
-
data_ra_or_ca: data as a record array or column attribute
|
70 |
-
old_name (str): old name of attribute
|
71 |
-
new_name (str): new name of attribute
|
72 |
-
"""
|
73 |
-
data_ra_or_ca[new_name] = data_ra_or_ca[old_name]
|
74 |
-
if new_name != old_name:
|
75 |
-
del data_ra_or_ca[old_name]
|
76 |
-
|
77 |
def rank_genes(gene_vector, gene_tokens):
|
78 |
"""
|
79 |
Rank gene expression vector.
|
@@ -131,18 +120,16 @@ def sum_ensembl_ids(
|
|
131 |
]
|
132 |
|
133 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
134 |
-
|
135 |
-
rename_attr(data.ra, "ensembl_id", "ensembl_id_original")
|
136 |
-
data.ra["ensembl_id"] = gene_ids_collapsed
|
137 |
return data_directory
|
138 |
else:
|
139 |
dedup_filename = data_directory.with_name(
|
140 |
data_directory.stem + "__dedup.loom"
|
141 |
)
|
142 |
-
data.ra["
|
143 |
dup_genes = [
|
144 |
idx
|
145 |
-
for idx, count in Counter(data.ra["
|
146 |
if count > 1
|
147 |
]
|
148 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
@@ -153,7 +140,7 @@ def sum_ensembl_ids(
|
|
153 |
|
154 |
def process_chunk(view, duplic_genes):
|
155 |
data_count_view = pd.DataFrame(
|
156 |
-
view, index=data.ra["
|
157 |
)
|
158 |
unique_data_df = data_count_view.loc[
|
159 |
~data_count_view.index.isin(duplic_genes)
|
@@ -179,7 +166,7 @@ def sum_ensembl_ids(
|
|
179 |
|
180 |
processed_chunk = process_chunk(view[:, :], dup_genes)
|
181 |
processed_array = processed_chunk.to_numpy()
|
182 |
-
new_row_attrs = {"
|
183 |
|
184 |
if "n_counts" not in view.ca.keys():
|
185 |
total_count_view = np.sum(view[:, :], axis=0).astype(int)
|
@@ -230,11 +217,11 @@ def sum_ensembl_ids(
|
|
230 |
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
231 |
]
|
232 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
233 |
-
data.var
|
234 |
return data
|
235 |
|
236 |
else:
|
237 |
-
data.var["
|
238 |
data.var_names = gene_ids_collapsed
|
239 |
data = data[:, ~data.var.index.isna()]
|
240 |
dup_genes = [
|
@@ -265,16 +252,13 @@ def sum_ensembl_ids(
|
|
265 |
processed_chunks = pd.concat(processed_chunks, axis=1)
|
266 |
processed_genes.append(processed_chunks)
|
267 |
processed_genes = pd.concat(processed_genes, axis=0)
|
268 |
-
var_df = pd.DataFrame({"
|
269 |
var_df.index = processed_genes.columns
|
270 |
processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
|
271 |
|
272 |
data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
|
273 |
data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
|
274 |
data_dedup.obs = data.obs
|
275 |
-
data_dedup.var = data_dedup.var.rename(
|
276 |
-
columns={"gene_ids_collapsed": "ensembl_id"}
|
277 |
-
)
|
278 |
return data_dedup
|
279 |
|
280 |
|
@@ -474,15 +458,15 @@ class TranscriptomeTokenizer:
|
|
474 |
}
|
475 |
|
476 |
coding_miRNA_loc = np.where(
|
477 |
-
[self.genelist_dict.get(i, False) for i in adata.var["
|
478 |
)[0]
|
479 |
norm_factor_vector = np.array(
|
480 |
[
|
481 |
self.gene_median_dict[i]
|
482 |
-
for i in adata.var["
|
483 |
]
|
484 |
)
|
485 |
-
coding_miRNA_ids = adata.var["
|
486 |
coding_miRNA_tokens = np.array(
|
487 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
488 |
)
|
@@ -546,15 +530,15 @@ class TranscriptomeTokenizer:
|
|
546 |
with lp.connect(str(loom_file_path)) as data:
|
547 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
548 |
coding_miRNA_loc = np.where(
|
549 |
-
[self.genelist_dict.get(i, False) for i in data.ra["
|
550 |
)[0]
|
551 |
norm_factor_vector = np.array(
|
552 |
[
|
553 |
self.gene_median_dict[i]
|
554 |
-
for i in data.ra["
|
555 |
]
|
556 |
)
|
557 |
-
coding_miRNA_ids = data.ra["
|
558 |
coding_miRNA_tokens = np.array(
|
559 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
560 |
)
|
|
|
63 |
|
64 |
from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
def rank_genes(gene_vector, gene_tokens):
|
67 |
"""
|
68 |
Rank gene expression vector.
|
|
|
120 |
]
|
121 |
|
122 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
123 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
|
|
|
|
124 |
return data_directory
|
125 |
else:
|
126 |
dedup_filename = data_directory.with_name(
|
127 |
data_directory.stem + "__dedup.loom"
|
128 |
)
|
129 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
130 |
dup_genes = [
|
131 |
idx
|
132 |
+
for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
|
133 |
if count > 1
|
134 |
]
|
135 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
|
|
140 |
|
141 |
def process_chunk(view, duplic_genes):
|
142 |
data_count_view = pd.DataFrame(
|
143 |
+
view, index=data.ra["ensembl_id_collapsed"]
|
144 |
)
|
145 |
unique_data_df = data_count_view.loc[
|
146 |
~data_count_view.index.isin(duplic_genes)
|
|
|
166 |
|
167 |
processed_chunk = process_chunk(view[:, :], dup_genes)
|
168 |
processed_array = processed_chunk.to_numpy()
|
169 |
+
new_row_attrs = {"ensembl_id_collapsed": processed_chunk.index.to_numpy()}
|
170 |
|
171 |
if "n_counts" not in view.ca.keys():
|
172 |
total_count_view = np.sum(view[:, :], axis=0).astype(int)
|
|
|
217 |
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
218 |
]
|
219 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
220 |
+
data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
|
221 |
return data
|
222 |
|
223 |
else:
|
224 |
+
data.var["ensembl_id_collapsed"] = gene_ids_collapsed
|
225 |
data.var_names = gene_ids_collapsed
|
226 |
data = data[:, ~data.var.index.isna()]
|
227 |
dup_genes = [
|
|
|
252 |
processed_chunks = pd.concat(processed_chunks, axis=1)
|
253 |
processed_genes.append(processed_chunks)
|
254 |
processed_genes = pd.concat(processed_genes, axis=0)
|
255 |
+
var_df = pd.DataFrame({"ensembl_id_collapsed": processed_genes.columns})
|
256 |
var_df.index = processed_genes.columns
|
257 |
processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
|
258 |
|
259 |
data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
|
260 |
data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
|
261 |
data_dedup.obs = data.obs
|
|
|
|
|
|
|
262 |
return data_dedup
|
263 |
|
264 |
|
|
|
458 |
}
|
459 |
|
460 |
coding_miRNA_loc = np.where(
|
461 |
+
[self.genelist_dict.get(i, False) for i in adata.var["ensembl_id_collapsed"]]
|
462 |
)[0]
|
463 |
norm_factor_vector = np.array(
|
464 |
[
|
465 |
self.gene_median_dict[i]
|
466 |
+
for i in adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
|
467 |
]
|
468 |
)
|
469 |
+
coding_miRNA_ids = adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
|
470 |
coding_miRNA_tokens = np.array(
|
471 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
472 |
)
|
|
|
530 |
with lp.connect(str(loom_file_path)) as data:
|
531 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
532 |
coding_miRNA_loc = np.where(
|
533 |
+
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id_collapsed"]]
|
534 |
)[0]
|
535 |
norm_factor_vector = np.array(
|
536 |
[
|
537 |
self.gene_median_dict[i]
|
538 |
+
for i in data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
|
539 |
]
|
540 |
)
|
541 |
+
coding_miRNA_ids = data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
|
542 |
coding_miRNA_tokens = np.array(
|
543 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
544 |
)
|