|
""" |
|
Geneformer tokenizer. |
|
|
|
Input data: |
|
Required format: raw counts scRNAseq data without feature selection as .loom file |
|
Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene |
|
Required col (cell) attribute: "n_counts"; total read counts in that cell |
|
Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria |
|
Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below |
|
|
|
Usage: |
|
from geneformer import TranscriptomeTokenizer |
|
tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4) |
|
tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix") |
|
""" |
|
|
|
import pickle |
|
from pathlib import Path |
|
|
|
import loompy as lp |
|
import numpy as np |
|
from datasets import Dataset |
|
|
|
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" |
|
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" |
|
|
|
|
|
def tokenize_cell(gene_vector, gene_tokens): |
|
""" |
|
Convert normalized gene expression vector to tokenized rank value encoding. |
|
""" |
|
|
|
|
|
nonzero_mask = np.nonzero(gene_vector)[0] |
|
|
|
sorted_indices = np.argsort(-gene_vector[nonzero_mask]) |
|
|
|
sentence_tokens = gene_tokens[nonzero_mask][sorted_indices] |
|
return sentence_tokens |
|
|
|
|
|
class TranscriptomeTokenizer: |
|
def __init__( |
|
self, |
|
custom_attr_name_dict=None, |
|
nproc=1, |
|
gene_median_file=GENE_MEDIAN_FILE, |
|
token_dictionary_file=TOKEN_DICTIONARY_FILE, |
|
): |
|
""" |
|
Initialize tokenizer. |
|
|
|
Parameters |
|
---------- |
|
custom_attr_name_dict : None, dict |
|
Dictionary of custom attributes to be added to the dataset. |
|
Keys are the names of the attributes in the loom file. |
|
Values are the names of the attributes in the dataset. |
|
nproc : int |
|
Number of processes to use for dataset mapping. |
|
gene_median_file : Path |
|
Path to pickle file containing dictionary of non-zero median |
|
gene expression values across Genecorpus-30M. |
|
token_dictionary_file : Path |
|
Path to pickle file containing token dictionary (Ensembl IDs:token). |
|
""" |
|
|
|
self.custom_attr_name_dict = custom_attr_name_dict |
|
|
|
|
|
self.nproc = nproc |
|
|
|
|
|
|
|
with open(gene_median_file, "rb") as f: |
|
self.gene_median_dict = pickle.load(f) |
|
|
|
|
|
with open(token_dictionary_file, "rb") as f: |
|
self.gene_token_dict = pickle.load(f) |
|
|
|
|
|
self.gene_keys = list(self.gene_median_dict.keys()) |
|
|
|
|
|
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) |
|
|
|
def tokenize_data(self, loom_data_directory, output_directory, output_prefix): |
|
""" |
|
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory. |
|
|
|
Parameters |
|
---------- |
|
loom_data_directory : Path |
|
Path to directory containing loom files |
|
output_directory : Path |
|
Path to directory where tokenized data will be saved as .dataset |
|
output_prefix : str |
|
Prefix for output .dataset |
|
""" |
|
tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory)) |
|
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata) |
|
|
|
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset") |
|
tokenized_dataset.save_to_disk(output_path) |
|
|
|
def tokenize_files(self, loom_data_directory): |
|
tokenized_cells = [] |
|
if self.custom_attr_name_dict is not None: |
|
loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()] |
|
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()} |
|
|
|
|
|
for loom_file_path in loom_data_directory.glob("*.loom"): |
|
print(f"Tokenizing {loom_file_path}") |
|
file_tokenized_cells, file_cell_metadata = self.tokenize_file( |
|
loom_file_path |
|
) |
|
tokenized_cells += file_tokenized_cells |
|
if self.custom_attr_name_dict is not None: |
|
for k in loom_cell_attr: |
|
cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k] |
|
else: |
|
cell_metadata = None |
|
|
|
return tokenized_cells, cell_metadata |
|
|
|
def tokenize_file(self, loom_file_path): |
|
if self.custom_attr_name_dict is not None: |
|
file_cell_metadata = { |
|
attr_key: [] for attr_key in self.custom_attr_name_dict.keys() |
|
} |
|
|
|
with lp.connect(str(loom_file_path)) as data: |
|
|
|
coding_miRNA_loc = np.where( |
|
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]] |
|
)[0] |
|
norm_factor_vector = np.array( |
|
[ |
|
self.gene_median_dict[i] |
|
for i in data.ra["ensembl_id"][coding_miRNA_loc] |
|
] |
|
) |
|
coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc] |
|
coding_miRNA_tokens = np.array( |
|
[self.gene_token_dict[i] for i in coding_miRNA_ids] |
|
) |
|
|
|
|
|
try: |
|
data.ca["filter_pass"] |
|
except AttributeError: |
|
var_exists = False |
|
else: |
|
var_exists = True |
|
|
|
if var_exists is True: |
|
filter_pass_loc = np.where( |
|
[True if i == 1 else False for i in data.ca["filter_pass"]] |
|
)[0] |
|
elif var_exists is False: |
|
print( |
|
f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells." |
|
) |
|
filter_pass_loc = np.array([i for i in range(data.shape[1])]) |
|
|
|
|
|
tokenized_cells = [] |
|
for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1): |
|
|
|
subview = view.view[coding_miRNA_loc, :] |
|
|
|
|
|
|
|
subview_norm_array = ( |
|
subview[:, :] |
|
/ subview.ca.n_counts |
|
* 10_000 |
|
/ norm_factor_vector[:, None] |
|
) |
|
|
|
tokenized_cells += [ |
|
tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens) |
|
for i in range(subview_norm_array.shape[1]) |
|
] |
|
|
|
|
|
if self.custom_attr_name_dict is not None: |
|
for k in file_cell_metadata.keys(): |
|
file_cell_metadata[k] += subview.ca[k].tolist() |
|
else: |
|
file_cell_metadata = None |
|
|
|
return tokenized_cells, file_cell_metadata |
|
|
|
def create_dataset(self, tokenized_cells, cell_metadata): |
|
|
|
dataset_dict = {"input_ids": tokenized_cells} |
|
if self.custom_attr_name_dict is not None: |
|
dataset_dict.update(cell_metadata) |
|
|
|
|
|
output_dataset = Dataset.from_dict(dataset_dict) |
|
|
|
|
|
def truncate(example): |
|
example["input_ids"] = example["input_ids"][0:2048] |
|
return example |
|
|
|
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc) |
|
|
|
|
|
def measure_length(example): |
|
example["length"] = len(example["input_ids"]) |
|
return example |
|
|
|
output_dataset_truncated_w_length = output_dataset_truncated.map( |
|
measure_length, num_proc=self.nproc |
|
) |
|
|
|
return output_dataset_truncated_w_length |
|
|