Cell embedding extraction

#126
by DYXDAVE - opened

Thank you for your model!
I'm currently trying to analyze the cell embedding and I'm not sure I understand the extraction of cell embedding correctly.
The last layer I got from the model.predict should be batch_size x 2048 x 256 matrix, right? So I need to average across gene dimension to get a matrix of batch_size x 256, which would be my cell embedding. I'm wondering do the dimension of gene still represent the input length? Since in the input if a cell got less gene than 2048, there will be padding filled into the input.
In other words, do we add all numbers in the gene dimension and divide by 2048 or we should divide it by the number of gene that the current cell actually has(for example, maybe we only detect 2000 genes in that cell and we should divide the result by 2000)

Thank you for your question! Yes, 2048 is the input size of the model and 256 is the embedding dimensions parameter. So yes, the genes would be averaged for each cell to generate a 256-dimension embedding vector for each cell. If you have padding, you could consider removing the padding before averaging your embedding dimensions. Here is an example of an approach for this: https://stackoverflow.com/questions/76015844/how-to-efficiently-mean-pool-bert-embeddings-while-excluding-padding

Update:
@DYXDAVE
We have now added a function to extract and plot cell embeddings. Please see example here:
https://huggingface.co/ctheodoris/Geneformer/blob/main/examples/extract_and_plot_cell_embeddings.ipynb

ctheodoris changed discussion status to closed

So, from my observation, those matrix for padding token in the last layers are not all zero, so we should delete them to exclude their effect?

We exclude the padding tokens from the embeddings for the in silico perturbation analysis as well. So, yes, you could remove the padding token embeddings before averaging to exclude them from the cell embedding for your application too.

Sign up or log in to comment