diff --git "a/examples/cell_classification_test_pr.ipynb" "b/examples/cell_classification_test_pr.ipynb"
new file mode 100644--- /dev/null
+++ "b/examples/cell_classification_test_pr.ipynb"
@@ -0,0 +1,454 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "65a2b29a-c678-4874-a1bf-5af3a7d00ed9",
+ "metadata": {},
+ "source": [
+ "## Geneformer Fine-Tuning for Classification of Cardiomyopathy Disease States"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
+ "metadata": {},
+ "source": [
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3dad7564-b464-4d37-9188-17c0ae4ae59f",
+ "metadata": {},
+ "source": [
+ "### Train cell classifier with 70% of data (with hyperparameters previously optimized based on 15% of data as validation set) and evaluate on held-out test set of 15% of data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9027e51e-7830-4ab8-aebf-b9779b3ea2c1",
+ "metadata": {},
+ "source": [
+ "### Fine-tune the model for cell state classification"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "efe3b79b-aa8f-416c-9755-7f9299d6a81e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import datetime\n",
+ "from geneformer import Classifier\n",
+ "\n",
+ "current_date = datetime.datetime.now()\n",
+ "datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}{current_date.hour:02d}{current_date.minute:02d}{current_date.second:02d}\"\n",
+ "datestamp_min = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
+ "\n",
+ "output_prefix = \"cm_classifier_test\"\n",
+ "output_dir = f\"/path/to/output_dir/{datestamp}\"\n",
+ "!mkdir $output_dir"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "f070ab20-1b18-4941-a5c7-89e23b519261",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filter_data_dict={\"cell_type\":[\"Cardiomyocyte1\",\"Cardiomyocyte2\",\"Cardiomyocyte3\"]}\n",
+ "training_args = {\n",
+ " \"num_train_epochs\": 0.9,\n",
+ " \"learning_rate\": 0.000804,\n",
+ " \"lr_scheduler_type\": \"polynomial\",\n",
+ " \"warmup_steps\": 1812,\n",
+ " \"weight_decay\":0.258828,\n",
+ " \"per_device_train_batch_size\": 12,\n",
+ " \"seed\": 73,\n",
+ "}\n",
+ "cc = Classifier(classifier=\"cell\",\n",
+ " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
+ " filter_data=filter_data_dict,\n",
+ " training_args=training_args,\n",
+ " max_ncells=None,\n",
+ " freeze_layers = 2,\n",
+ " num_crossval_splits = 1,\n",
+ " forward_batch_size=200,\n",
+ " nproc=16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "0bced2e8-0a49-418e-a7f9-3981be256bd6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "9c409ca656ed4cb0b280d95e326c1bc7",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Saving the dataset (0/3 shards): 0%| | 0/115367 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "facb7207b57948aebb3f8681346e17d4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Saving the dataset (0/1 shards): 0%| | 0/17228 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# previously balanced splits with prepare_data and validate functions\n",
+ "# argument attr_to_split set to \"individual\" and attr_to_balance set to [\"disease\",\"lvef\",\"age\",\"sex\",\"length\"]\n",
+ "train_ids = [\"1447\", \"1600\", \"1462\", \"1558\", \"1300\", \"1508\", \"1358\", \"1678\", \"1561\", \"1304\", \"1610\", \"1430\", \"1472\", \"1707\", \"1726\", \"1504\", \"1425\", \"1617\", \"1631\", \"1735\", \"1582\", \"1722\", \"1622\", \"1630\", \"1290\", \"1479\", \"1371\", \"1549\", \"1515\"]\n",
+ "eval_ids = [\"1422\", \"1510\", \"1539\", \"1606\", \"1702\"]\n",
+ "test_ids = [\"1437\", \"1516\", \"1602\", \"1685\", \"1718\"]\n",
+ "\n",
+ "train_test_id_split_dict = {\"attr_key\": \"individual\",\n",
+ " \"train\": train_ids+eval_ids,\n",
+ " \"test\": test_ids}\n",
+ "\n",
+ "# Example input_data_file: https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/tree/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset\n",
+ "cc.prepare_data(input_data_file=\"/path/to/human_dcm_hcm_nf_2048_w_length.dataset\",\n",
+ " output_directory=output_dir,\n",
+ " output_prefix=output_prefix,\n",
+ " split_id_dict=train_test_id_split_dict)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "73fe8b29-dd8f-4bf8-82c1-53196d73ed49",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "691e875524e441bca22b790a0f4a2a35",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/1 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "****** Validation split: 1/1 ******\n",
+ "\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c2c4f53aa71a49b89c32c8ba573b0b0c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Filter (num_proc=16): 0%| | 0/115367 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "adf76144219747558bf39b7e776a68b3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Filter (num_proc=16): 0%| | 0/115367 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /gladstone/theodoris/home/ctheodoris/Geneformer and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
+ "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " [7020/7020 26:02, Epoch 0/1]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ " Accuracy | \n",
+ " Macro F1 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.142400 | \n",
+ " 0.389166 | \n",
+ " 0.889797 | \n",
+ " 0.693074 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/gladstone/theodoris/home/ctheodoris/Geneformer/geneformer/collator_for_classification.py:581: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "train_valid_id_split_dict = {\"attr_key\": \"individual\",\n",
+ " \"train\": train_ids,\n",
+ " \"eval\": eval_ids}\n",
+ "\n",
+ "# 6 layer Geneformer: https://huggingface.co/ctheodoris/Geneformer/blob/main/model.safetensors\n",
+ "all_metrics = cc.validate(model_directory=\"/path/to/Geneformer\",\n",
+ " prepared_input_data_file=f\"{output_dir}/{output_prefix}_labeled_train.dataset\",\n",
+ " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
+ " output_directory=output_dir,\n",
+ " output_prefix=output_prefix,\n",
+ " split_id_dict=train_valid_id_split_dict)\n",
+ " # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6eca8ab4-6f4d-4dd6-9b90-edfb5cc7417c",
+ "metadata": {},
+ "source": [
+ "### Evaluate the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "f580021e-2b70-4ebc-943c-2bfe6177e1b5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Hyperparameter tuning is highly recommended for optimal results. No training_args provided; using default hyperparameters.\n"
+ ]
+ }
+ ],
+ "source": [
+ "cc = Classifier(classifier=\"cell\",\n",
+ " cell_state_dict = {\"state_key\": \"disease\", \"states\": \"all\"},\n",
+ " forward_batch_size=200,\n",
+ " nproc=16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "b05398b4-bca1-44b0-8160-637489f16646",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8e93a706295b49a1996b275eba3e9f31",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/87 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "all_metrics_test = cc.evaluate_saved_model(\n",
+ " model_directory=f\"{output_dir}/{datestamp_min}_geneformer_cellClassifier_{output_prefix}/ksplit1/\",\n",
+ " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
+ " test_data_file=f\"{output_dir}/{output_prefix}_labeled_test.dataset\",\n",
+ " output_directory=output_dir,\n",
+ " output_prefix=output_prefix,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b45404e4-87cc-421d-84f5-1f9cbc09aa31",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "