{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Geneformer Fine-Tuning for Classification of Dosage-Sensitive vs. -Insensitive Transcription Factors (TFs)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "GPU_NUMBER = [0]\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n", "os.environ[\"NCCL_DEBUG\"] = \"INFO\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# imports\n", "import datetime\n", "import subprocess\n", "import math\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from datasets import load_from_disk\n", "from sklearn import preprocessing\n", "from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve\n", "from sklearn.model_selection import StratifiedKFold\n", "import torch\n", "from transformers import BertForTokenClassification\n", "from transformers import Trainer\n", "from transformers.training_args import TrainingArguments\n", "from tqdm.notebook import tqdm\n", "\n", "from geneformer import DataCollatorForGeneClassification\n", "from geneformer.pretrainer import token_dictionary" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Gene Attribute Information" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)\n", "gene_info = pd.read_csv(\"/path/to/gene_info_table.csv\", index_col=0)\n", "\n", "# create dictionaries for corresponding attributes\n", "gene_id_type_dict = dict(zip(gene_info[\"ensembl_id\"],gene_info[\"gene_type\"]))\n", "gene_name_id_dict = dict(zip(gene_info[\"gene_name\"],gene_info[\"ensembl_id\"]))\n", "gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Training Data and Class Labels" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# function for preparing targets and labels\n", "def prep_inputs(genegroup1, genegroup2, id_type):\n", " if id_type == \"gene_name\":\n", " targets1 = [gene_name_id_dict[gene] for gene in genegroup1 if gene_name_id_dict.get(gene) in token_dictionary]\n", " targets2 = [gene_name_id_dict[gene] for gene in genegroup2 if gene_name_id_dict.get(gene) in token_dictionary]\n", " elif id_type == \"ensembl_id\":\n", " targets1 = [gene for gene in genegroup1 if gene in token_dictionary]\n", " targets2 = [gene for gene in genegroup2 if gene in token_dictionary]\n", " \n", " targets1_id = [token_dictionary[gene] for gene in targets1]\n", " targets2_id = [token_dictionary[gene] for gene in targets2]\n", " \n", " targets = np.array(targets1_id + targets2_id)\n", " labels = np.array([0]*len(targets1_id) + [1]*len(targets2_id))\n", " nsplits = min(5, min(len(targets1_id), len(targets2_id))-1)\n", " assert nsplits > 2\n", " print(f\"# targets1: {len(targets1_id)}\\n# targets2: {len(targets2_id)}\\n# splits: {nsplits}\")\n", " return targets, labels, nsplits" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# preparing targets and labels for dosage sensitive vs insensitive TFs\n", "dosage_tfs = pd.read_csv(\"/path/to/dosage_sens_tf_labels.csv\", header=0)\n", "sensitive = dosage_tfs[\"dosage_sensitive\"].dropna()\n", "insensitive = dosage_tfs[\"dosage_insensitive\"].dropna()\n", "targets, labels, nsplits = prep_inputs(sensitive, insensitive, \"ensembl_id\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# load training dataset\n", "train_dataset=load_from_disk(\"/path/to/gene_train_data.dataset\")\n", "shuffled_train_dataset = train_dataset.shuffle(seed=42)\n", "subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(50_000)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Functions for Training and Cross-Validating Classifier" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def preprocess_classifier_batch(cell_batch):\n", " max_batch_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n", " def pad_label_example(example):\n", " example[\"labels\"] = np.pad(example[\"labels\"], \n", " (0, max_batch_len-len(example[\"input_ids\"])), \n", " mode='constant', constant_values=-100)\n", " example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n", " (0, max_batch_len-len(example[\"input_ids\"])), \n", " mode='constant', constant_values=token_dictionary.get(\"\"))\n", " example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"\")).astype(int)\n", " return example\n", " padded_batch = cell_batch.map(pad_label_example)\n", " return padded_batch\n", "\n", "# forward batch size is batch size for model inference (e.g. 200)\n", "def classifier_predict(model, evalset, forward_batch_size, mean_fpr):\n", " predict_logits = []\n", " predict_labels = []\n", " model.eval()\n", " for i in range(0, len(evalset), forward_batch_size):\n", " max_range = min(i+forward_batch_size,len(evalset))\n", " batch_evalset = evalset.select([i for i in range(i, max_range)])\n", " padded_batch = preprocess_classifier_batch(batch_evalset)\n", " padded_batch.set_format(type=\"torch\")\n", " \n", " input_data_batch = padded_batch[\"input_ids\"]\n", " attn_msk_batch = padded_batch[\"attention_mask\"]\n", " label_batch = padded_batch[\"labels\"]\n", " with torch.no_grad():\n", " outputs = model(\n", " input_ids = input_data_batch.to(\"cuda\"), \n", " attention_mask = attn_msk_batch.to(\"cuda\"), \n", " labels = label_batch.to(\"cuda\"), \n", " )\n", " predict_logits += [torch.squeeze(outputs.logits.to(\"cpu\"))]\n", " predict_labels += [torch.squeeze(label_batch.to(\"cpu\"))]\n", " \n", " logits_by_cell = torch.cat(predict_logits)\n", " all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])\n", " labels_by_cell = torch.cat(predict_labels)\n", " all_labels = torch.flatten(labels_by_cell)\n", " logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]\n", " y_pred = [vote(item[0]) for item in logit_label_paired]\n", " y_true = [item[1] for item in logit_label_paired]\n", " logits_list = [item[0] for item in logit_label_paired]\n", " # probability of class 1\n", " y_score = [py_softmax(item)[1] for item in logits_list]\n", " conf_mat = confusion_matrix(y_true, y_pred)\n", " fpr, tpr, _ = roc_curve(y_true, y_score)\n", " # plot roc_curve for this split\n", " plt.plot(fpr, tpr)\n", " plt.xlim([0.0, 1.0])\n", " plt.ylim([0.0, 1.05])\n", " plt.xlabel('False Positive Rate')\n", " plt.ylabel('True Positive Rate')\n", " plt.title('ROC')\n", " plt.show()\n", " # interpolate to graph\n", " interp_tpr = np.interp(mean_fpr, fpr, tpr)\n", " interp_tpr[0] = 0.0\n", " return fpr, tpr, interp_tpr, conf_mat \n", "\n", "def vote(logit_pair):\n", " a, b = logit_pair\n", " if a > b:\n", " return 0\n", " elif b > a:\n", " return 1\n", " elif a == b:\n", " return \"tie\"\n", " \n", "def py_softmax(vector):\n", "\te = np.exp(vector)\n", "\treturn e / e.sum()\n", " \n", "# get cross-validated mean and sd metrics\n", "def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):\n", " wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]\n", " print(wts)\n", " all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]\n", " mean_tpr = np.sum(all_weighted_tpr, axis=0)\n", " mean_tpr[-1] = 1.0\n", " all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n", " roc_auc = np.sum(all_weighted_roc_auc)\n", " roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n", " return mean_tpr, roc_auc, roc_auc_sd" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# cross-validate gene classifier\n", "def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc):\n", " # check if output directory already written to\n", " # ensure not overwriting previously saved model\n", " model_dir_test = os.path.join(output_dir, \"ksplit0/models/pytorch_model.bin\")\n", " if os.path.isfile(model_dir_test) == True:\n", " raise Exception(\"Model already saved to this directory.\")\n", " \n", " # initiate eval metrics to return\n", " num_classes = len(set(labels))\n", " mean_fpr = np.linspace(0, 1, 100)\n", " all_tpr = []\n", " all_roc_auc = []\n", " all_tpr_wt = []\n", " label_dicts = []\n", " confusion = np.zeros((num_classes,num_classes))\n", " \n", " # set up cross-validation splits\n", " skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)\n", " # train and evaluate\n", " iteration_num = 0\n", " for train_index, eval_index in tqdm(skf.split(targets, labels)):\n", " if len(labels) > 500:\n", " print(\"early stopping activated due to large # of training examples\")\n", " nsplits = 3\n", " if iteration_num == 3:\n", " break\n", " print(f\"****** Crossval split: {iteration_num}/{nsplits-1} ******\\n\")\n", " # generate cross-validation splits\n", " targets_train, targets_eval = targets[train_index], targets[eval_index]\n", " labels_train, labels_eval = labels[train_index], labels[eval_index]\n", " label_dict_train = dict(zip(targets_train, labels_train))\n", " label_dict_eval = dict(zip(targets_eval, labels_eval))\n", " label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)\n", " \n", " # function to filter by whether contains train or eval labels\n", " def if_contains_train_label(example):\n", " a = label_dict_train.keys()\n", " b = example['input_ids']\n", " return not set(a).isdisjoint(b)\n", "\n", " def if_contains_eval_label(example):\n", " a = label_dict_eval.keys()\n", " b = example['input_ids']\n", " return not set(a).isdisjoint(b)\n", " \n", " # filter dataset for examples containing classes for this split\n", " print(f\"Filtering training data\")\n", " trainset = data.filter(if_contains_train_label, num_proc=num_proc)\n", " print(f\"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\\n\")\n", " print(f\"Filtering evalation data\")\n", " evalset = data.filter(if_contains_eval_label, num_proc=num_proc)\n", " print(f\"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\\n\")\n", "\n", " # minimize to smaller training sample\n", " training_size = min(subsample_size, len(trainset))\n", " trainset_min = trainset.select([i for i in range(training_size)])\n", " eval_size = min(training_size, len(evalset))\n", " half_training_size = round(eval_size/2)\n", " evalset_train_min = evalset.select([i for i in range(half_training_size)])\n", " evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])\n", " \n", " # label conversion functions\n", " def generate_train_labels(example):\n", " example[\"labels\"] = [label_dict_train.get(token_id, -100) for token_id in example[\"input_ids\"]]\n", " return example\n", "\n", " def generate_eval_labels(example):\n", " example[\"labels\"] = [label_dict_eval.get(token_id, -100) for token_id in example[\"input_ids\"]]\n", " return example\n", " \n", " # label datasets \n", " print(f\"Labeling training data\")\n", " trainset_labeled = trainset_min.map(generate_train_labels)\n", " print(f\"Labeling evaluation data\")\n", " evalset_train_labeled = evalset_train_min.map(generate_eval_labels)\n", " print(f\"Labeling evaluation OOS data\")\n", " evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)\n", " \n", " # create output directories\n", " ksplit_output_dir = os.path.join(output_dir, f\"ksplit{iteration_num}\")\n", " ksplit_model_dir = os.path.join(ksplit_output_dir, \"models/\") \n", " \n", " # ensure not overwriting previously saved model\n", " model_output_file = os.path.join(ksplit_model_dir, \"pytorch_model.bin\")\n", " if os.path.isfile(model_output_file) == True:\n", " raise Exception(\"Model already saved to this directory.\")\n", "\n", " # make training and model output directories\n", " subprocess.call(f'mkdir {ksplit_output_dir}', shell=True)\n", " subprocess.call(f'mkdir {ksplit_model_dir}', shell=True)\n", " \n", " # load model\n", " model = BertForTokenClassification.from_pretrained(\n", " \"/path/to/pretrained_model/\",\n", " num_labels=2,\n", " output_attentions = False,\n", " output_hidden_states = False\n", " )\n", " if freeze_layers is not None:\n", " modules_to_freeze = model.bert.encoder.layer[:freeze_layers]\n", " for module in modules_to_freeze:\n", " for param in module.parameters():\n", " param.requires_grad = False\n", " \n", " model = model.to(\"cuda:0\")\n", " \n", " # add output directory to training args and initiate\n", " training_args[\"output_dir\"] = ksplit_output_dir\n", " training_args_init = TrainingArguments(**training_args)\n", " \n", " # create the trainer\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args_init,\n", " data_collator=DataCollatorForGeneClassification(),\n", " train_dataset=trainset_labeled,\n", " eval_dataset=evalset_train_labeled\n", " )\n", "\n", " # train the gene classifier\n", " trainer.train()\n", " \n", " # save model\n", " trainer.save_model(ksplit_model_dir)\n", " \n", " # evaluate model\n", " fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 200, mean_fpr)\n", " \n", " # append to tpr and roc lists\n", " confusion = confusion + conf_mat\n", " all_tpr.append(interp_tpr)\n", " all_roc_auc.append(auc(fpr, tpr))\n", " # append number of eval examples by which to weight tpr in averaged graphs\n", " all_tpr_wt.append(len(tpr))\n", " \n", " iteration_num = iteration_num + 1\n", " \n", " # get overall metrics for cross-validation\n", " mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)\n", " return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Functions for Plotting Results" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# plot ROC curve\n", "def plot_ROC(bundled_data, title):\n", " plt.figure()\n", " lw = 2\n", " for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:\n", " plt.plot(mean_fpr, mean_tpr, color=color,\n", " lw=lw, label=\"{0} (AUC {1:0.2f} $\\pm$ {2:0.2f})\".format(sample, roc_auc, roc_auc_sd))\n", " plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')\n", " plt.xlim([0.0, 1.0])\n", " plt.ylim([0.0, 1.05])\n", " plt.xlabel('False Positive Rate')\n", " plt.ylabel('True Positive Rate')\n", " plt.title(title)\n", " plt.legend(loc=\"lower right\")\n", " plt.show()\n", " \n", "# plot confusion matrix\n", "def plot_confusion_matrix(classes_list, conf_mat, title):\n", " display_labels = []\n", " i = 0\n", " for label in classes_list:\n", " display_labels += [\"{0}\\nn={1:.0f}\".format(label, sum(conf_mat[:,i]))]\n", " i = i + 1\n", " display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm=\"l1\"), \n", " display_labels=display_labels)\n", " display.plot(cmap=\"Blues\",values_format=\".2g\")\n", " plt.title(title)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine-Tune With Gene Classification Learning Objective and Quantify Predictive Performance" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# set model parameters\n", "# max input size\n", "max_input_size = 2 ** 11 # 2048\n", "\n", "# set training parameters\n", "# max learning rate\n", "max_lr = 5e-5\n", "# how many pretrained layers to freeze\n", "freeze_layers = 4\n", "# number gpus\n", "num_gpus = 1\n", "# number cpu cores\n", "num_proc = 24\n", "# batch size for training and eval\n", "geneformer_batch_size = 12\n", "# learning schedule\n", "lr_schedule_fn = \"linear\"\n", "# warmup steps\n", "warmup_steps = 500\n", "# number of epochs\n", "epochs = 1\n", "# optimizer\n", "optimizer = \"adamw\"" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "tags": [] }, "outputs": [], "source": [ "# set training arguments\n", "subsample_size = 10_000\n", "training_args = {\n", " \"learning_rate\": max_lr,\n", " \"do_train\": True,\n", " \"evaluation_strategy\": \"no\",\n", " \"save_strategy\": \"epoch\",\n", " \"logging_steps\": 100,\n", " \"group_by_length\": True,\n", " \"length_column_name\": \"length\",\n", " \"disable_tqdm\": False,\n", " \"lr_scheduler_type\": lr_schedule_fn,\n", " \"warmup_steps\": warmup_steps,\n", " \"weight_decay\": 0.001,\n", " \"per_device_train_batch_size\": geneformer_batch_size,\n", " \"per_device_eval_batch_size\": geneformer_batch_size,\n", " \"num_train_epochs\": epochs,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# define output directory path\n", "current_date = datetime.datetime.now()\n", "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n", "training_output_dir = f\"/path/to/models/{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}/\"\n", "\n", "# ensure not overwriting previously saved model\n", "ksplit_model_test = os.path.join(training_output_dir, \"ksplit0/models/pytorch_model.bin\")\n", "if os.path.isfile(ksplit_model_test) == True:\n", " raise Exception(\"Model already saved to this directory.\")\n", "\n", "# make output directory\n", "subprocess.call(f'mkdir {training_output_dir}', shell=True)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3da0ae9f71de4f8b982948a2a9807dfd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3224634f88c19116.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-5534ad8f3f0cf000.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-bfb98c01d951ae8d.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-29ac8ab551fb8961.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-03912be57f358581.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "****** Crossval split: 0/4 ******\n", "\n", "Filtering training data\n", "Filtered 36%; 31897 remain\n", "\n", "Filtering evalation data\n", "Filtered 49%; 25258 remain\n", "\n", "Labeling training data\n", "Labeling evaluation data\n", "Labeling evaluation OOS data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", ":45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [834/834 01:33, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.684000
2000.617600
3000.477400
4000.334300
5000.229500
6000.152700
7000.125600
8000.104900

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-4d8947ed4c65f4a4.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-8a83f628e23d5548.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c6c437341faa1cfe.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-2010c177e27e09d1.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-15543d980ad3cbb0.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-a81a942ab15e4aa3.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-5d2c963673bb1115.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6c7cc476a9d722c3.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e274abd189113bba.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-1aedba9e0b982e5c.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6668161997480231.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d802b8093fb9c6f7.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3ea48baa5fe880e2.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-86024b6184e99afe.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-7a47db2c9f9758a4.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-af1f6b8f743677db.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-67cffffa35fa22f7.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-81ed63bd02a44ee5.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6e5a21d4d57e333d.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-eecde81c07e6d036.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-fcc19fab82bb7115.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-ea856d7fa4e78b24.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-698344adb3749f61.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-ee3f9e89abdbee4c.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d98fd9d7fda61d3b.arrow\n" ] }, { "data": { "image/png": "", "text/plain": [ "

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-1cc2a7963b74376c.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-25d39eb14def0850.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "****** Crossval split: 1/4 ******\n", "\n", "Filtering training data\n", "Filtered 35%; 32406 remain\n", "\n", "Filtering evalation data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-407cdf2a13a57414.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-5b5ee37df8a97b60.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Filtered 52%; 23996 remain\n", "\n", "Labeling training data\n", "Labeling evaluation data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-26e9dc90c3620d42.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Labeling evaluation OOS data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", ":45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [834/834 01:33, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.658900
2000.585400
3000.474600
4000.346600
5000.257400
6000.185800
7000.134200
8000.114500

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-cbfcb02a16dd9d81.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-b151d664d8c68613.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-52266cf801a76344.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-5c7ceff44bad692c.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-81bcbb23e61bfc0c.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e99a8c7eedd34769.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-6d7d5150907035d9.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-735b525b0abf0f74.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-9a47cf8290cd2f6b.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-56deb15eec02ca33.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-2aea162267b33f73.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3bc7a169c841323d.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-1f67206928846c7a.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-88375062775280fb.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-bb45ebd2db699b53.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-fd6e4344cc2f8033.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-b8a9338cde5e5801.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c013876f43a71ad7.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-148c328cb89da5c3.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-488b3d116a6d3b19.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-835e3e1538e24397.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d176e8ab14f1ce28.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3451fb13f869a5b0.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-56f270f895acc3ff.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-db497551e7a1e808.arrow\n" ] }, { "data": { "image/png": "", "text/plain": [ "

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-b9477826fb507d36.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-f814e2d804a22203.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-24ae0c22f739e6fa.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-7447dd57147cebd3.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "****** Crossval split: 2/4 ******\n", "\n", "Filtering training data\n", "Filtered 35%; 32462 remain\n", "\n", "Filtering evalation data\n", "Filtered 52%; 24113 remain\n", "\n", "Labeling training data\n", "Labeling evaluation data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3d0888fca1887e80.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Labeling evaluation OOS data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", ":45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [834/834 01:33, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.645900
2000.582800
3000.461700
4000.350200
5000.262800
6000.180400
7000.140900
8000.109600

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-8e85e7414566994a.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e2704cdfc217c3e3.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e213b038886d7cd4.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d6c9eba9fe9ffafc.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-442181417de57bb6.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-0d8563be811b9c30.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-85690e0bf5863858.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3bdda0a32e054f19.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3abe0ffb170c29f0.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-b132478871346000.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-09db8f6a69301008.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-34ae599619e2ced6.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c74b97625f913f63.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-228b6002a6690208.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d644cc9c55478a2a.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d3d097800ebd687c.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-2e536900ba2b88cc.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-0434f2adbb78af27.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-926036de71570e84.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-d7f012de8332824e.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-57a002ae2aa9ba42.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-0476d5fed302e1c5.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-69341790285e8ce2.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-ee190fa69ba78df3.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-4b3dc879e23e8e63.arrow\n" ] }, { "data": { "image/png": "", "text/plain": [ "

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-3c8713ea9ca7fcf8.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c51c509a283b1c08.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-e3bf280f62a1ecd0.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-8a4c3c27f7ce74ce.arrow\n", "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-11b5a95b53a4e86b.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "****** Crossval split: 3/4 ******\n", "\n", "Filtering training data\n", "Filtered 35%; 32464 remain\n", "\n", "Filtering evalation data\n", "Filtered 53%; 23712 remain\n", "\n", "Labeling training data\n", "Labeling evaluation data\n", "Labeling evaluation OOS data\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", ":45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [834/834 01:32, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.660300
2000.588000
3000.465400
4000.331400
5000.241100
6000.168800
7000.136600
8000.113900

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /n/holyscratch01/xiaoleliu_lab/Users/ctheodoris/datasets/geneformer_corpus_2048_sorted.dataset/cache-c438e6f7f8463bbc.arrow\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6f8a9dd0a5754dec845c0022470a8c96", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "eaa8acd785b34fe8ab7e2853b745bf9c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "55815cca43374fe1867219af483785e4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "46388a65e68440928be961d7ae57bd05", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "17799d65feac4638a0071df44f6432db", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e103daf395794272989c209b32c12afc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "81053043727a4c1dbe23304e5ad6282a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5d1d3f2835b74004b267d67d04c24663", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "14f38354b0354bc187be9db34990fcce", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4e3d47f0ecdc489ca34de778ebfb3021", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5997f34a471f4a918fd32043fc519bb3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "affe20b63e08414cb0863e1f6c1aad18", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fca7f8cafa504738b7eaddd3f7b708fc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "11f299f23b124674ab9e334bdbe09288", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "01a88ef05cb64f24adecfb5674265a02", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2f88e6525cbd486c9f03491a04681283", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8bb884df7370471d986c51c10431ba10", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4b82e5fe600b4270bb6268e68f76d093", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cd15c803ecc34a8d878df577ffd80252", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "246cac7b5a0b4fd799e7e2081badbdbf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fbc93f4256724314a5141ac29062bae9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b38551b3ac134fef8aa0c6ea3b7fa2a0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "16ddc360a6b64906bd3f1d1adcc94efe", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "44b3af87a1794fc09d00dd3743c4705d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/png": "", "text/plain": [ "

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "****** Crossval split: 4/4 ******\n", "\n", "Filtering training data\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be5426abaf5b41ebb51e2567dd73b0a4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Filtered 35%; 32428 remain\n", "\n", "Filtering evalation data\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ff5aad423e4f4bbab54518bc5f0fd028", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Filtered 53%; 23660 remain\n", "\n", "Labeling training data\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "78c25d0976854653be92baf65ca71158", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Labeling evaluation data\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c445de0805e145249f4647e5552292a2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Labeling evaluation OOS data\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c553f188f56e47acafa77fab9cb2b21f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=5000.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForTokenClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", ":45: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [834/834 01:35, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
1000.663500
2000.601800
3000.486200
4000.340400
5000.242700
6000.202300
7000.153600
8000.124400

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0e1c475ab2ff4bfa8c65a24d587c8ad0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ee8ff99342d4741a3f4ec4176b5d746", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "78a1a6af9439481ebe87731bb2d37c95", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "411ed284d33740eca1f0cef18df500a4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "aafdf3014691426c9c6acca3834c45f2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5aa3add5de134f589eaab69087b66549", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7d255e53e1c2408697da1fa08860c9c0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "29b8945f64354ae1b840a1dc316dedbf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "de251d1fba3d4a67893047ee8275d606", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8928cf69ea8746b2bef14028c0c0274a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0c0c4e21626f4ab99ce0696ee9322e0c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9e3499a2376d43bab0086cba34d1b522", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f33d4f879c294c6a8a6455b3692488d5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "38dd78e3ebf44c2bad58f9576a525ab3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b052e8b179584043945b49de9af31676", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e3e11781b4394db1a01454ef37a490f2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "915efb0adfb44c5caa01cf213c3cd56b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ceb10f0f87d044ebab534aefef5ec69c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31f4bd65079e4983b8a1937901cfbace", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ccb5be44b5494de8862488f82bf01741", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9da6bd7370db44889cab2fb81dcebe11", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "12bddf69336d481fb0076dced187523c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b89b616cd8064d248b37cc642a09b9bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9346181e5b8b4f1b9a562ca676f87d38", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "de9f0442fc1e43f8bb06e4cecf719d67", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/png": "", "text/plain": [ "

" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "[0.24272061700106187, 0.1890124629743475, 0.1665455764824233, 0.212820656122506, 0.18890068741966132]\n" ] } ], "source": [ "# cross-validate gene classifier\n", "all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts \\\n", " = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# bundle data for plotting\n", "bundled_data = []\n", "bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, \"Geneformer\", \"red\")]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plot ROC curve\n", "plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plot confusion matrix\n", "classes_list = [\"Dosage Sensitive\", \"Dosage Insensitive\"]\n", "plot_confusion_matrix(classes_list, confusion, \"Geneformer\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "vscode": { "interpreter": { "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829" } } }, "nbformat": 4, "nbformat_minor": 4 }