Single gene perturbation error

#243
by swang12 - opened

Hi, I'm running InSilicoPerturber for a single gene:

for gene in tqdm(genes_to_perturb):
output_directory = f'{work_directory}{gene}/'
os.mkdir(output_directory)

gene_id = df.at[gene, 'Ensembl_ID']

shutil.copytree(f'{root_directory}tokenized_.dataset', f'{output_directory}tokenized_copy.dataset')

isp = InSilicoPerturber(perturb_type='delete', 
                        perturb_rank_shift=None, 
                        genes_to_perturb=[gene_id], 
                        combos=0, 
                        anchor_gene=None, 
                        model_type='Pretrained', 
                        num_classes=0, 
                        emb_mode='cell_and_gene', 
                        cell_emb_style='mean_pool', 
                        cell_states_to_model=None, 
                        max_ncells=2000, 
                        emb_layer=-1, 
                        forward_batch_size=32, 
                        nproc=16
                        )

isp.perturb_data(model_directory='Geneformer/', 
                 input_data_file=f'{output_directory}tokenized_copy.dataset', 
                 output_directory=output_directory, 
                 output_prefix='perturbed'
                 )

ispstats = InSilicoPerturberStats(mode='mixture_model', 
                                  combos=0, 
                                  anchor_gene=None, 
                                  cell_states_to_model=None
                                  )

ispstats.get_stats(input_data_directory=output_directory,
                   null_dist_data_directory=None,
                   output_directory=output_directory,
                   output_prefix='stats'
                   )

I tried both python3.8.10 and python3.10.12. For any gene, I got similar errors:

Traceback (most recent call last):
File "/raid/swang12/Transformer31012/perturb_one.py", line 80, in
isp.perturb_data(model_directory='Geneformer/',
File "/home/swang12/anaconda3/envs/myenv/lib/python3.10/site-packages/geneformer/in_silico_perturber.py", line 981, in perturb_data
self.in_silico_perturb(model,
File "/home/swang12/anaconda3/envs/myenv/lib/python3.10/site-packages/geneformer/in_silico_perturber.py", line 1059, in in_silico_perturb
cos_sims_data = quant_cos_sims(model,
File "/home/swang12/anaconda3/envs/myenv/lib/python3.10/site-packages/geneformer/in_silico_perturber.py", line 445, in quant_cos_sims
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
File "/home/swang12/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/swang12/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/nn/modules/distance.py", line 87, in forward
return F.cosine_similarity(x1, x2, self.dim, self.eps)
RuntimeError: The size of tensor a (2047) must match the size of tensor b (2046) at non-singleton dimension 1

Could you help me with this? Thank you very much!

Sincerely,
Su

I'm also having this same issue, refer to discussion #85.

@swang12 @MinieRosie I found that this error only occurs at certain cell indices. While I haven't solved this error, I did find a work around that allows the function to process through more cells instead of throwing an error at the beginning. You can specify cell_inds_to_perturb={'start':0, 'end':2} and forward_batch_size=2, for example. Then loop the command to output a pickle file for every 2 cells. I personally had to loop the command in bash to prevent OOM errors. See example below. Hope this helps!

#!/usr/bin/env bash
# in_silico_perturbation_loop.sh

i=0
j=2
k=1

for a in {1..193}
do
  python3 in_silico_perturbation_loop.py -i $i -j $j -k $k;
  i=$(($i+2))
  j=$(($j+2))
  k=$(($k+1))
  rm /data/genecorpus_filtered_hep/cache*
  echo $i $j $k
done

python3 in_silico_perturbation_stats.py
#!/usr/bin/env python3
# in_silico_perturbation_loop.py

# Import libraries
from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
from datasets import load_from_disk
import os
import gc
import torch
import argparse

def main():
    args = get_cli_args()
    i = args.i
    j = args.j
    k = args.k

    # 1. Set parameters for perturbation function
    gc.collect()
    torch.cuda.empty_cache()

    isp = InSilicoPerturber(perturb_type="delete",
                            perturb_rank_shift=None,
                            genes_to_perturb=["ENSG00000118520"],
                            combos=0,
                            anchor_gene=None,
                            model_type="Pretrained",
                            num_classes=0,
                            emb_mode="cell",
                            cell_emb_style="mean_pool",
                            filter_data=None,
                            cell_states_to_model=None,
                            max_ncells=None,
                            cell_inds_to_perturb={"start":i,
                                                  "end":j},
                            emb_layer=-1,
                            forward_batch_size=2,
                            nproc=2,
                            token_dictionary_file = "/home/ubuntu/Geneformer/geneformer/token_dictionary.pkl")

    # 2. Perturb data
    isp.perturb_data("/home/ubuntu/Geneformer/",
                     "/data/genecorpus_filtered_hep/",
                     "/data/genecorpus_filtered_hep/delete_cell/",
                     f"cell_ARG1_{k}")

def get_cli_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('-i', dest='i',
                        type=int, help="Starting cell index",
                        required=True)
    parser.add_argument('-j', dest='j',
                        type=int, help="Ending cell index",
                        required=True)
    parser.add_argument('-k', dest='k',
                        type=int, help="Pickle file suffix",
                        required=True)

    return parser.parse_args()

if __name__ == '__main__':
    main()
#!/usr/bin/env python3
# in_silico_perturbation_stats.py

from geneformer import InSilicoPerturber
from geneformer import InSilicoPerturberStats
import argparse

# 1. Set parameters for perturbation statistics
ispstats = InSilicoPerturberStats(mode="aggregate_data",
                                  genes_perturbed=["ENSG00000118520"],
                                  combos=0,
                                  anchor_gene=None,
                                  cell_states_to_model=None,
                                  token_dictionary_file = "/home/ubuntu/Geneformer/geneformer/token_dictionary.pkl")

# 2. Get perturbation stats
ispstats.get_stats("/data/genecorpus_filtered_hep/delete_cell/",
                   None,
                   "/data/genecorpus_filtered_hep/delete_cell/",
                   "delete_cell_ARG1")

Thank you for your interest in Geneformer and for your patience! We pushed an update that should resolve this issue. If you continue to face errors after pulling the updated code, please let us know by either reopening this discussion if it's the same error or opening a new discussion if it's a new error. Thank you!

ctheodoris changed discussion status to closed

Sign up or log in to comment