How to fine-tune your model for data imputation

#507
by cjuigne - opened

Congrats for your very interesting work!

As the model is pretrained with mask tokens to recover the initial input data, I wonder if I can use it for data imputation. I’ve tried to fine-tune the model with

model_checkpoint = "ctheodoris/Geneformer" AutoModelForMaskedLM.from_pretrained(model_checkpoint)

But I'm having trouble feeding the formatter correctly (I first tokenized the training data with TranscriptomeTokenizer and gave the Collator a GeneformerPreCollator, but the training doesn't work).
Do you think this is possible? Is it even necessary to fine-tune it?

Thank you in advance.

Thanks for your question. We have not utilized Geneformer for data imputation, as the masked learning objective is intended for the model to gain generalizable knowledge of the gene network dynamics but not necessarily the end goal of the model. However, if you are interested in testing this, it may not be necessary to fine-tune it if you are interested in a setting that is already represented in the pretraining corpus. You can load the model as BertForMaskedLM, as here. You can put the model in eval mode and obtain the predictions for masked genes with a forward pass through the model.

ctheodoris changed discussion status to closed

Thank you so much for your reply!
Ok, I'll try! Which path should I provide as the correct path for model_directory?

It depends which model you want to use, but the current default model is in the main directory of this repository.

Thank you, if I understand well, after tokenized my data and add some masks (0), I do :
from geneformer.perturber_utils import load_model
model = load_model(model_type = "Pretrained", num_classes = 0, model_directory = "/Geneformer", mode = "eval", quantize=False)
and then, I gave the tokenized data to the model, but I don't know how must be the data to give to the model? How I transform them in PyTorch tensors and obtain an attention mask?
Thank you!

If you’d like to train it, you should load it in train mode rather than eval. The pretraining is done with masked modeling so you can check the pretraining code if you’d like to continue this same objective on a specific dataset.

As you suggested before "You can put the model in eval mode and obtain the predictions for masked genes with a forward pass through the model."
I wanted to try this. But I don't know what to give to the model (model = load_model(model_type = "Pretrained", num_classes = 0, model_directory = "/Geneformer", mode = "eval", quantize=False)).

You can give it the data you mentioned you prepared previously and do a forward pass through the model, extracting the predictions from the logits for the masked genes.

Thank you,
Do you have an idea why, when I mask several tokens in a cell, the logits for all the masks token always predict the same token/gene ?
In a way, however the number of masked genes, the model always imputes only one gene without considering the rank.

Predicted token, which, by the way, may already be in the tokenized cell (and ranked before the mask token).

This is not the usual behavior. Another option you can try is using the provided pretraining code and then providing your dataset of interest as the eval dataset. Then it is more likely the code should be set up correctly and you can check out this code to have it work for what you are trying to do.

Thank you for the suggestion!
Actually, I re-tried and examined in details the results for a test of 5 cells ; so in fact, the logits don't predict the same token/gene for all the masks token, but there is a lot of redundancy in the prediction (see unique_predicted_gene compared nb_masked_genes) and still predictions of genes that are already present in the transcriptome (not masked).

Did you encounter the same problems when pretraining the model ?

Results for cell 1:
nb_genes: 4096
nb_masked_genes: 399
unique_predicted_genes: 14
true_positives: 6
new_predicted_genes: 4
predicted_genes_already_present: 4

Results for cell 2:
nb_genes: 4096
nb_masked_genes: 399
unique_predicted_genes: 137
true_positives: 20
new_predicted_genes: 77
predicted_genes_already_present: 40

Results for cell 3:
nb_genes: 3988
nb_masked_genes: 380
unique_predicted_genes: 221
true_positives: 43
new_predicted_genes: 164
predicted_genes_already_present: 14

Results for cell 4:
nb_genes: 4096
nb_masked_genes: 418
unique_predicted_genes: 225
true_positives: 29
new_predicted_genes: 159
predicted_genes_already_present: 37

Results for cell 5:
nb_genes: 4096
nb_masked_genes: 396
unique_predicted_genes: 159
true_positives: 25
new_predicted_genes: 117
predicted_genes_already_present: 17

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment