Update geneformer/tokenizer.py
Browse files- Add checks for CLS and EOS token when special_toke = True
- More efficient filter of gene_mapping_dict for values in gene_token_dict
- Remove summing of genes that do not exist in gene_token_dict for loom files
- geneformer/tokenizer.py +14 -12
geneformer/tokenizer.py
CHANGED
@@ -94,18 +94,17 @@ def sum_ensembl_ids(data_directory,
|
|
94 |
|
95 |
if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
|
96 |
return data_directory
|
97 |
-
|
98 |
else:
|
99 |
dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
|
100 |
-
data.ra["
|
101 |
-
dup_genes = [idx for idx, count in Counter(data.ra["
|
102 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
103 |
first_chunk = True
|
104 |
for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
|
105 |
def process_chunk(view, duplic_genes):
|
106 |
-
data_count_view = pd.DataFrame(view, index=data.ra["
|
107 |
unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
|
108 |
-
dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
|
109 |
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
110 |
if not summed_data.index.is_unique:
|
111 |
raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
|
@@ -117,12 +116,6 @@ def sum_ensembl_ids(data_directory,
|
|
117 |
processed_array = processed_chunk.to_numpy()
|
118 |
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
119 |
|
120 |
-
ra_keys = [k for k in data.ra.keys() if k != "ensembl_id"]
|
121 |
-
for ra_value in ra_keys:
|
122 |
-
mapping_dict = dict(zip(data.ra["ensembl_id"], data.ra[ra_value]))
|
123 |
-
values_new = [mapping_dict[i] for i in processed_chunk.index]
|
124 |
-
new_row_attrs[ra_value] = np.array(values_new)
|
125 |
-
|
126 |
if "n_counts" not in view.ca.keys():
|
127 |
total_count_view = np.sum(view[:,:], axis=0).astype(int)
|
128 |
view.ca["n_counts"] = total_count_view
|
@@ -263,6 +256,14 @@ class TranscriptomeTokenizer:
|
|
263 |
with open(token_dictionary_file, "rb") as f:
|
264 |
self.gene_token_dict = pickle.load(f)
|
265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
# if collapsing duplicate gene IDs
|
267 |
self.collapse_gene_ids = collapse_gene_ids
|
268 |
|
@@ -277,7 +278,8 @@ class TranscriptomeTokenizer:
|
|
277 |
self.gene_keys = list(self.gene_token_dict.keys())
|
278 |
|
279 |
# Filter gene mapping dict for items that exist in gene_token_dict
|
280 |
-
|
|
|
281 |
|
282 |
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
283 |
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
|
|
94 |
|
95 |
if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
|
96 |
return data_directory
|
|
|
97 |
else:
|
98 |
dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
|
99 |
+
data.ra["gene_ids_collapsed"] = gene_ids_collapsed
|
100 |
+
dup_genes = [idx for idx, count in Counter(data.ra["gene_ids_collapsed"]).items() if count > 1]
|
101 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
102 |
first_chunk = True
|
103 |
for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
|
104 |
def process_chunk(view, duplic_genes):
|
105 |
+
data_count_view = pd.DataFrame(view, index=data.ra["gene_ids_collapsed"])
|
106 |
unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
|
107 |
+
dup_data_df = data_count_view.loc[data_count_view.index.isin([i for i in duplic_genes if "None" not in i])]
|
108 |
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
109 |
if not summed_data.index.is_unique:
|
110 |
raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
|
|
|
116 |
processed_array = processed_chunk.to_numpy()
|
117 |
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
if "n_counts" not in view.ca.keys():
|
120 |
total_count_view = np.sum(view[:,:], axis=0).astype(int)
|
121 |
view.ca["n_counts"] = total_count_view
|
|
|
256 |
with open(token_dictionary_file, "rb") as f:
|
257 |
self.gene_token_dict = pickle.load(f)
|
258 |
|
259 |
+
# check for special token in gene_token_dict
|
260 |
+
if self.special_token:
|
261 |
+
if ("<cls>" not in self.gene_token_dict.keys()) and ("<eos>" not in self.gene_token_dict.keys()):
|
262 |
+
logger.error(
|
263 |
+
"<cls> and <eos> required in gene_token_dict when special_token = True."
|
264 |
+
)
|
265 |
+
raise
|
266 |
+
|
267 |
# if collapsing duplicate gene IDs
|
268 |
self.collapse_gene_ids = collapse_gene_ids
|
269 |
|
|
|
278 |
self.gene_keys = list(self.gene_token_dict.keys())
|
279 |
|
280 |
# Filter gene mapping dict for items that exist in gene_token_dict
|
281 |
+
gene_keys_set = set(self.gene_token_dict.keys())
|
282 |
+
self.gene_mapping_dict = {k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set}
|
283 |
|
284 |
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
285 |
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|