{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "81b83fa8-421d-4be5-b9eb-5892f01fd5b0", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import os\n", "#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'\n", "from sentence_transformers import SentenceTransformer, InputExample, losses\n", "from torch.utils.data import DataLoader\n", "import torch.nn.functional as F\n", "import torch\n", "from sklearn.metrics import roc_auc_score" ] }, { "cell_type": "code", "execution_count": null, "id": "937cbcda-0cd6-47f7-b52e-17ed2bafce3d", "metadata": {}, "outputs": [], "source": [ "model = SentenceTransformer('dunzhang/stella_en_1.5b_v5', trust_remote_code=True, device='cuda')\n" ] }, { "cell_type": "code", "execution_count": null, "id": "853e6b86-db0b-4650-b98f-f437987baa5a", "metadata": {}, "outputs": [], "source": [ "cohort_checks = pd.read_csv('top_ten_cohorts_checked_synthetic.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "474950e0-869b-414e-823f-df5ba8e5de92", "metadata": {}, "outputs": [], "source": [ "cohort_checks.info()" ] }, { "cell_type": "code", "execution_count": null, "id": "c9835dad-4fc4-4a0e-aba2-d45358edbee9", "metadata": {}, "outputs": [], "source": [ "cohort_checks['mod_eligibility_result'] = np.where(cohort_checks.llama_response.str.contains('Yes!|YES!'), 1, 0)" ] }, { "cell_type": "code", "execution_count": null, "id": "413bddbc-f35c-48ec-bcbb-48405bd2c9c9", "metadata": {}, "outputs": [], "source": [ "cohort_checks.eligibility_result.value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "79c2d994-1e39-41b6-aeb3-962ba3ba5611", "metadata": {}, "outputs": [], "source": [ "cohort_checks.mod_eligibility_result.value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "9c8e6a20-4513-422c-be6e-3459ce98a2be", "metadata": {}, "outputs": [], "source": [ "patient_checks = pd.read_csv('top_twenty_patients_checked_synthetic.csv')" ] }, { "cell_type": "code", "execution_count": null, "id": "e91074cf-07de-40d1-8bf3-baf825d3f625", "metadata": {}, "outputs": [], "source": [ "patient_checks['mod_eligibility_result'] = np.where(patient_checks.llama_response.str.contains('Yes!|YES!'), 1, 0)" ] }, { "cell_type": "code", "execution_count": null, "id": "3a2d45fc-7d92-4a65-aad0-a9ab8f783779", "metadata": {}, "outputs": [], "source": [ "patient_checks.info()" ] }, { "cell_type": "code", "execution_count": null, "id": "0cf7c705-ba6d-4f01-ac2e-88d345fef7f6", "metadata": {}, "outputs": [], "source": [ "patient_checks.eligibility_result.value_counts(), patient_checks.mod_eligibility_result.value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "dec4a8c8-c5db-4164-a06c-27fb59782fa5", "metadata": {}, "outputs": [], "source": [ "patient_checks = patient_checks.rename(columns={'this_patient':'patient_summary', 'space_summary':'this_space'})" ] }, { "cell_type": "code", "execution_count": null, "id": "0bf55e82-c91d-472f-84ad-74c755e9bf29", "metadata": {}, "outputs": [], "source": [ "combined_checks = pd.concat([patient_checks, cohort_checks], axis=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "ebd07c9c-6263-4005-bfb1-2a8468b76a98", "metadata": {}, "outputs": [], "source": [ "combined_checks.info()" ] }, { "cell_type": "code", "execution_count": null, "id": "49f59429-c9f4-43df-a1b2-750a3c94517a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "e1d0126e-a58d-41ca-ad2f-a2d37bc585ad", "metadata": {}, "outputs": [], "source": [ "train_summaries = combined_checks[combined_checks.split=='train']\n", "train_summaries = train_summaries[~train_summaries.patient_summary.isnull()]\n", "train_summaries = train_summaries[~train_summaries.llama_response.isnull()]\n", "train_summaries.split.value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "2f6506ed-dcbf-4e0b-8722-6e234c2d4509", "metadata": {}, "outputs": [], "source": [ "train_summaries.mod_eligibility_result.value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "c678a59a-c301-42d9-83dd-511503cee2fb", "metadata": {}, "outputs": [], "source": [ "train_summaries.info()" ] }, { "cell_type": "code", "execution_count": null, "id": "57932264-103a-413b-9a48-43b7be254ac0", "metadata": {}, "outputs": [], "source": [ "# mll loss\n", "train_eligibles_only = train_summaries[train_summaries.eligibility_result == 1]\n", "example_list = []\n", "for i in range(train_eligibles_only.shape[0]):\n", " example_list.append(InputExample(texts=[train_summaries.patient_summary.iloc[i], train_summaries.this_space.iloc[i]]))\n", "\n", "train_eligibles_only_dataloader = DataLoader(example_list, shuffle=True, batch_size=8)\n", "train_eligibles_only_loss = losses.MultipleNegativesRankingLoss(model=model)" ] }, { "cell_type": "code", "execution_count": null, "id": "e5482be3-9a13-4ce1-aa8a-429c54bf6be0", "metadata": {}, "outputs": [], "source": [ "# for attempt at contrastive loss\n", "contrastive_example_list = []\n", "for i in range(train_summaries.shape[0]):\n", " contrastive_example_list.append(InputExample(texts=[train_summaries.patient_summary.iloc[i], train_summaries.this_space.iloc[i]],\n", " label=train_summaries.mod_eligibility_result.iloc[i]))\n", "\n", "contrastive_dataloader = DataLoader(contrastive_example_list, shuffle=True, batch_size=12)\n", "contrastive_train_loss = losses.OnlineContrastiveLoss(model=model)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e825dae-a5a9-4f87-af35-63ac2d73de33", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "f17ad7a6-8911-4d7d-8495-3e37cb00597d", "metadata": { "scrolled": true }, "outputs": [], "source": [ "#%%capture\n", "model.fit(train_objectives=[(contrastive_dataloader, contrastive_train_loss),\n", " (train_eligibles_only_dataloader, train_eligibles_only_loss)], epochs=2, warmup_steps=100)" ] }, { "cell_type": "code", "execution_count": null, "id": "c9cb6021-21d8-44bf-b440-980fcdae3b3d", "metadata": {}, "outputs": [], "source": [ "model.save('reranker_round1.model')" ] }, { "cell_type": "code", "execution_count": null, "id": "bae79a2e-4357-4c90-ba4c-a08b1206a99d", "metadata": {}, "outputs": [], "source": [ "model = SentenceTransformer('reranker_round1.model', trust_remote_code=True, device='cuda')" ] }, { "cell_type": "code", "execution_count": null, "id": "f5517caa-c45b-4b62-ae8d-0af61b61fd25", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "c6bfb8f7-ca6b-474b-8ce3-ba5acacb6b6a", "metadata": {}, "outputs": [], "source": [ "# check model's ability to do initial discriminate among diseases task\n", "# (on PHI)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4172f6ba-b334-4b83-b73e-d05dad6c05f0", "metadata": {}, "outputs": [], "source": [ "cohort_checks = pd.read_csv('../v7/space_specific_eligibility_checks_11-6-24.csv')\n", "# this cohort_checks file is not provided publicly, since it contains PHI/IP" ] }, { "cell_type": "code", "execution_count": null, "id": "d6b25941-0007-4347-9ef3-899f9258542a", "metadata": {}, "outputs": [], "source": [ "validation_set = cohort_checks[cohort_checks.split.str.contains('valid')]\n", "validation_set.info()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4b791608-6011-4bf6-914a-9534a08eba5a", "metadata": {}, "outputs": [], "source": [ "validation_set = validation_set[~validation_set.patient_summary.isnull()]\n", "validation_set.info()" ] }, { "cell_type": "code", "execution_count": null, "id": "479b9905-fcd6-4d37-9b03-7bbbfb88f123", "metadata": {}, "outputs": [], "source": [ "\n", "eligibles_only = validation_set[validation_set.eligibility_result == 1]\n", "patient_summary_embeddings = model.encode(eligibles_only.patient_summary.tolist())\n", "trial_summary_embeddings = model.encode(eligibles_only.this_space.tolist())" ] }, { "cell_type": "code", "execution_count": null, "id": "9b8f3a40-0854-43a5-bd83-a7fe6770f52b", "metadata": {}, "outputs": [], "source": [ "# among patient to trial space candidate matches that pass llama checks, how good is TrialSpace at discriminating between true and random matches?\n", "import random\n", "labels = []\n", "similarities = []\n", "for i in range(trial_summary_embeddings.shape[0]):\n", " if random.choice([0,1]) == 1:\n", " similarities.append(F.cosine_similarity(torch.tensor(patient_summary_embeddings[i,:]).unsqueeze(0), torch.tensor(trial_summary_embeddings[i, :]).unsqueeze(0)))\n", " labels.append(1.)\n", " else:\n", " random_index = random.choice([x for x in range(0,trial_summary_embeddings.shape[0])])\n", " similarities.append(F.cosine_similarity(torch.tensor(patient_summary_embeddings[i,:]).unsqueeze(0), torch.tensor(trial_summary_embeddings[random_index, :]).unsqueeze(0)))\n", " labels.append(0.)\n", "roc_auc_score(labels, np.array([x.numpy() for x in similarities]))" ] }, { "cell_type": "code", "execution_count": null, "id": "16dd4634-0389-466d-8257-160ddd2659af", "metadata": {}, "outputs": [], "source": [ "# how good are embeddings at discriminating between llama yes and no checks?\n", "# (on PHI)\n", "patient_summary_embeddings = model.encode(validation_set.patient_summary.tolist(), convert_to_tensor=True)\n", "trial_summary_embeddings = model.encode(validation_set.this_space.tolist(), convert_to_tensor=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "5bb0bc89-0b4f-451d-9523-550f7344e4d9", "metadata": {}, "outputs": [], "source": [ "similarities = F.cosine_similarity(patient_summary_embeddings, trial_summary_embeddings).detach().cpu().numpy()\n", "roc_auc_score(validation_set.eligibility_result, similarities)" ] }, { "cell_type": "code", "execution_count": null, "id": "c6035e62-8d28-49c5-8d0a-049633edd553", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4f587899-0101-4d81-91a7-f8ef72be949f", "metadata": {}, "outputs": [], "source": [ "validation_set.eligibility_result.value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "453c2f3c-105a-4b71-851c-372bf29d3fe8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "69a3fc1d-86f1-49f7-a93a-54f4748c5dbf", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "23d7d1f4-9f1f-42f6-a366-0e39af8893b2", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "0c4415d5-d0fd-48ca-b88c-2e244434561d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }