File size: 6,571 Bytes
2d4cda5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d6ff53-2176-44aa-8590-ec0aa301342d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from vllm import LLM, SamplingParams\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "from transformers import AutoTokenizer\n",
    "from transformers import AutoModelForCausalLM\n",
    "import re\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62669512-19e7-43cd-a518-4572eea700af",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, \n",
    "            gpu_memory_utilization=0.50,\n",
    "            download_dir = \"../../\", max_model_len=5000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16bca2af-0cf4-41f2-ae28-d2c669a1af21",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask_about_trials_loosely(patient_summaries, trial_summaries, llama_model):\n",
    "\n",
    "    tokenizer = llama_model.get_tokenizer()\n",
    "\n",
    "    prompts = []\n",
    "\n",
    "    for patient_summary, trial_summary in zip(patient_summaries, trial_summaries):\n",
    "        messages = [{'role':'system', 'content': \"\"\"You are a brilliant oncologist with encyclopedic knowledge about cancer and its treatment. \n",
    "    Your job is to evaluate whether a given clinical trial is a reasonable consideration for a patient, given a clinical trial summary and a patient summary.\\n\"\"\"}, \n",
    "                {'role':'user', 'content': \"Here is a summary of the clinical trial:\\n\" + trial_summary + \"\\nHere is a summary of the patient:\\n\" + patient_summary + \"\"\"\n",
    "Base your judgment on whether the patient generally fits the cancer type(s), cancer burden, prior treatment(s), and biomarker criteria specified for the trial.\n",
    "You do not have to determine if the patient is actually eligible; instead please just evaluate whether it is reasonable for the trial to be considered further by the patient's oncologist.\n",
    "Some trials have biomarker requirements that are not assessed until formal eligibility screening begins; please ignore these requirements.\n",
    "Reason step by step, then answer the question \"Is this trial a reasonable consideration for this patient?\" with a one-word \"Yes!\" or \"No!\" answer.\n",
    "Make sure to include the exclamation point in your final one-word answer.\"\"\"}]\n",
    "\n",
    "    \n",
    "        prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
    "        prompts.append(prompt)\n",
    "        \n",
    "    responses = llama_model.generate(\n",
    "        prompts,     \n",
    "        SamplingParams(\n",
    "        temperature=0.0,\n",
    "        top_p=0.2,\n",
    "        max_tokens=2048,\n",
    "        repetition_penalty=1.2,\n",
    "        stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")],  # KEYPOINT HERE\n",
    "    ))\n",
    "\n",
    "    response_texts = [x.outputs[0].text for x in responses]\n",
    "\n",
    "    eligibility_results = []\n",
    "\n",
    "    for response_text in response_texts:\n",
    "        if (\"Yes!\" in response_text) or (\"YES!\" in response_text):\n",
    "            eligibility_results.append(1.0)\n",
    "        else:\n",
    "            eligibility_results.append(0.0)\n",
    "    \n",
    "    return responses, response_texts, eligibility_results\n",
    "    \n",
    "\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce4cce6-3833-451a-98c6-d7f4c7b948c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_cohort_candidates = pd.read_csv('top_twenty_patients_tocheck_synthetic.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbb846d2-20e3-4361-b69b-82aa31c1f789",
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_cohort_candidates = patient_cohort_candidates.rename(columns={'this_patient':'patient_summary', 'space_summary':'this_space'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fb1aa2f-6c28-4a4e-a4e1-bfd69d0b39a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_cohort_candidates.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d30bf018-e135-40be-b636-0ba17acf8e61",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "output_list = []\n",
    "batch_list = []\n",
    "\n",
    "num_in_batch = 0\n",
    "\n",
    "for i in range(0, patient_cohort_candidates.shape[0]):\n",
    "   \n",
    "    batch_list.append(patient_cohort_candidates.iloc[[i]])\n",
    "    num_in_batch += 1\n",
    "    \n",
    "    if (num_in_batch == 500) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
    "\n",
    "        output = pd.concat(batch_list, axis=0)\n",
    "        _, output['llama_response'], output['eligibility_result'] = ask_about_trials_loosely(output['patient_summary'].tolist(), output['this_space'].astype(str).tolist(), llama)\n",
    "\n",
    "        output_list.append(output)\n",
    "        num_in_batch = 0\n",
    "        batch_list = []\n",
    "    \n",
    "    if (len(output_list) > 0 and (i % 500 == 0)) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
    "        output_file = pd.concat(output_list, axis=0)\n",
    "        output_file.to_csv('top_twenty_patients_checked_synthetic.csv')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91534c0e-4873-4eda-9a69-53660a84b4df",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaebffcc-4b62-4ab6-a077-69a6e4340773",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}