kenlkehl commited on
Commit
2d4cda5
·
verified ·
1 Parent(s): d3a3d42
0_summarize_ctgov_trials.ipynb ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "ee78bb6d-4e3c-4751-b042-12c358d89cac",
7
+ "metadata": {
8
+ "scrolled": true
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import numpy as np\n",
13
+ "import pandas as pd\n",
14
+ "import json\n",
15
+ "from vllm import LLM, SamplingParams\n",
16
+ "from transformers import AutoTokenizer\n",
17
+ "import torch\n",
18
+ "import os\n",
19
+ "#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'\n"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "id": "e8eeb339-6aca-4d3f-96fb-24a1caf26b34",
26
+ "metadata": {
27
+ "scrolled": true
28
+ },
29
+ "outputs": [],
30
+ "source": []
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "7129a989-04e9-475d-9260-d1fdb1ab7faa",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, \n",
40
+ " gpu_memory_utilization = 0.5,\n",
41
+ " download_dir = \"../../..\", max_model_len=6000)"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "d9d7d1c4-50ed-4614-9855-8e6cc86bbb0e",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": []
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "73897bb9-0738-4446-b332-9b9bf46ad043",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "def summarize_trials_multi_cohort(eligibility_texts, llama_model):\n",
60
+ "\n",
61
+ " tokenizer = llama.get_tokenizer()\n",
62
+ " prompts = []\n",
63
+ " for trial in eligibility_texts:\n",
64
+ " messages = [\n",
65
+ " {'role':'system', 'content': \"\"\"You are an expert clinical oncologist with an encyclopedic knowledge of cancer and its treatments.\n",
66
+ " Your job is to review a clinical trial document and extract a list of structured clinical spaces that are eligible for that trial.\n",
67
+ " A clinical space is defined as a unique combination of cancer primary site, histology, which treatments a patient must have received, which treatments a patient must not have received, cancer burden (eg presence of metastatic disease), and tumor biomarkers (such as germline or somatic gene mutations or alterations, or protein expression on tumor) that a patient must have or must not have; that renders a patient eligible for the trial.\n",
68
+ " Trials often specify that a particular treatment is excluded only if it was given within a short period of time, for example 14 days, one month, etc , prior to trial start. Do not include this type of time-specific treatment eligibility criteria in your output at all.\n",
69
+ " Some trials have only one space, while others have several. Do not output a space that contains multiple cancer types and/or histologies. Instead, generate separate spaces for each cancer type/histology combination.\n",
70
+ " For biomarkers, if the trial specifies whether the biomarker will be assessed during screening, note that.\n",
71
+ " Spell out cancer types; do not abbreviate them. For example, write \"non-small cell lung cancer\" rather than \"NSCLC\".\n",
72
+ " Structure your output like this, as a list of spaces, with spaces separated by newlines, as below:\n",
73
+ " 1. Cancer type allowed: <cancer_type_allowed>. Histology allowed: <histology_allowed>. Cancer burden allowed: <cancer_burden_allowed>. Prior treatment required: <prior_treatments_requred>. Prior treatment excluded: <prior_treatments_excluded>. Biomarkers required: <biomarkers_required>. Biomarkers excluded: <biomarkers_excluded>.\n",
74
+ " 2. Cancer type allowed: <cancer_type_allowed>, etc.\n",
75
+ " If a particular concept is not mentioned in the trial text, do not include it in your definition of trial space(s).\n",
76
+ " \"\"\"}, \n",
77
+ " \n",
78
+ " {'role':'user', 'content': \"Here is a clinical trial document: \\n\" + trial + \"\\n\" + \"\"\"Now, generate your list of the trial space(s), formatted as above.\n",
79
+ " Do not provide any introductory, explanatory, concluding, or disclaimer text.\n",
80
+ " Reminder: Treatment history is an important component of trial space definitions, but treatment history requirements that are described as applying only in a given period of time prior to trial treatment MUST BE IGNORED.\"\"\"\n",
81
+ " }\n",
82
+ " ]\n",
83
+ " \n",
84
+ " prompts.append(tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False))\n",
85
+ " \n",
86
+ "\n",
87
+ " \n",
88
+ " responses = llama_model.generate(\n",
89
+ " prompts, \n",
90
+ " SamplingParams(\n",
91
+ " temperature=0.0,\n",
92
+ " top_p=0.9,\n",
93
+ " max_tokens=3096,\n",
94
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
95
+ " ))\n",
96
+ "\n",
97
+ " response_texts = [x.outputs[0].text for x in responses]\n",
98
+ "\n",
99
+ "\n",
100
+ " return responses, response_texts"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "id": "ca683840-842b-4346-8eef-b66bc52d26af",
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "trials = pd.read_csv('./ctgov_cancer_trials.csv')"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "id": "aa51de7d-74e0-4822-b7e1-2c9a3bc31260",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": []
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "4816dbf0-bd92-4742-912a-477e545e330b",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "trial_cohorts = summarize_trials_multi_cohort(trials.trial_text.tolist(), llama)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "8283c587-c909-4548-804d-4d88b4ed7255",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "trials['spaces'] = trial_cohorts[1]"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "2ca75bab-7273-4ab0-86cd-1e0373546fce",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "trials.to_csv('ctgov_all_trials_unique_trial_spaces_10-31-24.csv')"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "0291913f-f3b9-4b39-99ab-954cb7237255",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": []
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "16563812-6967-4788-a123-0af5fd701ede",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": []
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "95776dbe-1a25-44bd-90f8-5c1573b6e92a",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "import pandas as pd\n",
175
+ "import numpy as np\n",
176
+ "output = pd.read_csv('ctgov_all_trials_unique_trial_spaces_10-31-24.csv')"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "cf647a1f-5a8c-4958-9032-440806a306d5",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "# example of a trial and extracted spaces\n",
187
+ "i = 1000\n",
188
+ "output.trial_text.iloc[i], output.spaces.iloc[i]"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "9cc06840-5647-4524-a7bf-a1ad53a07b7c",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "frames = []\n",
199
+ "for i in range(trials.shape[0]):\n",
200
+ " cohorts = pd.Series(trials.iloc[i].spaces.split(\"\\n\"))\n",
201
+ " cohorts = cohorts[~((cohorts.isnull()) | (cohorts == \"\\n\") | (cohorts == ''))].reset_index(drop=True)\n",
202
+ " frame = pd.DataFrame(np.repeat(trials.iloc[[i]], len(cohorts), axis=0), columns=trials.columns)\n",
203
+ " frame['this_space'] = cohorts\n",
204
+ " frame['space_number'] = frame.index\n",
205
+ " frames.append(frame)\n",
206
+ " "
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "id": "541669eb-f92e-49f3-9a36-b6625448c1a4",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "cohort_level_trials = pd.concat(frames, axis=0)"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": null,
222
+ "id": "51a04e84-7483-4398-b4a0-d0cdab790609",
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "cohort_level_trials.info()"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "id": "648f0e1e-ef81-4983-8f03-1fbdb138f649",
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "cohort_level_trials.this_space.str[0].isin(['1','2','3','4','5','6','7','8','9']).value_counts()"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "9ea048c1-c4ef-4202-a9be-a4658c4f1058",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "cohort_level_trials = cohort_level_trials[cohort_level_trials.this_space.str[0].isin(['1','2','3','4','5','6','7','8','9'])]"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "id": "852aee9d-ad97-4374-932f-6cae378dde2a",
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": []
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "00d2220a-627a-4b67-be28-c42561c3c964",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "cohort_level_trials.to_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "id": "a130e909-6629-4408-b1ad-201b319d5e0f",
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "temp = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "ad078444-33e1-4398-92b8-2e7f9f1a4031",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": [
284
+ "temp.this_space.nunique()"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "be264ecb-12e7-4fd4-a16b-5a4b2f44d2aa",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "import pandas as pd\n",
295
+ "out = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "id": "d38ca13b-f4c4-47f1-abd6-3289abbd5f64",
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "out.info()"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": null,
311
+ "id": "6849b44d-df0d-464f-bbce-f8fc1f789d3a",
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "# this component and following cells will not run without access to the DFCI private dataset\n",
316
+ "\n",
317
+ "import pandas as pd\n",
318
+ "dfci_trials = pd.read_csv(\"../space_specific_eligibility_checks_11-6-24.csv\")\n",
319
+ "dfci_trials.info()"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "code",
324
+ "execution_count": null,
325
+ "id": "869690c3-2a80-4403-8933-f8f042c4ae35",
326
+ "metadata": {},
327
+ "outputs": [],
328
+ "source": [
329
+ "non_dfci_ctgov_trials = out[~out.nct_id.isin(dfci_trials.nct_id)]"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "id": "3d28ed5c-d152-40a0-ab14-4aa748f3f8ee",
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "non_dfci_ctgov_trials.info()"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "efdaaf5b-edfb-4900-b85b-dde7eb1f92df",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "unique_trials = non_dfci_ctgov_trials.groupby('nct_id').first().reset_index()[['nct_id']]\n",
350
+ "unique_trials.shape[0]"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "41a63a73-4822-4c1d-820d-389252c0c56f",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "unique_trial_sample = unique_trials.nct_id.sample(n=500, random_state=42)"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "id": "4cbbffe7-72ca-45b4-a11f-bb2d278bcfb7",
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "sample_spaces = non_dfci_ctgov_trials[non_dfci_ctgov_trials.nct_id.isin(unique_trial_sample)]"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "id": "c2639cc4-3472-463c-8519-ce0a9a1d845c",
377
+ "metadata": {},
378
+ "outputs": [],
379
+ "source": [
380
+ "sample_spaces.info()"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": null,
386
+ "id": "bc6def48-cacc-437b-ac19-2af9418821c2",
387
+ "metadata": {},
388
+ "outputs": [],
389
+ "source": [
390
+ "sample_spaces.to_csv('sample_spaces.csv')"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": null,
396
+ "id": "2b2370cf-e2ec-4e54-8dd0-6dde6d0fb041",
397
+ "metadata": {},
398
+ "outputs": [],
399
+ "source": []
400
+ }
401
+ ],
402
+ "metadata": {
403
+ "kernelspec": {
404
+ "display_name": "Python 3 (ipykernel)",
405
+ "language": "python",
406
+ "name": "python3"
407
+ },
408
+ "language_info": {
409
+ "codemirror_mode": {
410
+ "name": "ipython",
411
+ "version": 3
412
+ },
413
+ "file_extension": ".py",
414
+ "mimetype": "text/x-python",
415
+ "name": "python",
416
+ "nbconvert_exporter": "python",
417
+ "pygments_lexer": "ipython3",
418
+ "version": "3.9.18"
419
+ }
420
+ },
421
+ "nbformat": 4,
422
+ "nbformat_minor": 5
423
+ }
10_train_trialspace_round2.ipynb ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "81b83fa8-421d-4be5-b9eb-5892f01fd5b0",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import numpy as np\n",
12
+ "import os\n",
13
+ "#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'\n",
14
+ "from sentence_transformers import SentenceTransformer, InputExample, losses\n",
15
+ "from torch.utils.data import DataLoader\n",
16
+ "import torch.nn.functional as F\n",
17
+ "import torch\n",
18
+ "from sklearn.metrics import roc_auc_score"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "937cbcda-0cd6-47f7-b52e-17ed2bafce3d",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "model = SentenceTransformer('reranker_round1.model', trust_remote_code=True, device='cuda')\n"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "id": "853e6b86-db0b-4650-b98f-f437987baa5a",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "cohort_checks = pd.read_csv('top_ten_cohorts_checked_synthetic_round2.csv')"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "474950e0-869b-414e-823f-df5ba8e5de92",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "cohort_checks.info()"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "c9835dad-4fc4-4a0e-aba2-d45358edbee9",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "cohort_checks['mod_eligibility_result'] = np.where(cohort_checks.llama_response.str.contains('Yes!|YES!'), 1, 0)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "413bddbc-f35c-48ec-bcbb-48405bd2c9c9",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "cohort_checks.eligibility_result.value_counts()"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "79c2d994-1e39-41b6-aeb3-962ba3ba5611",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "cohort_checks.mod_eligibility_result.value_counts()"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "9c8e6a20-4513-422c-be6e-3459ce98a2be",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "patient_checks = pd.read_csv('top_twenty_patients_checked_synthetic_round2.csv')"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "e91074cf-07de-40d1-8bf3-baf825d3f625",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "patient_checks['mod_eligibility_result'] = np.where(patient_checks.llama_response.str.contains('Yes!|YES!'), 1, 0)"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "3a2d45fc-7d92-4a65-aad0-a9ab8f783779",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "patient_checks.info()"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "0cf7c705-ba6d-4f01-ac2e-88d345fef7f6",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "patient_checks.eligibility_result.value_counts(), patient_checks.mod_eligibility_result.value_counts()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "dec4a8c8-c5db-4164-a06c-27fb59782fa5",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "patient_checks = patient_checks.rename(columns={'this_patient':'patient_summary', 'space_summary':'this_space'})"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "0bf55e82-c91d-472f-84ad-74c755e9bf29",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "combined_checks = pd.concat([patient_checks, cohort_checks], axis=0)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "ebd07c9c-6263-4005-bfb1-2a8468b76a98",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "combined_checks.info()"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "49f59429-c9f4-43df-a1b2-750a3c94517a",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": []
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "e1d0126e-a58d-41ca-ad2f-a2d37bc585ad",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "train_summaries = combined_checks[combined_checks.split=='train']\n",
167
+ "train_summaries = train_summaries[~train_summaries.patient_summary.isnull()]\n",
168
+ "train_summaries = train_summaries[~train_summaries.llama_response.isnull()]\n",
169
+ "train_summaries.split.value_counts()"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "id": "2f6506ed-dcbf-4e0b-8722-6e234c2d4509",
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "train_summaries.mod_eligibility_result.value_counts()"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "c678a59a-c301-42d9-83dd-511503cee2fb",
186
+ "metadata": {
187
+ "scrolled": true
188
+ },
189
+ "outputs": [],
190
+ "source": [
191
+ "train_summaries.info()"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "57932264-103a-413b-9a48-43b7be254ac0",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "# mll loss\n",
202
+ "train_eligibles_only = train_summaries[train_summaries.eligibility_result == 1]\n",
203
+ "example_list = []\n",
204
+ "for i in range(train_eligibles_only.shape[0]):\n",
205
+ " example_list.append(InputExample(texts=[train_summaries.patient_summary.iloc[i], train_summaries.this_space.iloc[i]]))\n",
206
+ "\n",
207
+ "train_eligibles_only_dataloader = DataLoader(example_list, shuffle=True, batch_size=8)\n",
208
+ "train_eligibles_only_loss = losses.MultipleNegativesRankingLoss(model=model)"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "id": "e5482be3-9a13-4ce1-aa8a-429c54bf6be0",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "# for attempt at contrastive loss\n",
219
+ "# note 'Yes' is considered positive even without !\n",
220
+ "contrastive_example_list = []\n",
221
+ "for i in range(train_summaries.shape[0]):\n",
222
+ " contrastive_example_list.append(InputExample(texts=[train_summaries.patient_summary.iloc[i], train_summaries.this_space.iloc[i]],\n",
223
+ " label=train_summaries.mod_eligibility_result.iloc[i]))\n",
224
+ "\n",
225
+ "contrastive_dataloader = DataLoader(contrastive_example_list, shuffle=True, batch_size=12)\n",
226
+ "contrastive_train_loss = losses.OnlineContrastiveLoss(model=model)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "id": "4e825dae-a5a9-4f87-af35-63ac2d73de33",
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": []
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "id": "f17ad7a6-8911-4d7d-8495-3e37cb00597d",
241
+ "metadata": {
242
+ "scrolled": true
243
+ },
244
+ "outputs": [],
245
+ "source": [
246
+ "#%%capture\n",
247
+ "model.fit(train_objectives=[(contrastive_dataloader, contrastive_train_loss),\n",
248
+ " (train_eligibles_only_dataloader, train_eligibles_only_loss)], epochs=2, warmup_steps=100)"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "c9cb6021-21d8-44bf-b440-980fcdae3b3d",
255
+ "metadata": {},
256
+ "outputs": [],
257
+ "source": [
258
+ "model.save('reranker_round2.model')"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "id": "bae79a2e-4357-4c90-ba4c-a08b1206a99d",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "model = SentenceTransformer('reranker_round2.model', trust_remote_code=True, device='cuda')"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": null,
274
+ "id": "f5517caa-c45b-4b62-ae8d-0af61b61fd25",
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": []
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": null,
282
+ "id": "c6bfb8f7-ca6b-474b-8ce3-ba5acacb6b6a",
283
+ "metadata": {},
284
+ "outputs": [],
285
+ "source": [
286
+ "# check model's ability to do initial discriminate among diseases task\n",
287
+ "# (on PHI)\n"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "id": "4172f6ba-b334-4b83-b73e-d05dad6c05f0",
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "# this file is not uploaded, since it contains PHI/IP\n",
298
+ "cohort_checks = pd.read_csv('../v7/space_specific_eligibility_checks_11-6-24.csv')"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "id": "d6b25941-0007-4347-9ef3-899f9258542a",
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "validation_set = cohort_checks[cohort_checks.split.str.contains('valid')]\n",
309
+ "validation_set.info()\n"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "id": "4b791608-6011-4bf6-914a-9534a08eba5a",
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "validation_set = validation_set[~validation_set.patient_summary.isnull()]\n",
320
+ "validation_set.info()"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": null,
326
+ "id": "479b9905-fcd6-4d37-9b03-7bbbfb88f123",
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "\n",
331
+ "eligibles_only = validation_set[validation_set.eligibility_result == 1]\n",
332
+ "patient_summary_embeddings = model.encode(eligibles_only.patient_summary.tolist())\n",
333
+ "trial_summary_embeddings = model.encode(eligibles_only.this_space.tolist())"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "9b8f3a40-0854-43a5-bd83-a7fe6770f52b",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": [
343
+ "import random\n",
344
+ "labels = []\n",
345
+ "similarities = []\n",
346
+ "for i in range(trial_summary_embeddings.shape[0]):\n",
347
+ " if random.choice([0,1]) == 1:\n",
348
+ " similarities.append(F.cosine_similarity(torch.tensor(patient_summary_embeddings[i,:]).unsqueeze(0), torch.tensor(trial_summary_embeddings[i, :]).unsqueeze(0)))\n",
349
+ " labels.append(1.)\n",
350
+ " else:\n",
351
+ " random_index = random.choice([x for x in range(0,trial_summary_embeddings.shape[0])])\n",
352
+ " similarities.append(F.cosine_similarity(torch.tensor(patient_summary_embeddings[i,:]).unsqueeze(0), torch.tensor(trial_summary_embeddings[random_index, :]).unsqueeze(0)))\n",
353
+ " labels.append(0.)\n",
354
+ "roc_auc_score(labels, np.array([x.numpy() for x in similarities]))"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "id": "16dd4634-0389-466d-8257-160ddd2659af",
361
+ "metadata": {},
362
+ "outputs": [],
363
+ "source": [
364
+ "# how good are embeddings at discriminating between llama yes and no checks on original enrollments?\n",
365
+ "# (for PHI)\n",
366
+ "patient_summary_embeddings = model.encode(validation_set.patient_summary.tolist(), convert_to_tensor=True)\n",
367
+ "trial_summary_embeddings = model.encode(validation_set.this_space.tolist(), convert_to_tensor=True)"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "id": "5bb0bc89-0b4f-451d-9523-550f7344e4d9",
374
+ "metadata": {},
375
+ "outputs": [],
376
+ "source": [
377
+ "similarities = F.cosine_similarity(patient_summary_embeddings, trial_summary_embeddings).detach().cpu().numpy()\n",
378
+ "roc_auc_score(validation_set.eligibility_result, similarities)"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "id": "c6035e62-8d28-49c5-8d0a-049633edd553",
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": []
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "execution_count": null,
392
+ "id": "453c2f3c-105a-4b71-851c-372bf29d3fe8",
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": []
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "69a3fc1d-86f1-49f7-a93a-54f4748c5dbf",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": []
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "id": "23d7d1f4-9f1f-42f6-a366-0e39af8893b2",
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": []
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "id": "0c4415d5-d0fd-48ca-b88c-2e244434561d",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": []
420
+ }
421
+ ],
422
+ "metadata": {
423
+ "kernelspec": {
424
+ "display_name": "Python 3 (ipykernel)",
425
+ "language": "python",
426
+ "name": "python3"
427
+ },
428
+ "language_info": {
429
+ "codemirror_mode": {
430
+ "name": "ipython",
431
+ "version": 3
432
+ },
433
+ "file_extension": ".py",
434
+ "mimetype": "text/x-python",
435
+ "name": "python",
436
+ "nbconvert_exporter": "python",
437
+ "pygments_lexer": "ipython3",
438
+ "version": "3.9.18"
439
+ }
440
+ },
441
+ "nbformat": 4,
442
+ "nbformat_minor": 5
443
+ }
11_train_roberta_checker.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
12_example_patient_query.ipynb ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "da78b443-8f9a-426d-8ac1-2320dc10f1d6",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import numpy as np\n",
12
+ "import os\n",
13
+ "# os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 2,
19
+ "id": "6ed86222-b6a4-4f3c-8dd6-ae6ac2fd7545",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 3,
29
+ "id": "df199321-94ac-4998-a3ad-bb90705485f9",
30
+ "metadata": {},
31
+ "outputs": [
32
+ {
33
+ "name": "stdout",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "<class 'pandas.core.frame.DataFrame'>\n",
37
+ "RangeIndex: 38140 entries, 0 to 38139\n",
38
+ "Data columns (total 10 columns):\n",
39
+ " # Column Non-Null Count Dtype \n",
40
+ "--- ------ -------------- ----- \n",
41
+ " 0 Unnamed: 0.1 38140 non-null int64 \n",
42
+ " 1 Unnamed: 0 38140 non-null int64 \n",
43
+ " 2 nct_id 38140 non-null object\n",
44
+ " 3 title 38140 non-null object\n",
45
+ " 4 brief_summary 38140 non-null object\n",
46
+ " 5 eligibility_criteria 38140 non-null object\n",
47
+ " 6 trial_text 38140 non-null object\n",
48
+ " 7 spaces 38140 non-null object\n",
49
+ " 8 this_space 38140 non-null object\n",
50
+ " 9 space_number 38140 non-null int64 \n",
51
+ "dtypes: int64(3), object(7)\n",
52
+ "memory usage: 2.9+ MB\n"
53
+ ]
54
+ }
55
+ ],
56
+ "source": [
57
+ "trial_spaces.info()"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 4,
63
+ "id": "47b983df-d6f7-41d8-8c98-354f395c098e",
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "name": "stderr",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "/homes10/klkehl/miniconda3/envs/vllm2/lib/python3.12/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:13: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
71
+ " from tqdm.autonotebook import tqdm, trange\n",
72
+ "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.86s/it]\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "from sentence_transformers import SentenceTransformer\n",
78
+ "import torch\n",
79
+ "\n",
80
+ "embedding_model = SentenceTransformer('reranker_round2.model', trust_remote_code=True, device='cuda')"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 5,
86
+ "id": "9ae0aa51-92f4-4e44-9a95-bf76196b1b7b",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "# only needs to be run once to generate and save trial embeddings\n",
91
+ "\n",
92
+ "# with torch.no_grad():\n",
93
+ "# trial_space_embeddings = embedding_model.encode(trial_spaces.this_space.tolist(), convert_to_tensor=True)\n",
94
+ "\n",
95
+ "# from safetensors.torch import save_file\n",
96
+ "# output_trial_file = {\"space_embeddings\": trial_space_embeddings}\n",
97
+ "# save_file(output_trial_file, \"trial_space_embeddings.safetensors\")\n",
98
+ "\n",
99
+ "# trial_space_embeddings.shape"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "id": "a6abc0de-b919-41df-88be-b838e0998f51",
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": []
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 6,
113
+ "id": "e6a25767-1939-48dd-a509-0bb2e3598f06",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "from safetensors import safe_open\n",
118
+ "with safe_open(\"trial_space_embeddings.safetensors\", framework=\"pt\", device=0) as f:\n",
119
+ " trial_space_embeddings = f.get_tensor(\"space_embeddings\")"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": 7,
125
+ "id": "f83d55e8-4876-4048-8aa9-8aaf766b6722",
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "from transformers import pipeline, AutoTokenizer\n",
130
+ "tokenizer = AutoTokenizer.from_pretrained(\"roberta-large\")\n",
131
+ "\n",
132
+ "pipe = pipeline('text-classification', './roberta-checker', tokenizer=tokenizer, truncation=True, padding='max_length', max_length=512, device='cuda') \n"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": 8,
138
+ "id": "0318b38e-8ee8-493c-a3fa-391142ce9696",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "patient_summary = \"metastatic lung adenocarcinoma, PD-L1 75%, KRAS G12C mutant, prior pembrolizumab, prior carboplatin/pemetrexed\""
143
+ ]
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 9,
148
+ "id": "6d84d3a6-d965-49f2-ad6c-c8f9571da616",
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "08c76a33-fdd0-4b1c-8a93-3a9e1f8b9bd6",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": []
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "id": "d8bf2e6f-fc8a-4f1a-9c68-f0e5fb939f5a",
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": []
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": 10,
174
+ "id": "0b36fc28-28f1-4cba-b4fc-679d457e8da4",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "import torch.nn.functional as F\n",
179
+ "similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": 11,
185
+ "id": "774db345-8fd4-4a5d-b76e-2aa44123b4f7",
186
+ "metadata": {},
187
+ "outputs": [
188
+ {
189
+ "data": {
190
+ "text/plain": [
191
+ "torch.Size([38140])"
192
+ ]
193
+ },
194
+ "execution_count": 11,
195
+ "metadata": {},
196
+ "output_type": "execute_result"
197
+ }
198
+ ],
199
+ "source": [
200
+ "similarities.shape"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 12,
206
+ "id": "b77b3b0b-a474-4a19-a4b8-2cc0990a3a45",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "# pull top ten spaces for the patient\n",
211
+ "sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n",
212
+ "relevant_spaces = trial_spaces.iloc[sorted_indices[0:10].cpu().numpy()].this_space\n",
213
+ "relevant_nctid = trial_spaces.iloc[sorted_indices[0:10].cpu().numpy()].nct_id\n",
214
+ "relevant_title = trial_spaces.iloc[sorted_indices[0:10].cpu().numpy()].title\n",
215
+ "relevant_brief_summary = trial_spaces.iloc[sorted_indices[0:10].cpu().numpy()].brief_summary\n",
216
+ "relevant_eligibility_criteria = trial_spaces.iloc[sorted_indices[0:10].cpu().numpy()].eligibility_criteria\n",
217
+ "relevant_space_embeddings = trial_space_embeddings[sorted_indices[0:10], :]"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 13,
223
+ "id": "653c772f-6031-4c35-9978-5f23eb9f9301",
224
+ "metadata": {},
225
+ "outputs": [
226
+ {
227
+ "data": {
228
+ "text/html": [
229
+ "<div>\n",
230
+ "<style scoped>\n",
231
+ " .dataframe tbody tr th:only-of-type {\n",
232
+ " vertical-align: middle;\n",
233
+ " }\n",
234
+ "\n",
235
+ " .dataframe tbody tr th {\n",
236
+ " vertical-align: top;\n",
237
+ " }\n",
238
+ "\n",
239
+ " .dataframe thead th {\n",
240
+ " text-align: right;\n",
241
+ " }\n",
242
+ "</style>\n",
243
+ "<table border=\"1\" class=\"dataframe\">\n",
244
+ " <thead>\n",
245
+ " <tr style=\"text-align: right;\">\n",
246
+ " <th></th>\n",
247
+ " <th>patient_summary</th>\n",
248
+ " <th>this_space</th>\n",
249
+ " <th>nct_id</th>\n",
250
+ " <th>trial_title</th>\n",
251
+ " <th>trial_brief_summary</th>\n",
252
+ " <th>trial_eligibility_criteria</th>\n",
253
+ " <th>pt_trial_pair</th>\n",
254
+ " </tr>\n",
255
+ " </thead>\n",
256
+ " <tbody>\n",
257
+ " <tr>\n",
258
+ " <th>0</th>\n",
259
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
260
+ " <td>5. Cancer type allowed: non-small cell lung ca...</td>\n",
261
+ " <td>NCT06253520</td>\n",
262
+ " <td>A Phase Ib Clinical Trial to Evaluate the Admi...</td>\n",
263
+ " <td>Background:\\n\\nMany cancer cells produce subst...</td>\n",
264
+ " <td>* INCLUSION CRITERIA:\\n* Participants with an ...</td>\n",
265
+ " <td>5. Cancer type allowed: non-small cell lung ca...</td>\n",
266
+ " </tr>\n",
267
+ " <tr>\n",
268
+ " <th>1</th>\n",
269
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
270
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
271
+ " <td>NCT05853575</td>\n",
272
+ " <td>A Randomized Study of Two Dosing Regimens of A...</td>\n",
273
+ " <td>This study will evaluate the efficacy of two d...</td>\n",
274
+ " <td>Key Inclusion Criteria:\\n\\n* Are at least 18 y...</td>\n",
275
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
276
+ " </tr>\n",
277
+ " <tr>\n",
278
+ " <th>2</th>\n",
279
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
280
+ " <td>3. Cancer type allowed: non-small cell lung ca...</td>\n",
281
+ " <td>NCT06128551</td>\n",
282
+ " <td>Phase 1b, Multicenter, Open-Label, Dose Escala...</td>\n",
283
+ " <td>This study is to evaluate the safety, tolerabi...</td>\n",
284
+ " <td>Inclusion Criteria:\\n\\n* 18 years of age\\n* Hi...</td>\n",
285
+ " <td>3. Cancer type allowed: non-small cell lung ca...</td>\n",
286
+ " </tr>\n",
287
+ " <tr>\n",
288
+ " <th>3</th>\n",
289
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
290
+ " <td>1. Cancer type allowed: Non-small cell lung ca...</td>\n",
291
+ " <td>NCT05788926</td>\n",
292
+ " <td>A Phase I Dose-escalation Trial of TG6050 Admi...</td>\n",
293
+ " <td>This is a phase I, open-label, dose-escalation...</td>\n",
294
+ " <td>Inclusion Criteria:\\n\\n1. Signed written infor...</td>\n",
295
+ " <td>1. Cancer type allowed: Non-small cell lung ca...</td>\n",
296
+ " </tr>\n",
297
+ " <tr>\n",
298
+ " <th>4</th>\n",
299
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
300
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
301
+ " <td>NCT05375084</td>\n",
302
+ " <td>A Phase 1 Study of the SHP2 Inhibitor BBP-398 ...</td>\n",
303
+ " <td>This is a Phase 1 study of BBP-398, a SHP2 inh...</td>\n",
304
+ " <td>Key Inclusion Criteria:\\n\\n* Patients must hav...</td>\n",
305
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
306
+ " </tr>\n",
307
+ " </tbody>\n",
308
+ "</table>\n",
309
+ "</div>"
310
+ ],
311
+ "text/plain": [
312
+ " patient_summary \\\n",
313
+ "0 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
314
+ "1 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
315
+ "2 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
316
+ "3 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
317
+ "4 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
318
+ "\n",
319
+ " this_space nct_id \\\n",
320
+ "0 5. Cancer type allowed: non-small cell lung ca... NCT06253520 \n",
321
+ "1 1. Cancer type allowed: non-small cell lung ca... NCT05853575 \n",
322
+ "2 3. Cancer type allowed: non-small cell lung ca... NCT06128551 \n",
323
+ "3 1. Cancer type allowed: Non-small cell lung ca... NCT05788926 \n",
324
+ "4 1. Cancer type allowed: non-small cell lung ca... NCT05375084 \n",
325
+ "\n",
326
+ " trial_title \\\n",
327
+ "0 A Phase Ib Clinical Trial to Evaluate the Admi... \n",
328
+ "1 A Randomized Study of Two Dosing Regimens of A... \n",
329
+ "2 Phase 1b, Multicenter, Open-Label, Dose Escala... \n",
330
+ "3 A Phase I Dose-escalation Trial of TG6050 Admi... \n",
331
+ "4 A Phase 1 Study of the SHP2 Inhibitor BBP-398 ... \n",
332
+ "\n",
333
+ " trial_brief_summary \\\n",
334
+ "0 Background:\\n\\nMany cancer cells produce subst... \n",
335
+ "1 This study will evaluate the efficacy of two d... \n",
336
+ "2 This study is to evaluate the safety, tolerabi... \n",
337
+ "3 This is a phase I, open-label, dose-escalation... \n",
338
+ "4 This is a Phase 1 study of BBP-398, a SHP2 inh... \n",
339
+ "\n",
340
+ " trial_eligibility_criteria \\\n",
341
+ "0 * INCLUSION CRITERIA:\\n* Participants with an ... \n",
342
+ "1 Key Inclusion Criteria:\\n\\n* Are at least 18 y... \n",
343
+ "2 Inclusion Criteria:\\n\\n* 18 years of age\\n* Hi... \n",
344
+ "3 Inclusion Criteria:\\n\\n1. Signed written infor... \n",
345
+ "4 Key Inclusion Criteria:\\n\\n* Patients must hav... \n",
346
+ "\n",
347
+ " pt_trial_pair \n",
348
+ "0 5. Cancer type allowed: non-small cell lung ca... \n",
349
+ "1 1. Cancer type allowed: non-small cell lung ca... \n",
350
+ "2 3. Cancer type allowed: non-small cell lung ca... \n",
351
+ "3 1. Cancer type allowed: Non-small cell lung ca... \n",
352
+ "4 1. Cancer type allowed: non-small cell lung ca... "
353
+ ]
354
+ },
355
+ "execution_count": 13,
356
+ "metadata": {},
357
+ "output_type": "execute_result"
358
+ }
359
+ ],
360
+ "source": [
361
+ "analysis = pd.DataFrame({'patient_summary':patient_summary, 'this_space':relevant_spaces,\n",
362
+ " 'nct_id':relevant_nctid, 'trial_title':relevant_title,\n",
363
+ " 'trial_brief_summary':relevant_brief_summary, 'trial_eligibility_criteria':relevant_eligibility_criteria}).reset_index(drop=True)\n",
364
+ "analysis['pt_trial_pair'] = analysis['this_space'] + \"\\nNow here is the patient summary:\" + analysis['patient_summary']\n",
365
+ "analysis.head()"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 14,
371
+ "id": "38439e90-59bc-4d8a-bef7-8731966ff015",
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "pipe = pipeline('text-classification', model='./roberta-checker', device='cuda')"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": 15,
381
+ "id": "4d6c761a-1d9f-4890-beb8-c20ba5523f87",
382
+ "metadata": {},
383
+ "outputs": [],
384
+ "source": [
385
+ "classifier_results = pipe(analysis.pt_trial_pair.tolist())\n",
386
+ "analysis['roberta_check_result'] = [x['label'] for x in classifier_results]\n",
387
+ "analysis['roberta_check_score'] = [x['score'] for x in classifier_results]\n"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": null,
393
+ "id": "d9b9cb51-a054-4950-bef7-220c4378757d",
394
+ "metadata": {},
395
+ "outputs": [],
396
+ "source": []
397
+ },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": 16,
401
+ "id": "86d23c7b-d10a-4efb-bec5-ab95ef6dfb0d",
402
+ "metadata": {},
403
+ "outputs": [
404
+ {
405
+ "data": {
406
+ "text/html": [
407
+ "<div>\n",
408
+ "<style scoped>\n",
409
+ " .dataframe tbody tr th:only-of-type {\n",
410
+ " vertical-align: middle;\n",
411
+ " }\n",
412
+ "\n",
413
+ " .dataframe tbody tr th {\n",
414
+ " vertical-align: top;\n",
415
+ " }\n",
416
+ "\n",
417
+ " .dataframe thead th {\n",
418
+ " text-align: right;\n",
419
+ " }\n",
420
+ "</style>\n",
421
+ "<table border=\"1\" class=\"dataframe\">\n",
422
+ " <thead>\n",
423
+ " <tr style=\"text-align: right;\">\n",
424
+ " <th></th>\n",
425
+ " <th>patient_summary</th>\n",
426
+ " <th>this_space</th>\n",
427
+ " <th>nct_id</th>\n",
428
+ " <th>trial_title</th>\n",
429
+ " <th>trial_brief_summary</th>\n",
430
+ " <th>trial_eligibility_criteria</th>\n",
431
+ " <th>pt_trial_pair</th>\n",
432
+ " <th>roberta_check_result</th>\n",
433
+ " <th>roberta_check_score</th>\n",
434
+ " </tr>\n",
435
+ " </thead>\n",
436
+ " <tbody>\n",
437
+ " <tr>\n",
438
+ " <th>0</th>\n",
439
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
440
+ " <td>5. Cancer type allowed: non-small cell lung ca...</td>\n",
441
+ " <td>NCT06253520</td>\n",
442
+ " <td>A Phase Ib Clinical Trial to Evaluate the Admi...</td>\n",
443
+ " <td>Background:\\n\\nMany cancer cells produce subst...</td>\n",
444
+ " <td>* INCLUSION CRITERIA:\\n* Participants with an ...</td>\n",
445
+ " <td>5. Cancer type allowed: non-small cell lung ca...</td>\n",
446
+ " <td>NEGATIVE</td>\n",
447
+ " <td>0.834101</td>\n",
448
+ " </tr>\n",
449
+ " <tr>\n",
450
+ " <th>1</th>\n",
451
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
452
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
453
+ " <td>NCT05853575</td>\n",
454
+ " <td>A Randomized Study of Two Dosing Regimens of A...</td>\n",
455
+ " <td>This study will evaluate the efficacy of two d...</td>\n",
456
+ " <td>Key Inclusion Criteria:\\n\\n* Are at least 18 y...</td>\n",
457
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
458
+ " <td>POSITIVE</td>\n",
459
+ " <td>0.910206</td>\n",
460
+ " </tr>\n",
461
+ " <tr>\n",
462
+ " <th>2</th>\n",
463
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
464
+ " <td>3. Cancer type allowed: non-small cell lung ca...</td>\n",
465
+ " <td>NCT06128551</td>\n",
466
+ " <td>Phase 1b, Multicenter, Open-Label, Dose Escala...</td>\n",
467
+ " <td>This study is to evaluate the safety, tolerabi...</td>\n",
468
+ " <td>Inclusion Criteria:\\n\\n* 18 years of age\\n* Hi...</td>\n",
469
+ " <td>3. Cancer type allowed: non-small cell lung ca...</td>\n",
470
+ " <td>POSITIVE</td>\n",
471
+ " <td>0.915395</td>\n",
472
+ " </tr>\n",
473
+ " <tr>\n",
474
+ " <th>3</th>\n",
475
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
476
+ " <td>1. Cancer type allowed: Non-small cell lung ca...</td>\n",
477
+ " <td>NCT05788926</td>\n",
478
+ " <td>A Phase I Dose-escalation Trial of TG6050 Admi...</td>\n",
479
+ " <td>This is a phase I, open-label, dose-escalation...</td>\n",
480
+ " <td>Inclusion Criteria:\\n\\n1. Signed written infor...</td>\n",
481
+ " <td>1. Cancer type allowed: Non-small cell lung ca...</td>\n",
482
+ " <td>POSITIVE</td>\n",
483
+ " <td>0.914168</td>\n",
484
+ " </tr>\n",
485
+ " <tr>\n",
486
+ " <th>4</th>\n",
487
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
488
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
489
+ " <td>NCT05375084</td>\n",
490
+ " <td>A Phase 1 Study of the SHP2 Inhibitor BBP-398 ...</td>\n",
491
+ " <td>This is a Phase 1 study of BBP-398, a SHP2 inh...</td>\n",
492
+ " <td>Key Inclusion Criteria:\\n\\n* Patients must hav...</td>\n",
493
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
494
+ " <td>POSITIVE</td>\n",
495
+ " <td>0.877930</td>\n",
496
+ " </tr>\n",
497
+ " <tr>\n",
498
+ " <th>5</th>\n",
499
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
500
+ " <td>2. Cancer type allowed: non-small cell lung ca...</td>\n",
501
+ " <td>NCT06128551</td>\n",
502
+ " <td>Phase 1b, Multicenter, Open-Label, Dose Escala...</td>\n",
503
+ " <td>This study is to evaluate the safety, tolerabi...</td>\n",
504
+ " <td>Inclusion Criteria:\\n\\n* 18 years of age\\n* Hi...</td>\n",
505
+ " <td>2. Cancer type allowed: non-small cell lung ca...</td>\n",
506
+ " <td>POSITIVE</td>\n",
507
+ " <td>0.926033</td>\n",
508
+ " </tr>\n",
509
+ " <tr>\n",
510
+ " <th>6</th>\n",
511
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
512
+ " <td>2. Cancer type allowed: Non-Small Cell Lung Ca...</td>\n",
513
+ " <td>NCT06447662</td>\n",
514
+ " <td>A Phase 1 Open-Label Study of PF-07934040 as a...</td>\n",
515
+ " <td>The purpose of this study is to learn about th...</td>\n",
516
+ " <td>Inclusion Criteria:\\n\\n* Histological or cytol...</td>\n",
517
+ " <td>2. Cancer type allowed: Non-Small Cell Lung Ca...</td>\n",
518
+ " <td>POSITIVE</td>\n",
519
+ " <td>0.506948</td>\n",
520
+ " </tr>\n",
521
+ " <tr>\n",
522
+ " <th>7</th>\n",
523
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
524
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
525
+ " <td>NCT06127940</td>\n",
526
+ " <td>K-SAB Trial - Sotorasib Followed by SBRT to 1-...</td>\n",
527
+ " <td>The goal of this interventional study is to le...</td>\n",
528
+ " <td>Main inclusion criteria:\\n\\n1. Histological or...</td>\n",
529
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
530
+ " <td>POSITIVE</td>\n",
531
+ " <td>0.952771</td>\n",
532
+ " </tr>\n",
533
+ " <tr>\n",
534
+ " <th>8</th>\n",
535
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
536
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
537
+ " <td>NCT06343402</td>\n",
538
+ " <td>A Phase 1a/1b Open-Label Study of BBO-8520 As ...</td>\n",
539
+ " <td>A first in human study to evaluate the safety,...</td>\n",
540
+ " <td>Inclusion Criteria:\\n\\n* Histologically docume...</td>\n",
541
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
542
+ " <td>POSITIVE</td>\n",
543
+ " <td>0.949954</td>\n",
544
+ " </tr>\n",
545
+ " <tr>\n",
546
+ " <th>9</th>\n",
547
+ " <td>metastatic lung adenocarcinoma, PD-L1 75%, KRA...</td>\n",
548
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
549
+ " <td>NCT05815173</td>\n",
550
+ " <td>Phase I/II Study of Ladarixin and Sotorasib in...</td>\n",
551
+ " <td>This is a phase I/II, open-label, study of twi...</td>\n",
552
+ " <td>Inclusion Criteria:\\n\\n* Written informed cons...</td>\n",
553
+ " <td>1. Cancer type allowed: non-small cell lung ca...</td>\n",
554
+ " <td>POSITIVE</td>\n",
555
+ " <td>0.937962</td>\n",
556
+ " </tr>\n",
557
+ " </tbody>\n",
558
+ "</table>\n",
559
+ "</div>"
560
+ ],
561
+ "text/plain": [
562
+ " patient_summary \\\n",
563
+ "0 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
564
+ "1 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
565
+ "2 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
566
+ "3 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
567
+ "4 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
568
+ "5 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
569
+ "6 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
570
+ "7 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
571
+ "8 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
572
+ "9 metastatic lung adenocarcinoma, PD-L1 75%, KRA... \n",
573
+ "\n",
574
+ " this_space nct_id \\\n",
575
+ "0 5. Cancer type allowed: non-small cell lung ca... NCT06253520 \n",
576
+ "1 1. Cancer type allowed: non-small cell lung ca... NCT05853575 \n",
577
+ "2 3. Cancer type allowed: non-small cell lung ca... NCT06128551 \n",
578
+ "3 1. Cancer type allowed: Non-small cell lung ca... NCT05788926 \n",
579
+ "4 1. Cancer type allowed: non-small cell lung ca... NCT05375084 \n",
580
+ "5 2. Cancer type allowed: non-small cell lung ca... NCT06128551 \n",
581
+ "6 2. Cancer type allowed: Non-Small Cell Lung Ca... NCT06447662 \n",
582
+ "7 1. Cancer type allowed: non-small cell lung ca... NCT06127940 \n",
583
+ "8 1. Cancer type allowed: non-small cell lung ca... NCT06343402 \n",
584
+ "9 1. Cancer type allowed: non-small cell lung ca... NCT05815173 \n",
585
+ "\n",
586
+ " trial_title \\\n",
587
+ "0 A Phase Ib Clinical Trial to Evaluate the Admi... \n",
588
+ "1 A Randomized Study of Two Dosing Regimens of A... \n",
589
+ "2 Phase 1b, Multicenter, Open-Label, Dose Escala... \n",
590
+ "3 A Phase I Dose-escalation Trial of TG6050 Admi... \n",
591
+ "4 A Phase 1 Study of the SHP2 Inhibitor BBP-398 ... \n",
592
+ "5 Phase 1b, Multicenter, Open-Label, Dose Escala... \n",
593
+ "6 A Phase 1 Open-Label Study of PF-07934040 as a... \n",
594
+ "7 K-SAB Trial - Sotorasib Followed by SBRT to 1-... \n",
595
+ "8 A Phase 1a/1b Open-Label Study of BBO-8520 As ... \n",
596
+ "9 Phase I/II Study of Ladarixin and Sotorasib in... \n",
597
+ "\n",
598
+ " trial_brief_summary \\\n",
599
+ "0 Background:\\n\\nMany cancer cells produce subst... \n",
600
+ "1 This study will evaluate the efficacy of two d... \n",
601
+ "2 This study is to evaluate the safety, tolerabi... \n",
602
+ "3 This is a phase I, open-label, dose-escalation... \n",
603
+ "4 This is a Phase 1 study of BBP-398, a SHP2 inh... \n",
604
+ "5 This study is to evaluate the safety, tolerabi... \n",
605
+ "6 The purpose of this study is to learn about th... \n",
606
+ "7 The goal of this interventional study is to le... \n",
607
+ "8 A first in human study to evaluate the safety,... \n",
608
+ "9 This is a phase I/II, open-label, study of twi... \n",
609
+ "\n",
610
+ " trial_eligibility_criteria \\\n",
611
+ "0 * INCLUSION CRITERIA:\\n* Participants with an ... \n",
612
+ "1 Key Inclusion Criteria:\\n\\n* Are at least 18 y... \n",
613
+ "2 Inclusion Criteria:\\n\\n* 18 years of age\\n* Hi... \n",
614
+ "3 Inclusion Criteria:\\n\\n1. Signed written infor... \n",
615
+ "4 Key Inclusion Criteria:\\n\\n* Patients must hav... \n",
616
+ "5 Inclusion Criteria:\\n\\n* 18 years of age\\n* Hi... \n",
617
+ "6 Inclusion Criteria:\\n\\n* Histological or cytol... \n",
618
+ "7 Main inclusion criteria:\\n\\n1. Histological or... \n",
619
+ "8 Inclusion Criteria:\\n\\n* Histologically docume... \n",
620
+ "9 Inclusion Criteria:\\n\\n* Written informed cons... \n",
621
+ "\n",
622
+ " pt_trial_pair roberta_check_result \\\n",
623
+ "0 5. Cancer type allowed: non-small cell lung ca... NEGATIVE \n",
624
+ "1 1. Cancer type allowed: non-small cell lung ca... POSITIVE \n",
625
+ "2 3. Cancer type allowed: non-small cell lung ca... POSITIVE \n",
626
+ "3 1. Cancer type allowed: Non-small cell lung ca... POSITIVE \n",
627
+ "4 1. Cancer type allowed: non-small cell lung ca... POSITIVE \n",
628
+ "5 2. Cancer type allowed: non-small cell lung ca... POSITIVE \n",
629
+ "6 2. Cancer type allowed: Non-Small Cell Lung Ca... POSITIVE \n",
630
+ "7 1. Cancer type allowed: non-small cell lung ca... POSITIVE \n",
631
+ "8 1. Cancer type allowed: non-small cell lung ca... POSITIVE \n",
632
+ "9 1. Cancer type allowed: non-small cell lung ca... POSITIVE \n",
633
+ "\n",
634
+ " roberta_check_score \n",
635
+ "0 0.834101 \n",
636
+ "1 0.910206 \n",
637
+ "2 0.915395 \n",
638
+ "3 0.914168 \n",
639
+ "4 0.877930 \n",
640
+ "5 0.926033 \n",
641
+ "6 0.506948 \n",
642
+ "7 0.952771 \n",
643
+ "8 0.949954 \n",
644
+ "9 0.937962 "
645
+ ]
646
+ },
647
+ "execution_count": 16,
648
+ "metadata": {},
649
+ "output_type": "execute_result"
650
+ }
651
+ ],
652
+ "source": [
653
+ "analysis"
654
+ ]
655
+ },
656
+ {
657
+ "cell_type": "code",
658
+ "execution_count": 17,
659
+ "id": "94ccb775-e2da-47cf-a64d-6e017a5bb11f",
660
+ "metadata": {},
661
+ "outputs": [
662
+ {
663
+ "data": {
664
+ "text/plain": [
665
+ "'5. Cancer type allowed: non-small cell lung cancer. Histology allowed: solid cancer. Cancer burden allowed: metastatic disease. Prior treatment required: at least one platinum-based chemotherapy regimen and at least one FDA-approved targeted treatment. Biomarkers required: KRAS G12V or G12D mutation. Biomarkers to be assessed during screening: HLA match.'"
666
+ ]
667
+ },
668
+ "execution_count": 17,
669
+ "metadata": {},
670
+ "output_type": "execute_result"
671
+ }
672
+ ],
673
+ "source": [
674
+ "analysis.this_space.iloc[0]"
675
+ ]
676
+ },
677
+ {
678
+ "cell_type": "code",
679
+ "execution_count": 18,
680
+ "id": "27f3bdda-a893-4439-bb20-d864c0476b23",
681
+ "metadata": {},
682
+ "outputs": [
683
+ {
684
+ "data": {
685
+ "text/plain": [
686
+ "'1. Cancer type allowed: non-small cell lung cancer. Histology allowed: adenocarcinoma, squamous cell carcinoma, large cell carcinoma, and other subtypes of non-small cell lung cancer. Cancer burden allowed: advanced, metastatic. Prior treatment required: chemotherapy that included cisplatin or carboplatin, immune checkpoint inhibitor. Prior treatment excluded: KRAS G12C targeted therapy. Biomarkers required: KRAS G12C mutation.'"
687
+ ]
688
+ },
689
+ "execution_count": 18,
690
+ "metadata": {},
691
+ "output_type": "execute_result"
692
+ }
693
+ ],
694
+ "source": [
695
+ "analysis.this_space.iloc[1]"
696
+ ]
697
+ },
698
+ {
699
+ "cell_type": "code",
700
+ "execution_count": 19,
701
+ "id": "824794ca-1934-47dd-8844-74cf0ea2db75",
702
+ "metadata": {},
703
+ "outputs": [
704
+ {
705
+ "data": {
706
+ "text/plain": [
707
+ "'3. Cancer type allowed: non-small cell lung cancer. Histology allowed: pathologically documented, KRAS G12C-mutated. Cancer burden allowed: advanced or metastatic. Prior treatment required: immunotherapy, chemotherapy. Biomarkers required: KRAS G12C mutation.'"
708
+ ]
709
+ },
710
+ "execution_count": 19,
711
+ "metadata": {},
712
+ "output_type": "execute_result"
713
+ }
714
+ ],
715
+ "source": [
716
+ "analysis.this_space.iloc[2]"
717
+ ]
718
+ },
719
+ {
720
+ "cell_type": "code",
721
+ "execution_count": null,
722
+ "id": "24a06913-e32e-4f14-ba0a-c9cd6ae6c4fd",
723
+ "metadata": {},
724
+ "outputs": [],
725
+ "source": []
726
+ }
727
+ ],
728
+ "metadata": {
729
+ "kernelspec": {
730
+ "display_name": "Python 3 (ipykernel)",
731
+ "language": "python",
732
+ "name": "python3"
733
+ },
734
+ "language_info": {
735
+ "codemirror_mode": {
736
+ "name": "ipython",
737
+ "version": 3
738
+ },
739
+ "file_extension": ".py",
740
+ "mimetype": "text/x-python",
741
+ "name": "python",
742
+ "nbconvert_exporter": "python",
743
+ "pygments_lexer": "ipython3",
744
+ "version": "3.9.18"
745
+ }
746
+ },
747
+ "nbformat": 4,
748
+ "nbformat_minor": 5
749
+ }
1a_generate_synthetic_imaging_reports.ipynb ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "180f9bc1-03cc-4e31-babe-3f6c6ecb0167",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "ef6b0609-695f-4975-9970-f8b8350f953d",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": []
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "f986b468-c428-4ca3-9101-1cbabe6ad73f",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "from vllm import LLM, SamplingParams\n",
29
+ "import pandas as pd\n",
30
+ "import numpy as np\n",
31
+ "import torch.nn.functional as F\n",
32
+ "import torch\n",
33
+ "from transformers import AutoTokenizer\n",
34
+ "from transformers import AutoModelForCausalLM\n",
35
+ "import re\n",
36
+ "import os\n",
37
+ "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "a8138f4e-6f45-4c98-b6b6-19370d53e7ac",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# llama = LLM(model='meta-llama/Meta-Llama-3.1-8B-Instruct', tensor_parallel_size = 2, \n",
48
+ "# gpu_memory_utilization=0.95,\n",
49
+ "# download_dir = \"../../\", max_model_len=120000)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "8e0537c8-85cc-4bae-97de-6dd6f70ea5a3",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, download_dir = \"../meta_ai/\", gpu_memory_utilization=0.80, max_model_len=5000)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "e0394142-a749-4995-abc8-bac884fea671",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "def generate_synthetic_imaging_reports(num_reports, llama_model):\n",
70
+ "\n",
71
+ " tokenizer = llama_model.get_tokenizer()\n",
72
+ " prompts = []\n",
73
+ " scan_types = np.random.choice(['CT scan', 'MRI', 'Nuclear bone scan', 'PET-CT'], size=num_reports)\n",
74
+ " cancer_types = np.random.choice(['breast', 'non-small cell lung', 'small cell lung', 'colorectal', 'pancreatic', 'urothelial', 'prostate', 'gastric', 'esophageal', 'thymoma', 'thymic carcinoma', 'adrenal', 'ovarian', 'endometrial', 'melanoma', 'renal cell', 'sarcoma', 'head and neck', 'Hodgkin lymphoma', 'Non-Hodgkin lymphoma', 'myeloma', 'acute myeloid leukemia', 'chronic myeloid leukemia', 'acute lymphoblastic leukemia', 'chronic lymphocytic leukemia/lymphoma', 'primary brain tumor'], size=num_reports) \n",
75
+ "\n",
76
+ " for i in range(num_reports):\n",
77
+ "\n",
78
+ " messages = [\n",
79
+ " {'role':'system', 'content': \"\"\"Your job is to generate synthetic imaging reports for hypothetical patients with cancer.\n",
80
+ " You know all there is to know about cancer and its treatments, so be detailed. \n",
81
+ " \"\"\"}, \n",
82
+ "\n",
83
+ "\n",
84
+ " {'role':'user', 'content': \"\"\"Imagine a patient with cancer. \n",
85
+ " The cancer type is \"\"\" + cancer_types[i] + \".\" + \"\"\"\n",
86
+ " Then, generate a very detailed imaging report that might have been written about an imaging study performed for the patient. \n",
87
+ " The patient might have any stage of disease and be at any point along the disease trajectory. Use everything you know about cancer, including epidemiology, treatment, and heterogeneity in disease presentations.\n",
88
+ " The imaging study type is \"\"\" + scan_types[i] + \".\" + \"\"\"\n",
89
+ " The report should include a detailed \"Findings\" section followed by an \"Impression\" section.\n",
90
+ " The report should not include any treatment recommendations.\n",
91
+ " The imaging report should be approximately a full page long.\"\"\"}\n",
92
+ " ]\n",
93
+ " \n",
94
+ " prompts.append(tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False))\n",
95
+ " \n",
96
+ "\n",
97
+ " \n",
98
+ " responses = llama_model.generate(\n",
99
+ " prompts, \n",
100
+ " SamplingParams(\n",
101
+ " temperature=1.0,\n",
102
+ " top_p=0.9,\n",
103
+ " max_tokens=4000,\n",
104
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
105
+ " ))\n",
106
+ "\n",
107
+ " response_texts = [x.outputs[0].text for x in responses]\n",
108
+ "\n",
109
+ "\n",
110
+ " return pd.DataFrame({'cancer_type':cancer_types, 'scan_type':scan_types, 'synthetic_imaging_report':response_texts})"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": null,
116
+ "id": "109f2208-de29-43cf-a831-4609bdab225e",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "results = generate_synthetic_imaging_reports(10000, llama)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "id": "443fa19c-8933-41d2-ba04-382902421e08",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "results.synthetic_imaging_report.sample(n=1).iloc[0]"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "id": "b40b06f0-69e9-48cc-9766-61d8d26178d6",
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "results.head()"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "6119a0b4-5a63-4f91-a637-484b5e9dc29c",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "results.to_csv('synthetic_imaging_reports.csv')"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "71a24b19-5e1c-4c24-b13b-ac04c1e94bd2",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": []
160
+ }
161
+ ],
162
+ "metadata": {
163
+ "kernelspec": {
164
+ "display_name": "Python 3 (ipykernel)",
165
+ "language": "python",
166
+ "name": "python3"
167
+ },
168
+ "language_info": {
169
+ "codemirror_mode": {
170
+ "name": "ipython",
171
+ "version": 3
172
+ },
173
+ "file_extension": ".py",
174
+ "mimetype": "text/x-python",
175
+ "name": "python",
176
+ "nbconvert_exporter": "python",
177
+ "pygments_lexer": "ipython3",
178
+ "version": "3.12.5"
179
+ }
180
+ },
181
+ "nbformat": 4,
182
+ "nbformat_minor": 5
183
+ }
1b_generate_synthetic_clinical_notes.ipynb ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "f986b468-c428-4ca3-9101-1cbabe6ad73f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n",
19
+ "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "id": "8e0537c8-85cc-4bae-97de-6dd6f70ea5a3",
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, download_dir = \"../meta_ai/\", gpu_memory_utilization=0.80, max_model_len=10000)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "6119a0b4-5a63-4f91-a637-484b5e9dc29c",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": []
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "71a24b19-5e1c-4c24-b13b-ac04c1e94bd2",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "def generate_synthetic_clinical_notes(num_reports, llama_model):\n",
48
+ "\n",
49
+ " tokenizer = llama_model.get_tokenizer()\n",
50
+ " prompts = []\n",
51
+ " cancer_types = np.random.choice(['breast', 'non-small cell lung', 'small cell lung', 'colorectal', 'pancreatic', 'urothelial', 'prostate', 'gastric', 'esophageal', 'thymoma', 'thymic carcinoma', 'adrenal', 'ovarian', 'endometrial', 'melanoma', 'renal cell', 'sarcoma', 'head and neck', 'Hodgkin lymphoma', 'Non-Hodgkin lymphoma', 'myeloma', 'acute myeloid leukemia', 'chronic myeloid leukemia', 'acute lymphoblastic leukemia', 'chronic lymphocytic leukemia/lymphoma', 'primary brain tumor'], size=num_reports) \n",
52
+ "\n",
53
+ " for i in range(num_reports):\n",
54
+ " messages = [\n",
55
+ " {'role':'system', 'content': \"\"\"Your job is to generate synthetic oncologist clinical progress notes for hypothetical patients with cancer.\n",
56
+ " You know all there is to know about cancer and its treatments, so be detailed. \n",
57
+ " \"\"\"}, \n",
58
+ " \n",
59
+ " {'role':'user', 'content': \"\"\"Imagine a patient with cancer. \n",
60
+ " The cancer type is\"\"\" + cancer_types[i] + \".\" + \"\"\"\n",
61
+ " The patient might have any stage of disease. Use everything you know about cancer, including biomarkers, epidemiology, and heterogeneity in disease presentations.\n",
62
+ " The note might correspond to any point along the disease trajectory, from initial diagnosis to curative intent treatment to palliative intent treatment.\n",
63
+ " The note should include a chief complaint, oncologic history including prior treatments, past medical history/comorbidities, current subjective clinical status and physical exam including vital signs and ECOG performance status, laboratory values, radiology excerpts, and an assessment and plan.\n",
64
+ " The note should should be approximately two pages long. It will not be used for clinical care, so do not include disclaimers.\"\"\"}\n",
65
+ " ]\n",
66
+ " \n",
67
+ " prompts.append(tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False))\n",
68
+ " \n",
69
+ "\n",
70
+ " \n",
71
+ " responses = llama_model.generate(\n",
72
+ " prompts, \n",
73
+ " SamplingParams(\n",
74
+ " temperature=1.0,\n",
75
+ " top_p=0.9,\n",
76
+ " max_tokens=4000,\n",
77
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
78
+ " ))\n",
79
+ "\n",
80
+ " response_texts = [x.outputs[0].text for x in responses]\n",
81
+ "\n",
82
+ "\n",
83
+ " return pd.DataFrame({'cancer_type':cancer_types, 'synthetic_clinical_note':response_texts})"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "63642559-3583-4771-88f1-5d4e8b6e033f",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "results = generate_synthetic_clinical_notes(10000, llama)"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "6eeb2f9c-4ca8-424b-b30e-dcee57323fe2",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "results.to_csv('synthetic_clinical_notes.csv')"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "0470b150-dda1-4900-adb6-a90e5cf74e39",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": []
113
+ }
114
+ ],
115
+ "metadata": {
116
+ "kernelspec": {
117
+ "display_name": "Python 3 (ipykernel)",
118
+ "language": "python",
119
+ "name": "python3"
120
+ },
121
+ "language_info": {
122
+ "codemirror_mode": {
123
+ "name": "ipython",
124
+ "version": 3
125
+ },
126
+ "file_extension": ".py",
127
+ "mimetype": "text/x-python",
128
+ "name": "python",
129
+ "nbconvert_exporter": "python",
130
+ "pygments_lexer": "ipython3",
131
+ "version": "3.12.5"
132
+ }
133
+ },
134
+ "nbformat": 4,
135
+ "nbformat_minor": 5
136
+ }
1c_generate_synthetic_path_reports.ipynb ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "f986b468-c428-4ca3-9101-1cbabe6ad73f",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n",
19
+ "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "id": "8e0537c8-85cc-4bae-97de-6dd6f70ea5a3",
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, download_dir = \"../meta_ai/\", gpu_memory_utilization=0.80, max_model_len=10000)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "id": "e0394142-a749-4995-abc8-bac884fea671",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "def generate_synthetic_path_reports(num_reports, llama_model):\n",
40
+ "\n",
41
+ " tokenizer = llama_model.get_tokenizer()\n",
42
+ " prompts = []\n",
43
+ " cancer_types = np.random.choice(['breast', 'non-small cell lung', 'small cell lung', 'colorectal', 'pancreatic', 'urothelial', 'prostate', 'gastric', 'esophageal', 'thymoma', 'thymic carcinoma', 'adrenal', 'ovarian', 'endometrial', 'melanoma', 'renal cell', 'sarcoma', 'head and neck', 'Hodgkin lymphoma', 'Non-Hodgkin lymphoma', 'myeloma', 'acute myeloid leukemia', 'chronic myeloid leukemia', 'acute lymphoblastic leukemia', 'chronic lymphocytic leukemia/lymphoma', 'primary brain tumor'], size=num_reports) \n",
44
+ "\n",
45
+ " for i in range(num_reports):\n",
46
+ " messages = [\n",
47
+ " {'role':'system', 'content': \"\"\"Your job is to generate synthetic pathology reports for hypothetical patients with cancer.\n",
48
+ " You know all there is to know about cancer and its treatments, so be detailed. \n",
49
+ " \"\"\"}, \n",
50
+ " \n",
51
+ " {'role':'user', 'content': \"\"\"Imagine a patient with cancer. \n",
52
+ " The cancer type is\"\"\" + cancer_types[i] + \".\" + \"\"\"\n",
53
+ " Then, generate a very detailed pathology report that might have been written about a specimen collected from the patient. The patient might have any stage of disease. Use everything you know about cancer, including biomarkers, epidemiology, and heterogeneity in disease presentations.\n",
54
+ " The report might be from a cytology specimen, anatomic pathology specimen, genomic sequencing analysis, bone marrow biopsy, flow cytometry, SPEP, etc.\n",
55
+ " The report should not include any treatment recommendations.\n",
56
+ " The pathology report should be approximately a full page long.\"\"\"}\n",
57
+ " ]\n",
58
+ " \n",
59
+ " prompts.append(tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False))\n",
60
+ " \n",
61
+ "\n",
62
+ " \n",
63
+ " responses = llama_model.generate(\n",
64
+ " prompts, \n",
65
+ " SamplingParams(\n",
66
+ " temperature=1.0,\n",
67
+ " top_p=0.9,\n",
68
+ " max_tokens=4000,\n",
69
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
70
+ " ))\n",
71
+ "\n",
72
+ " response_texts = [x.outputs[0].text for x in responses]\n",
73
+ "\n",
74
+ "\n",
75
+ " return pd.DataFrame({'cancer_type':cancer_types, 'synthetic_path_report':response_texts})"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "109f2208-de29-43cf-a831-4609bdab225e",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "results = generate_synthetic_path_reports(10000, llama)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "b40b06f0-69e9-48cc-9766-61d8d26178d6",
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "results.to_csv('synthetic_path_reports.csv')"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "id": "6119a0b4-5a63-4f91-a637-484b5e9dc29c",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": []
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "0470b150-dda1-4900-adb6-a90e5cf74e39",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": []
113
+ }
114
+ ],
115
+ "metadata": {
116
+ "kernelspec": {
117
+ "display_name": "Python 3 (ipykernel)",
118
+ "language": "python",
119
+ "name": "python3"
120
+ },
121
+ "language_info": {
122
+ "codemirror_mode": {
123
+ "name": "ipython",
124
+ "version": 3
125
+ },
126
+ "file_extension": ".py",
127
+ "mimetype": "text/x-python",
128
+ "name": "python",
129
+ "nbconvert_exporter": "python",
130
+ "pygments_lexer": "ipython3",
131
+ "version": "3.12.5"
132
+ }
133
+ },
134
+ "nbformat": 4,
135
+ "nbformat_minor": 5
136
+ }
2a_tag_chunks_of_synthetic_notes.ipynb ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a3d6ff53-2176-44aa-8590-ec0aa301342d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n",
19
+ "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,2\"\n"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "id": "62669512-19e7-43cd-a518-4572eea700af",
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, \n",
30
+ " gpu_memory_utilization=0.50,\n",
31
+ " download_dir = \"../../\", max_model_len=12000)"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "f87a20c9-fbea-4a09-9ffc-99c64bcd3709",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": []
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "aa6a0fb4-b22b-4e24-a4c1-55b95182fe60",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": []
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "id": "f32e5e86-2769-4dad-972c-00e6ddddb95a",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "import re\n",
58
+ "def tag_chunks(patient_texts, llama_model):\n",
59
+ " \n",
60
+ "\n",
61
+ "\n",
62
+ " tokenizer = llama_model.get_tokenizer()\n",
63
+ "\n",
64
+ " prompts = []\n",
65
+ " for the_patient in patient_texts:\n",
66
+ " temp_patient = re.sub(\"\\n|\\r\", \" \", the_patient.strip())\n",
67
+ " temp_patient = re.sub(r'\\s+', \" \", temp_patient)\n",
68
+ " sentences = \"<excerpt break>\" + re.sub(\"\\\\. \", \"<excerpt break>\", temp_patient) + \"<excerpt break>\"\n",
69
+ " \n",
70
+ " messages = [{'role':'system', 'content': \"\"\"You are an oncology clinical note data extraction bot.\n",
71
+ " Your job is to review a list of excerpts from a clinical document and extract the excerpts relevant to a list of questions.\n",
72
+ " \"\"\" \n",
73
+ " \n",
74
+ " },\n",
75
+ "\n",
76
+ " {'role':'user', 'content': \"The list of excerpts, separated by <excerpt break>, is: \" + sentences + \n",
77
+ " \"\"\"Now, list the excerpts relevant to any of the following questions.\n",
78
+ " Format your answer as JSON, tagging each excerpt that is relevant to at least one question with each tag to which it is relevant.\n",
79
+ " Here is the list of questions:\n",
80
+ " What type of cancer (primary site and histology) does the patient have? (Tag: cancer_type )\n",
81
+ " What was the stage at diagnosis? (Tag: stage_at_diagnosis)\n",
82
+ " What treatments (including surgery, radiation, or systemic therapy) has the patient received? (Tag: treatment)\n",
83
+ " How widespread is the cancer currently? (Tag: cancer_burden)\n",
84
+ " Is there response to therapy or progressive disease? (Tag: cancer_status)\n",
85
+ " Is the patient experiencing an adverse event of treatment? (Tag: adverse_event)\n",
86
+ " What biomarkers, such as protein expression and genetic mutations/alterations, does the patient's tumor have? (Tag: biomarker)\n",
87
+ " What comorbidities, or diseases other than cancer, does the patient have? (Tag: comorbidity)\n",
88
+ " Here is an example of the output format:\n",
89
+ " [{\"excerpt\": \"80M with metastatic lung adenocarcinoma.\", \"tags\": [\"cancer_type\", \"cancer_burden\"]},\n",
90
+ " {\"excerpt\": \"The tumor was HER2 positive.\", \"tags\": [\"biomarker\"]}\n",
91
+ " ]\n",
92
+ " \n",
93
+ " Do not include excerpts that are not relevant to the questions. \n",
94
+ " Do not abbreviate or alter excerpts that you do include; copy them verbatim from the prompt.\n",
95
+ " Do not add disclaimers or introductory text.\n",
96
+ " If there are no excerpts relevant to the above questions, just output blank JSON {} .\n",
97
+ " \"\"\"}\n",
98
+ " ]\n",
99
+ "\n",
100
+ " prompts.append(messages)\n",
101
+ "\n",
102
+ " long_messages = [x[1]['content'] for x in prompts]\n",
103
+ " trunc_messages = tokenizer.batch_decode([x[-10000:] for x in tokenizer(long_messages, add_special_tokens=False).input_ids])\n",
104
+ "\n",
105
+ " newprompts = []\n",
106
+ " for i, messages in enumerate(prompts):\n",
107
+ " messages[1]['content'] = trunc_messages[i]\n",
108
+ " template_prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
109
+ " newprompts.append(template_prompt)\n",
110
+ " \n",
111
+ "\n",
112
+ " \n",
113
+ " responses = llama_model.generate(\n",
114
+ " newprompts, \n",
115
+ " SamplingParams(\n",
116
+ " temperature=0.0,\n",
117
+ " top_p=0.2,\n",
118
+ " max_tokens=5000,\n",
119
+ " repetition_penalty=1.2,\n",
120
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
121
+ " ))\n",
122
+ "\n",
123
+ " response_texts = [x.outputs[0].text for x in responses]\n",
124
+ "\n",
125
+ "\n",
126
+ " return responses, response_texts\n",
127
+ " "
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "fd501ab8-65de-4098-9019-ead68ab9cd5e",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": []
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "id": "8a070d00-9a45-4360-a38f-ceed8a9360e1",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": []
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "064eef80-feae-407b-b2cd-ad7aa115c0de",
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "# pull in our synthetic notes\n",
154
+ "imaging = pd.read_csv('synthetic_imaging_reports.csv').rename(columns={'synthetic_imaging_report':'text'})\n",
155
+ "medonc = pd.read_csv('synthetic_clinical_notes.csv').rename(columns={'synthetic_clinical_note':'text'})\n",
156
+ "path = pd.read_csv('synthetic_path_reports.csv').rename(columns={'synthetic_path_report':'text'})\n"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "088f4db2-ef4e-45e6-bbf9-77cd94224e94",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "all_reports = pd.concat([imaging, medonc, path], axis=0)\n"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "id": "246cbc95-8130-493b-9422-9b204d0a381b",
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "all_reports.info()"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "1119b5e8-d8fc-416f-9d72-e3d5a5608066",
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": []
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "id": "3f9bece9-fa57-4bdc-a9e5-53f6b4f61caa",
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "response = tag_chunks(all_reports.sample(n=2).text.tolist(), llama)"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "id": "bee4a4e1-bf6a-4e9a-98da-828b5ddea3c2",
201
+ "metadata": {},
202
+ "outputs": [],
203
+ "source": [
204
+ "response[1]"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "id": "d0e8390b-ef71-4c49-8cef-6e565d28fe56",
211
+ "metadata": {},
212
+ "outputs": [],
213
+ "source": []
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "id": "18b1b677-9798-4172-809e-e763c6f30121",
219
+ "metadata": {
220
+ "scrolled": true
221
+ },
222
+ "outputs": [],
223
+ "source": [
224
+ "output_datasets = []\n",
225
+ "for i in range(0, all_reports.shape[0], 5000):\n",
226
+ " output_dataset = all_reports.iloc[i:(i+5000)]\n",
227
+ " output_dataset['llm_output'] = tag_chunks(output_dataset.text.tolist(), llama)[1]\n",
228
+ " output_datasets.append(output_dataset)\n",
229
+ " fileout = pd.concat(output_datasets, axis=0)\n",
230
+ " fileout.to_parquet('tagged_chunks.parquet')"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "id": "92a37f90-f1fe-4a6c-85fb-cc2941456b0d",
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": []
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": null,
244
+ "id": "8a47e9f4-be48-41c8-a3ce-a876c92e2961",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": []
248
+ }
249
+ ],
250
+ "metadata": {
251
+ "kernelspec": {
252
+ "display_name": "Python 3 (ipykernel)",
253
+ "language": "python",
254
+ "name": "python3"
255
+ },
256
+ "language_info": {
257
+ "codemirror_mode": {
258
+ "name": "ipython",
259
+ "version": 3
260
+ },
261
+ "file_extension": ".py",
262
+ "mimetype": "text/x-python",
263
+ "name": "python",
264
+ "nbconvert_exporter": "python",
265
+ "pygments_lexer": "ipython3",
266
+ "version": "3.9.18"
267
+ }
268
+ },
269
+ "nbformat": 4,
270
+ "nbformat_minor": 5
271
+ }
2b_train_tiny_bert_tagger.ipynb ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "f8fe89a3-56d5-4a58-9e2f-0a236c4a8409",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import numpy as np\n",
11
+ "import pandas as pd\n",
12
+ "import re\n",
13
+ "import json\n",
14
+ "import os\n",
15
+ "import torch\n",
16
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
17
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "id": "76a8061d-f35e-4d11-83c6-bfc46d68130b",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "summarized_notes = pd.read_parquet('./tagged_chunks.parquet')"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "id": "20bb4efc-4ca1-4179-8f0a-e89d80123358",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": []
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "a75ae543-ecca-4598-91f2-938998086801",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "summarized_notes.info()"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "fdeaf0de-0462-4e1d-b6b2-cd8e10c76ceb",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "def generate_rowwise_chunk_labels(original_note, llm_output, valid_tags_list):\n",
56
+ " valid_tags_array = np.array(valid_tags_list)\n",
57
+ " chunks = re.sub(\"\\n|\\r\", \" \", original_note.strip())\n",
58
+ " chunks = re.sub(r'\\s+', \" \", chunks)\n",
59
+ " chunks = \"<excerpt break>\" + re.sub(\"\\\\. \", \"<excerpt break>\", chunks) + \"<excerpt break>\"\n",
60
+ " chunks = pd.Series(chunks.split(\"<excerpt break>\")).str.strip()\n",
61
+ " chunks = chunks[chunks != '']\n",
62
+ " chunk_frame = pd.DataFrame({'excerpt':chunks})\n",
63
+ " tag_dict = {}\n",
64
+ " try:\n",
65
+ " json_output = pd.DataFrame.from_records(json.loads(llm_output))\n",
66
+ " json_output['tags'] = json_output['tags'].astype(str).str.strip(\"[|]\")\n",
67
+ "\n",
68
+ " chunk_frame = pd.merge(chunk_frame, json_output, on='excerpt', how='left')\n",
69
+ " chunk_frame['is_tagged'] = np.where(chunk_frame.tags.isnull(), 0, 1)\n",
70
+ " chunk_frame['tags'] = np.where(chunk_frame.tags.isnull(), \"\", chunk_frame.tags)\n",
71
+ " chunk_frame['good_json'] = 1\n",
72
+ " for tag in valid_tags_array[valid_tags_array != 'is_tagged'].tolist():\n",
73
+ " chunk_frame[tag] = np.where(chunk_frame.tags.str.contains(tag), 1, 0)\n",
74
+ " except:\n",
75
+ " chunk_frame['tags'] = \"\"\n",
76
+ " chunk_frame['is_tagged'] = 0\n",
77
+ " chunk_frame['good_json'] = 0\n",
78
+ " for tag in valid_tags_array[valid_tags_array != 'is_tagged'].tolist():\n",
79
+ " chunk_frame[tag] = 0\n",
80
+ " return chunk_frame"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": null,
86
+ "id": "aaeaf6a9-aa16-4770-9b36-b8f4c7bb2327",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "summarized_notes.info()"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "f00b0c00-c215-4254-8555-28aac4e0796c",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "\n",
101
+ "valid_tags_list = ['is_tagged','cancer_type','stage_at_diagnosis','treatment','cancer_burden','cancer_status','adverse_event','comorbidity','biomarker']"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "id": "bd1ff0a9-dfcf-4ffb-9aa9-2dec616bb5fd",
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "summarized_notes.shape[0]"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "8a75ed17-01c8-4187-9c10-d32f601cd81b",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "outputs = []\n",
122
+ "for i in range(summarized_notes.shape[0]):\n",
123
+ " out = generate_rowwise_chunk_labels(summarized_notes.text.iloc[i], summarized_notes.llm_output.iloc[i], valid_tags_list)\n",
124
+ " try:\n",
125
+ " if out['good_json'].iloc[0] == 1:\n",
126
+ " outputs.append(out)\n",
127
+ " except:\n",
128
+ " pass\n"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "0aa51527-77ff-4938-8699-c80941da2a3b",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "excerpts=pd.concat(outputs, axis=0)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "30a3b8a3-6948-4ec7-ae5c-308f9b1af03e",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "excerpts.shape"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "722d6d3d-ceec-4795-b41f-d6afd8d86890",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": []
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "5dbb8472-1f46-4e2e-a0ce-bc5ea28fc6f9",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": []
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "bb556c78-f370-4134-9a64-a8d5263b4e48",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "from torch.utils import data\n",
175
+ "from transformers import AutoTokenizer\n",
176
+ "\n",
177
+ "class TagDataset(data.Dataset):\n",
178
+ " def __init__(self, pandas_dataset, valid_tags_list):\n",
179
+ " self.data = pandas_dataset.copy().reset_index(drop=True)\n",
180
+ " self.indices = self.data.index.unique()\n",
181
+ " self.tokenizer = AutoTokenizer.from_pretrained('prajjwal1/bert-tiny', max_length=128, truncation_side='left') \n",
182
+ " self.valid_tags_list = valid_tags_list\n",
183
+ " \n",
184
+ " def __len__(self):\n",
185
+ " # how many notes in the dataset\n",
186
+ " return len(self.indices)\n",
187
+ " \n",
188
+ " def __getitem__(self, index):\n",
189
+ " # get data for notes corresponding to indices passed\n",
190
+ " this_index = self.indices[index]\n",
191
+ " pand = self.data.loc[this_index, :]\n",
192
+ " \n",
193
+ " encoded = self.tokenizer(pand['excerpt'], padding='max_length', max_length=128, truncation=True)\n",
194
+ "\n",
195
+ " x_text_tensor = torch.tensor(encoded.input_ids, dtype=torch.long)\n",
196
+ " x_attention_mask = torch.tensor(encoded.attention_mask, dtype=torch.long)\n",
197
+ " y_labels = torch.tensor([torch.tensor(pand[label], dtype=torch.float32) for label in self.valid_tags_list])\n",
198
+ " \n",
199
+ "\n",
200
+ " return x_text_tensor, x_attention_mask, y_labels\n",
201
+ " \n",
202
+ " "
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "9917bfdf-03f9-4c8f-a201-cef9d4744522",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "temp_dataset = TagDataset(excerpts.head(10), valid_tags_list)\n",
213
+ "temp_data = data.DataLoader(temp_dataset, shuffle=False, batch_size=1)"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "id": "8368008f-2a20-44a1-bf0b-41f35a04e66e",
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "temp_iter = iter(temp_data)\n",
224
+ "result = next(temp_iter)\n",
225
+ "result[0].shape, result[1].shape, torch.unbind(result[2], dim=1)[0].shape"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": null,
231
+ "id": "a75a6e89-8dd5-45db-b5ce-315eca7e409f",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "excerpts.info()"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": null,
241
+ "id": "b2812618-20f7-4c10-b0e0-7a904ed6f9c3",
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "valid_tags_list"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "id": "cb47772b-30da-4b70-959e-87bc57af94fc",
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "\n",
256
+ "from torch.nn import functional as F\n",
257
+ "import torch.nn as nn\n",
258
+ "from torch.utils.data import DataLoader\n",
259
+ "from torch.nn import LSTM, Linear, Embedding, Conv1d, MaxPool1d, GRU, LSTMCell, Dropout, Module, Sequential, ReLU\n",
260
+ "from transformers import AutoModel\n",
261
+ " \n",
262
+ "class TagModel(nn.Module):\n",
263
+ "\n",
264
+ " def __init__(self, num_tags, device):\n",
265
+ " super(TagModel, self).__init__()\n",
266
+ " \n",
267
+ " self.bert = AutoModel.from_pretrained('prajjwal1/bert-tiny').to(device)\n",
268
+ "\n",
269
+ " #self.prediction_head = Sequential(Linear(128, 128), ReLU(), Linear(128,1)).to(device)\n",
270
+ " self.prediction_heads = nn.ModuleList([Sequential(Linear(128, 128), ReLU(), Linear(128,1)).to(device) for x in range(0, num_tags)])\n",
271
+ " \n",
272
+ "\n",
273
+ " def forward(self, x_text_tensor, x_attention_mask):\n",
274
+ " # x should be tuple of input IDs, then attention mask\n",
275
+ " \n",
276
+ " main = self.bert(x_text_tensor, x_attention_mask)\n",
277
+ " main = main.last_hidden_state[:,0,:].squeeze(1)\n",
278
+ "\n",
279
+ " outputs = [x(main) for x in self.prediction_heads]\n",
280
+ " #outputs = [self.prediction_head(main)] \n",
281
+ "\n",
282
+ " return outputs\n"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "id": "e99f1867-5d53-48ba-b6c8-89b5470085f9",
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": [
292
+ "# train loop\n",
293
+ "from transformers import get_scheduler\n",
294
+ "from torch.optim import AdamW, Adam\n",
295
+ "#, get_linear_schedule_with_warmup\n",
296
+ "\n",
297
+ "\n",
298
+ "def train_model(model, num_epochs, num_tags, trainloader, validloader=None, device='cuda'):\n",
299
+ " \n",
300
+ " \n",
301
+ "\n",
302
+ " optimizer = AdamW(model.parameters(), lr=5e-5)\n",
303
+ " num_training_steps = num_epochs * len(trainloader)\n",
304
+ " lr_scheduler = get_scheduler(\n",
305
+ " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)\n",
306
+ "\n",
307
+ " model.to(device)\n",
308
+ " \n",
309
+ " for epoch in range(num_epochs): \n",
310
+ " running_train_losses = [0.0 for i in range(num_tags)]\n",
311
+ " mean_train_losses = [0.0 for i in range(num_tags)]\n",
312
+ " \n",
313
+ " running_valid_losses = [0.0 for i in range(num_tags)]\n",
314
+ " mean_valid_losses = [0.0 for i in range(num_tags)]\n",
315
+ "\n",
316
+ " num_train_batches = len(trainloader)\n",
317
+ " \n",
318
+ " model.train()\n",
319
+ " \n",
320
+ " for i, batch in enumerate(trainloader, 0):\n",
321
+ " input_ids = batch[0].to(device)\n",
322
+ " input_masks = batch[1].to(device)\n",
323
+ "\n",
324
+ " \n",
325
+ " optimizer.zero_grad()\n",
326
+ " \n",
327
+ " outputs_true = torch.unbind(batch[2].to(device), dim=1)\n",
328
+ " \n",
329
+ " outputs_pred = model(input_ids, input_masks)\n",
330
+ " \n",
331
+ " \n",
332
+ " losses = [F.binary_cross_entropy_with_logits(outputs_pred[x].squeeze(1), outputs_true[x]) for x in range(num_tags)]\n",
333
+ " \n",
334
+ " total_loss = 0.0\n",
335
+ " for j in range(num_tags):\n",
336
+ " total_loss = total_loss + losses[j]\n",
337
+ "\n",
338
+ " \n",
339
+ " total_loss.backward()\n",
340
+ " optimizer.step()\n",
341
+ " lr_scheduler.step()\n",
342
+ " \n",
343
+ " \n",
344
+ " for j in range(num_tags):\n",
345
+ " running_train_losses[j] += losses[j].detach().cpu().numpy()\n",
346
+ " mean_train_losses[j] = running_train_losses[j] / (i+1)\n",
347
+ "\n",
348
+ " if i % 10 == 0:\n",
349
+ " print('Training Epoch: ' + str(epoch+1) + ', batch: ' + str(i + 1) + '/' + str(num_train_batches) + ' this_loss:' + str(total_loss.detach().cpu().numpy()) +', train losses: ' + str([str(x) + ': ' + str(mean_train_losses[x]) + \", \" for x in range(num_tags)]), end='\\r', flush=True)\n",
350
+ " \n",
351
+ " print('\\n')\n",
352
+ " # eval on valid\n",
353
+ " \n",
354
+ " if validloader is not None:\n",
355
+ " num_valid_batches = len(validloader)\n",
356
+ " model.eval()\n",
357
+ " \n",
358
+ " for i, batch in enumerate(validloader, 0):\n",
359
+ " input_ids = batch[0].to(device)\n",
360
+ " input_masks = batch[1].to(device)\n",
361
+ "\n",
362
+ "\n",
363
+ " outputs_true = torch.unbind(batch[2].to(device), dim=1)\n",
364
+ "\n",
365
+ " outputs_pred = model(input_ids, input_masks)\n",
366
+ "\n",
367
+ " losses = [F.binary_cross_entropy_with_logits(outputs_pred[x].squeeze(1), outputs_true[x]) for x in range(num_tags)]\n",
368
+ "\n",
369
+ " total_loss = 0.0\n",
370
+ " for j in range(num_tags):\n",
371
+ " total_loss = total_loss + losses[j]\n",
372
+ "\n",
373
+ " for j in range(num_tags):\n",
374
+ " running_valid_losses[j] += losses[j].detach().cpu().numpy()\n",
375
+ "\n",
376
+ " \n",
377
+ " for j in range(num_tags):\n",
378
+ " mean_valid_losses[j] = running_valid_losses[j] / (i+1)\n",
379
+ " \n",
380
+ "\n",
381
+ " \n",
382
+ " print('Validation Epoch: ' + str(epoch+1) + ', batch: ' + str(i + 1) + '/' + str(num_valid_batches) + ', valid losses: ' + str([str(x) + ': ' + str(mean_valid_losses[x]) + \", \" for x in range(num_tags)]), end='\\r', flush=True)\n",
383
+ " print('\\n')\n",
384
+ "\n",
385
+ " "
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": null,
391
+ "id": "98c1aa95-e339-4e89-9c67-0215b399b7a5",
392
+ "metadata": {},
393
+ "outputs": [],
394
+ "source": [
395
+ "import torch\n",
396
+ "temp = torch.tensor([0]).to('cuda')"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": null,
402
+ "id": "7c408392-94f7-46eb-9400-05f28327db95",
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "# actual training code, commented out after model was trained\n",
407
+ "training = excerpts\n",
408
+ "\n",
409
+ "themodel = TagModel(len(valid_tags_list), device='cuda')\n",
410
+ "trainloader = data.DataLoader(TagDataset(training.reset_index(drop=True), valid_tags_list), batch_size=64, num_workers=4, shuffle=True)\n",
411
+ "#validloader = data.DataLoader(TagDataset(validation.reset_index(drop=True), valid_tags_list), batch_size=64, num_workers=4, shuffle=True)\n",
412
+ "train_model(themodel, 5, len(valid_tags_list), trainloader, validloader=None, device='cuda')"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "code",
417
+ "execution_count": null,
418
+ "id": "b223e6d9-dad2-4180-8832-9b5ef713ef52",
419
+ "metadata": {},
420
+ "outputs": [],
421
+ "source": [
422
+ "torch.save(themodel.state_dict(), './tiny_bert_tagger_synthetic.pt')"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "19dafc4b-f952-444f-b857-3c6743cc871d",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "# pull PHI notes from v7 for validation\n",
433
+ "phi_summarized_notes = pd.read_parquet(\"../v7/tagged_chunks_enrolled_pt_reports.parquet\")"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "bf9cab2b-eeb8-4665-a63f-382950a8c089",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "validation = phi_summarized_notes[phi_summarized_notes.split.str.contains(\"valid\")]"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": null,
449
+ "id": "7cc4f54c-906c-4309-b900-4fbc6a3db9e5",
450
+ "metadata": {},
451
+ "outputs": [],
452
+ "source": [
453
+ "val_outputs = []\n",
454
+ "for i in range(validation.shape[0]):\n",
455
+ " out = generate_rowwise_chunk_labels(validation.text.iloc[i], validation.llm_output.iloc[i], valid_tags_list)\n",
456
+ " try:\n",
457
+ " if out['good_json'].iloc[0] == 1:\n",
458
+ " val_outputs.append(out)\n",
459
+ " except:\n",
460
+ " pass\n"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": null,
466
+ "id": "76b29328-5254-4b32-8732-981a8696e7cb",
467
+ "metadata": {},
468
+ "outputs": [],
469
+ "source": [
470
+ "val_outputs = pd.concat(val_outputs, axis=0)"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "id": "a19b22ed-0811-40e7-8a60-f2092d002808",
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "val_outputs.is_tagged.value_counts()"
481
+ ]
482
+ },
483
+ {
484
+ "cell_type": "code",
485
+ "execution_count": null,
486
+ "id": "6e3a0623-fa86-4c9c-9b7a-dd6cf631d781",
487
+ "metadata": {},
488
+ "outputs": [],
489
+ "source": []
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "id": "621e9ff7-e155-403d-b1ab-96b8b1af2ad3",
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": []
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": null,
502
+ "id": "cf0c2a8f-6fa0-450f-9a3b-ee79c1f3ace8",
503
+ "metadata": {},
504
+ "outputs": [],
505
+ "source": [
506
+ "num_valid_tags = len(valid_tags_list)\n",
507
+ "themodel = TagModel(num_valid_tags, 'cuda')\n",
508
+ "themodel.load_state_dict(torch.load('./tiny_bert_tagger_synthetic.pt'))\n",
509
+ "themodel.to('cuda')\n",
510
+ "\n",
511
+ "themodel.eval()"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": null,
517
+ "id": "a051ea9f-f934-4f05-801e-f470276e6225",
518
+ "metadata": {},
519
+ "outputs": [],
520
+ "source": [
521
+ "# write out actual PHI validation dataset\n",
522
+ "\n",
523
+ "\n",
524
+ "num_valid_tags = len(valid_tags_list)\n",
525
+ "val_output = val_outputs.reset_index(drop=True)\n",
526
+ "no_shuffle_valid_dataset = data.DataLoader(TagDataset(val_outputs, valid_tags_list), batch_size=32, shuffle=False, num_workers=0)\n",
527
+ "\n",
528
+ "output_true_lists = [[] for x in range(num_valid_tags)]\n",
529
+ "output_prediction_lists = [[] for x in range(num_valid_tags)]\n",
530
+ "for batch in no_shuffle_valid_dataset:\n",
531
+ " x_text_ids = batch[0].to('cuda')\n",
532
+ " x_attention_mask = batch[1].to('cuda')\n",
533
+ " label_list = torch.unbind(batch[2], axis=1)\n",
534
+ " with torch.no_grad():\n",
535
+ " predictions = themodel(x_text_ids, x_attention_mask)\n",
536
+ "\n",
537
+ "\n",
538
+ " predictions = themodel(x_text_ids, x_attention_mask)\n",
539
+ "\n",
540
+ " \n",
541
+ " for j in range(num_valid_tags):\n",
542
+ " output_true_lists[j].append(label_list[j].detach().cpu().numpy())\n",
543
+ " output_prediction_lists[j].append(predictions[j].squeeze(1).detach().cpu().numpy())\n",
544
+ "\n",
545
+ "output_true_lists = [np.concatenate(x) for x in output_true_lists] \n",
546
+ "output_prediction_lists = [np.concatenate(x) for x in output_prediction_lists]\n",
547
+ "\n",
548
+ "\n",
549
+ "output_validation = val_outputs.copy()\n",
550
+ "for x in range(num_valid_tags):\n",
551
+ " output_validation['outcome_' + str(x) + '_logit'] = output_prediction_lists[x]\n",
552
+ " output_validation['truth_' + str(x)] = output_true_lists[x]\n"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "execution_count": null,
558
+ "id": "999ca687-5a27-4718-8d73-f6dcc9d5800d",
559
+ "metadata": {},
560
+ "outputs": [],
561
+ "source": []
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "execution_count": null,
566
+ "id": "986def2b-29c0-4807-9a30-e56b21ec535d",
567
+ "metadata": {},
568
+ "outputs": [],
569
+ "source": [
570
+ "from sklearn.metrics import roc_auc_score\n"
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "execution_count": null,
576
+ "id": "d395df39-cd74-4805-ac8d-e145a91e5bfe",
577
+ "metadata": {},
578
+ "outputs": [],
579
+ "source": [
580
+ "valid_tags_list"
581
+ ]
582
+ },
583
+ {
584
+ "cell_type": "code",
585
+ "execution_count": null,
586
+ "id": "cfee9187-8d6e-4741-ac92-cb6c996a1bfb",
587
+ "metadata": {},
588
+ "outputs": [],
589
+ "source": [
590
+ "pd.Series(output_true_lists[0]).value_counts()"
591
+ ]
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": null,
596
+ "id": "1c6e6939-b217-4821-af67-4a8c54786172",
597
+ "metadata": {},
598
+ "outputs": [],
599
+ "source": [
600
+ "# PHI valset metric AUROCs wer:e\n",
601
+ "# 0.8596166833121998\n",
602
+ "# 0.9855369622435224\n",
603
+ "# 0.9771353402092497\n",
604
+ "# 0.9645691289367063\n",
605
+ "# 0.9389302197266838\n",
606
+ "# 0.9564117372864042\n",
607
+ "# 0.9452735452766257\n",
608
+ "# 0.9234540539394782\n",
609
+ "# 0.9863098212461762"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "code",
614
+ "execution_count": null,
615
+ "id": "19623b7f-9485-4973-ac94-52d8c64aa43b",
616
+ "metadata": {},
617
+ "outputs": [],
618
+ "source": []
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": null,
623
+ "id": "433f9bfa-8468-4914-9a1b-a555887f5367",
624
+ "metadata": {},
625
+ "outputs": [],
626
+ "source": [
627
+ "[print(roc_auc_score(output_true_lists[x], output_prediction_lists[x])) for x in range(num_valid_tags)]"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "code",
632
+ "execution_count": null,
633
+ "id": "90c5d6ee-7d1e-4e38-957d-f7378cc6c891",
634
+ "metadata": {},
635
+ "outputs": [],
636
+ "source": [
637
+ "from utils_102023 import eval_model"
638
+ ]
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": null,
643
+ "id": "98da87ab-d0e1-4610-a7c9-7d47a0d54e8a",
644
+ "metadata": {},
645
+ "outputs": [],
646
+ "source": [
647
+ "best_f1_thresholds = [eval_model(output_prediction_lists[x], output_true_lists[x], graph=False) for x in range(num_valid_tags)]"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": null,
653
+ "id": "01dfb257-62d8-412d-a446-59d81111b2a5",
654
+ "metadata": {},
655
+ "outputs": [],
656
+ "source": [
657
+ "best_f1_thresholds"
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "code",
662
+ "execution_count": null,
663
+ "id": "afbcbe3d-d261-4f96-80e5-912ac664bf0c",
664
+ "metadata": {},
665
+ "outputs": [],
666
+ "source": []
667
+ },
668
+ {
669
+ "cell_type": "code",
670
+ "execution_count": null,
671
+ "id": "d8b41e47-a901-40b9-84b2-dcdd46af0e89",
672
+ "metadata": {},
673
+ "outputs": [],
674
+ "source": []
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": null,
679
+ "id": "fe2d639d-568b-47cf-894e-e0ffffbf24e8",
680
+ "metadata": {},
681
+ "outputs": [],
682
+ "source": []
683
+ }
684
+ ],
685
+ "metadata": {
686
+ "kernelspec": {
687
+ "display_name": "Python 3 (ipykernel)",
688
+ "language": "python",
689
+ "name": "python3"
690
+ },
691
+ "language_info": {
692
+ "codemirror_mode": {
693
+ "name": "ipython",
694
+ "version": 3
695
+ },
696
+ "file_extension": ".py",
697
+ "mimetype": "text/x-python",
698
+ "name": "python",
699
+ "nbconvert_exporter": "python",
700
+ "pygments_lexer": "ipython3",
701
+ "version": "3.9.18"
702
+ }
703
+ },
704
+ "nbformat": 4,
705
+ "nbformat_minor": 5
706
+ }
3_generate_synthetic_full_patient_histories.ipynb ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "180f9bc1-03cc-4e31-babe-3f6c6ecb0167",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "ef6b0609-695f-4975-9970-f8b8350f953d",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": []
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "f986b468-c428-4ca3-9101-1cbabe6ad73f",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "from vllm import LLM, SamplingParams\n",
29
+ "import pandas as pd\n",
30
+ "import numpy as np\n",
31
+ "import torch.nn.functional as F\n",
32
+ "import torch\n",
33
+ "from transformers import AutoTokenizer\n",
34
+ "from transformers import AutoModelForCausalLM\n",
35
+ "import re\n",
36
+ "import os\n",
37
+ "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "8e0537c8-85cc-4bae-97de-6dd6f70ea5a3",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, download_dir = \"../meta_ai/\", gpu_memory_utilization=0.75, max_model_len=20000)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "id": "e0394142-a749-4995-abc8-bac884fea671",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "def generate_synthetic_histories(num_histories, llama_model):\n",
58
+ " np.random.seed(42)\n",
59
+ " tokenizer = llama_model.get_tokenizer()\n",
60
+ " prompts = []\n",
61
+ " cancer_types = np.random.choice(['breast', 'non-small cell lung', 'small cell lung', 'colorectal', 'pancreatic', 'urothelial', 'prostate', 'gastric', 'esophageal', 'thymoma', 'thymic carcinoma', 'adrenal', 'ovarian', 'endometrial', 'melanoma', 'renal cell', 'sarcoma', 'head and neck', 'Hodgkin lymphoma', 'Non-Hodgkin lymphoma', 'myeloma', 'acute myeloid leukemia', 'chronic myeloid leukemia', 'acute lymphoblastic leukemia', 'chronic lymphocytic leukemia/lymphoma', 'primary brain tumor'], size=num_histories) \n",
62
+ " splits = np.random.choice(['train','val','test'], size=num_histories, p=[0.8,0.1,0.1])\n",
63
+ " for i in range(num_histories):\n",
64
+ " messages = [\n",
65
+ " {'role':'system', 'content': \"\"\"Your job is to generate synthetic clinical histories for hypothetical patients with cancer.\n",
66
+ " You know all there is to know about cancer and its treatments, so be detailed.\n",
67
+ " The histories should be presented in chronological order as a sequence of events. Each event should begin with a date, and should then include some new development, such as a diagnosis, treatment, adverse event, progression, response to therapy, biomarker ascertainment, symptom burden, recurrence events, and so on.\n",
68
+ " \n",
69
+ " \"\"\"}, \n",
70
+ " \n",
71
+ " {'role':'user', 'content': \"\"\"Imagine a patient with cancer. \n",
72
+ " The cancer type is \"\"\" + cancer_types[i] + \"\"\".\n",
73
+ " Then, generate a very detailed synthetic clinical history for the patient. The patient might have any stage of disease. Use everything you know about cancer, including epidemiology, treatment options, outcomes, and heterogeneity in disease trajectories.\n",
74
+ " Do not mention transitions to hospice or death events.\n",
75
+ " Do not start with any demographics; just launch into the chronological history. Phrase it in the past tense. Dates should be in mm/dd/yyyy format. Output should be plain text, not Markdown.\n",
76
+ " The history should be approximately two pages long.\"\"\"}\n",
77
+ " ]\n",
78
+ " \n",
79
+ " prompts.append(tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False))\n",
80
+ " \n",
81
+ "\n",
82
+ " \n",
83
+ " responses = llama_model.generate(\n",
84
+ " prompts, \n",
85
+ " SamplingParams(\n",
86
+ " temperature=1.0,\n",
87
+ " top_p=0.9,\n",
88
+ " max_tokens=5000,\n",
89
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
90
+ " ))\n",
91
+ "\n",
92
+ " response_texts = [x.outputs[0].text for x in responses]\n",
93
+ "\n",
94
+ "\n",
95
+ " return pd.DataFrame({'split':splits, 'cancer_type':cancer_types, 'patient_long_text':response_texts})"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "id": "109f2208-de29-43cf-a831-4609bdab225e",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "results = generate_synthetic_histories(30000, llama)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "6119a0b4-5a63-4f91-a637-484b5e9dc29c",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "results.to_csv('synthetic_histories_11-22-24.csv')"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "71a24b19-5e1c-4c24-b13b-ac04c1e94bd2",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": []
125
+ }
126
+ ],
127
+ "metadata": {
128
+ "kernelspec": {
129
+ "display_name": "Python 3 (ipykernel)",
130
+ "language": "python",
131
+ "name": "python3"
132
+ },
133
+ "language_info": {
134
+ "codemirror_mode": {
135
+ "name": "ipython",
136
+ "version": 3
137
+ },
138
+ "file_extension": ".py",
139
+ "mimetype": "text/x-python",
140
+ "name": "python",
141
+ "nbconvert_exporter": "python",
142
+ "pygments_lexer": "ipython3",
143
+ "version": "3.9.18"
144
+ }
145
+ },
146
+ "nbformat": 4,
147
+ "nbformat_minor": 5
148
+ }
4_summarize_synthetic_histories.ipynb ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a3d6ff53-2176-44aa-8590-ec0aa301342d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n",
19
+ "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": null,
25
+ "id": "8a070d00-9a45-4360-a38f-ceed8a9360e1",
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": []
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "id": "0f407048-0eb3-439a-8257-3cb6881ac784",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "import pandas as pd\n",
38
+ "synthetic_histories = pd.read_csv('synthetic_histories_11-22-24.csv')"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "9bc40636-2325-4664-afc3-833b58fe7ba0",
45
+ "metadata": {
46
+ "scrolled": true
47
+ },
48
+ "outputs": [],
49
+ "source": [
50
+ "synthetic_histories.info()"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "a9b4cae4-d46d-4a80-841c-8c8f08915b90",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": []
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "ca2b0678-119e-47a7-9a72-28685e97559d",
65
+ "metadata": {
66
+ "scrolled": true
67
+ },
68
+ "outputs": [],
69
+ "source": [
70
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, download_dir = \"../../\", gpu_memory_utilization=0.90, max_model_len=120000)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "f19be1ca-334c-4285-b8b7-0c9fbc83d0d4",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": []
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "02b9f891-4b50-4b64-9954-8481056cba79",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "def summarize_patients(patient_texts, llama_model):\n",
89
+ " \n",
90
+ "\n",
91
+ " prompts = []\n",
92
+ "\n",
93
+ " tokenizer = llama_model.get_tokenizer()\n",
94
+ "\n",
95
+ " prompts = []\n",
96
+ " for the_patient in patient_texts:\n",
97
+ "\n",
98
+ "\n",
99
+ " \n",
100
+ " messages = [{'role':'system', 'content': \"\"\"You are an experienced clinical oncology history summarization bot.\n",
101
+ " Your job is to construct a summary of the cancer history for a patient based on an excerpt of the patient's electronic health record. The text in the excerpt is provided in chronological order. \n",
102
+ " Document the cancer type/primary site (eg breast cancer, lung cancer, etc); histology (eg adenocarcinoma, squamous carcinoma, etc); current extent (localized, advanced, metastatic, etc); biomarkers (genomic results, protein expression, etc); and treatment history (surgery, radiation, chemotherapy/targeted therapy/immunotherapy, etc, including start and stop dates and best response if known).\n",
103
+ " Do not consider localized basal cell or squamous carcinomas of the skin, or colon polyps, to be cancers for your purposes.\n",
104
+ " Do not include the patient's name, but do include relevant dates whenever documented, including dates of diagnosis and start/stop dates of each treatment.\n",
105
+ " If a patient has a history of more than one cancer, document the cancers one at a time.\n",
106
+ " \"\"\"}, \n",
107
+ " {'role':'user', 'content': \"The excerpt is:\\n\" + the_patient + \"\"\"Now, write your summary. Do not add preceding text before the abstraction, and do not add notes or commentary afterwards. This will not be used for clinical care, so do not write any disclaimers or cautionary notes.\"\"\"}\n",
108
+ "\n",
109
+ " ]\n",
110
+ " \n",
111
+ "\n",
112
+ "\n",
113
+ " prompts.append(messages)\n",
114
+ "\n",
115
+ " long_messages = [x[1]['content'] for x in prompts]\n",
116
+ " trunc_messages = tokenizer.batch_decode([x[-115000:] for x in tokenizer(long_messages, add_special_tokens=False).input_ids])\n",
117
+ "\n",
118
+ " newprompts = []\n",
119
+ " for i, messages in enumerate(prompts):\n",
120
+ " messages[1]['content'] = trunc_messages[i]\n",
121
+ " template_prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
122
+ " newprompts.append(template_prompt)\n",
123
+ " \n",
124
+ "\n",
125
+ " \n",
126
+ " responses = llama_model.generate(\n",
127
+ " newprompts, \n",
128
+ " SamplingParams(\n",
129
+ " temperature=0.0,\n",
130
+ " top_p=0.2,\n",
131
+ " max_tokens=4096,\n",
132
+ " repetition_penalty=1.2,\n",
133
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
134
+ " ))\n",
135
+ "\n",
136
+ " response_texts = [x.outputs[0].text for x in responses]\n",
137
+ "\n",
138
+ "\n",
139
+ " return responses, response_texts\n",
140
+ " "
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "69bc8576-e6d7-452f-b6b0-15df7f4c8922",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "synthetic_histories.info()"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "bd443d34-c5db-414e-9892-eec368ef7ad6",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "# example summary generation for one synthetic patient\n",
161
+ "patient_summaries = summarize_patients(synthetic_histories.patient_long_text.iloc[10025:10026].tolist(), llama)"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "id": "6b5f0b1a-6df4-4d32-9072-efb4136df070",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "patient_summaries[1]"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": null,
177
+ "id": "dabd98af-947e-40c0-aea8-7805bb5b1c3c",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": []
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "74b2a972-9271-4ed2-9c2c-5ec5793e8650",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "patient_summaries = summarize_patients(synthetic_histories.patient_long_text.tolist(), llama)"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "e6b772c9-c4dd-45c2-8a4a-9e5c17d25e2c",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "output = synthetic_histories.copy()\n",
200
+ "output['patient_summary'] = patient_summaries[1]\n",
201
+ "output.to_parquet('synthetic_pt_summaries_11-22-24.parquet')"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "id": "d30bf018-e135-40be-b636-0ba17acf8e61",
208
+ "metadata": {},
209
+ "outputs": [],
210
+ "source": [
211
+ "import pandas as pd"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": null,
217
+ "id": "9f9ed498-4927-46a1-a23e-bf9f3a0cc544",
218
+ "metadata": {},
219
+ "outputs": [],
220
+ "source": [
221
+ "output = pd.read_parquet('synthetic_pt_summaries_11-22-24.parquet')"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "a6e1e9ce-e984-458b-881c-a99e3336e6c6",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": [
231
+ "output.info()"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "id": "5baf640d-1a6d-447e-84c2-d09a2a94a65a",
238
+ "metadata": {},
239
+ "outputs": [],
240
+ "source": [
241
+ "output.patient_summary.sample(n=1).iloc[0]"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": null,
247
+ "id": "633ab065-8620-4519-af61-d9e76849cbdf",
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "output['patient_summary'].str.contains(\"Lung\").value_counts()"
252
+ ]
253
+ }
254
+ ],
255
+ "metadata": {
256
+ "kernelspec": {
257
+ "display_name": "Python 3 (ipykernel)",
258
+ "language": "python",
259
+ "name": "python3"
260
+ },
261
+ "language_info": {
262
+ "codemirror_mode": {
263
+ "name": "ipython",
264
+ "version": 3
265
+ },
266
+ "file_extension": ".py",
267
+ "mimetype": "text/x-python",
268
+ "name": "python",
269
+ "nbconvert_exporter": "python",
270
+ "pygments_lexer": "ipython3",
271
+ "version": "3.9.18"
272
+ }
273
+ },
274
+ "nbformat": 4,
275
+ "nbformat_minor": 5
276
+ }
5a_make_top_10_cohorts_llama_check_list.ipynb ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "07d99d54-84c8-4531-8951-133dfaf64c1e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import numpy as np\n",
12
+ "import os\n",
13
+ "#os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
14
+ "from sentence_transformers import SentenceTransformer, InputExample, losses\n",
15
+ "from torch.utils.data import DataLoader\n",
16
+ "import torch.nn.functional as F\n",
17
+ "import torch\n",
18
+ "from sklearn.metrics import roc_auc_score"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "b7a22e15-81dd-4ab4-b7df-1d8da8992685",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": []
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "3553d4f3-c23c-4d21-94ad-da2bcd31a63d",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "test_spaces = pd.read_csv('test_spaces.csv')\n",
37
+ "test_spaces.info()"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "e4757eb7-4a69-4578-9020-075e960817ce",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import pandas as pd\n",
48
+ "dfci_trials = pd.read_csv(\"../v7/space_specific_eligibility_checks_11-6-24.csv\")\n",
49
+ "# this will not run out of the box, because the dfci trials file is not included in the upload, since it contains PHI/IP\n",
50
+ "\n",
51
+ "other_trials = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')\n",
52
+ "other_trials = other_trials[~other_trials.nct_id.isin(test_spaces.nct_id)]\n",
53
+ "other_trials = other_trials[~other_trials.nct_id.isin(dfci_trials.nct_id)]\n",
54
+ "\n",
55
+ "unique_trials = other_trials.groupby('nct_id').first().reset_index()[['nct_id', 'this_space']]\n",
56
+ "unique_trials.shape[0]\n",
57
+ "\n",
58
+ "unique_trial_sample = unique_trials.nct_id.sample(n=500, random_state=42)\n",
59
+ "\n",
60
+ "valid_spaces = unique_trials[unique_trials.nct_id.isin(unique_trial_sample)]\n",
61
+ "\n",
62
+ "valid_spaces.to_csv('valid_spaces.csv')\n",
63
+ "\n"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "8d6b5654-2f5f-4caa-82ff-b02a9007dc49",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "train_spaces = unique_trials[~unique_trials.nct_id.isin(valid_spaces.nct_id)]"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "01e0eb5b-9282-48b8-9917-3147fbf25730",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "train_spaces.nct_id.isin(valid_spaces.nct_id).value_counts()"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "ed9139ac-a5f2-4791-a3a4-fd2127b7af0c",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "train_spaces.info()"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "cfa78d0d-3de1-48f7-a01e-f7d81a6a5b3f",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "patients = pd.read_parquet('synthetic_pt_summaries_11-23-24.parquet')\n",
104
+ "patients = patients[patients.split == 'train'][['patient_summary','split']]"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "fcba54bc-aba3-4d02-b874-c8275badf015",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "patients.info()"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "399e05a5-7a58-4362-a002-0bce62a348ac",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "train_unique_patient_summaries = patients.patient_summary.unique().tolist()\n",
125
+ "print(len(train_unique_patient_summaries))\n",
126
+ "train_unique_spaces = train_spaces.this_space.unique().tolist()\n",
127
+ "print(len(train_unique_spaces))"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "c37f1137-41cf-4a6d-8099-7bed5838a1ee",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": []
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "id": "06e7e0f2-5382-4507-908e-5ccf92b4beae",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "embedding_model = SentenceTransformer('dunzhang/stella_en_1.5B_v5', trust_remote_code=True, device='cuda')"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "77659362-cf51-4906-9e12-5d1e55440b25",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "with torch.no_grad():\n",
156
+ " train_unique_patient_embeddings = embedding_model.encode(train_unique_patient_summaries, convert_to_tensor=True, prompt_name = \"s2s_query\")"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "a0bb8926-3fa8-49bc-b507-6cc36d998600",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "with torch.no_grad():\n",
167
+ " train_unique_space_embeddings = embedding_model.encode(train_unique_spaces, convert_to_tensor=True, prompt_name = \"s2s_query\")"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "bf259f29-4934-41e8-9b3a-87f09d0ef52e",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "output_list = []\n",
178
+ "train_unique_space_series = pd.Series(train_unique_spaces)\n",
179
+ "for i, patient_summary in enumerate(train_unique_patient_summaries):\n",
180
+ " patient_embedding = train_unique_patient_embeddings[i, :]\n",
181
+ " similarities = F.cosine_similarity(patient_embedding, train_unique_space_embeddings)\n",
182
+ " sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n",
183
+ " relevant_spaces = train_unique_space_series.iloc[sorted_indices[0:10].cpu().numpy()]\n",
184
+ " output = pd.DataFrame({'patient_summary':patient_summary, 'this_space':relevant_spaces})\n",
185
+ " output_list.append(output)\n",
186
+ "\n",
187
+ "train_output = pd.concat(output_list, axis=0).reset_index(drop=True)\n",
188
+ "train_output['patient_summary'] = train_output.patient_summary.str.strip()\n",
189
+ "train_output['split'] = 'train'"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "52c926cc-d481-4c43-8288-7d1620c1f06f",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "train_output.to_csv('top_ten_cohorts_tocheck_synthetic.csv')"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "0e0e18c1-bfb6-4c20-9aa4-257f0fb0424c",
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": []
209
+ }
210
+ ],
211
+ "metadata": {
212
+ "kernelspec": {
213
+ "display_name": "Python 3 (ipykernel)",
214
+ "language": "python",
215
+ "name": "python3"
216
+ },
217
+ "language_info": {
218
+ "codemirror_mode": {
219
+ "name": "ipython",
220
+ "version": 3
221
+ },
222
+ "file_extension": ".py",
223
+ "mimetype": "text/x-python",
224
+ "name": "python",
225
+ "nbconvert_exporter": "python",
226
+ "pygments_lexer": "ipython3",
227
+ "version": "3.9.18"
228
+ }
229
+ },
230
+ "nbformat": 4,
231
+ "nbformat_minor": 5
232
+ }
5b_check_top10_cohorts_synthetic.ipynb ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a3d6ff53-2176-44aa-8590-ec0aa301342d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "62669512-19e7-43cd-a518-4572eea700af",
25
+ "metadata": {
26
+ "scrolled": true
27
+ },
28
+ "outputs": [],
29
+ "source": [
30
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, \n",
31
+ " gpu_memory_utilization=0.80,\n",
32
+ " download_dir = \"../../\", max_model_len=5000)"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "16bca2af-0cf4-41f2-ae28-d2c669a1af21",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "def ask_about_trials_loosely(patient_summaries, trial_summaries, llama_model):\n",
43
+ "\n",
44
+ " tokenizer = llama_model.get_tokenizer()\n",
45
+ "\n",
46
+ " prompts = []\n",
47
+ "\n",
48
+ " for patient_summary, trial_summary in zip(patient_summaries, trial_summaries):\n",
49
+ " messages = [{'role':'system', 'content': \"\"\"You are a brilliant oncologist with encyclopedic knowledge about cancer and its treatment. \n",
50
+ " Your job is to evaluate whether a given clinical trial is a reasonable consideration for a patient, given a clinical trial summary and a patient summary.\\n\"\"\"}, \n",
51
+ " {'role':'user', 'content': \"Here is a summary of the clinical trial:\\n\" + trial_summary + \"\\nHere is a summary of the patient:\\n\" + patient_summary + \"\"\"\n",
52
+ "Base your judgment on whether the patient generally fits the cancer type(s), cancer burden, prior treatment(s), and biomarker criteria specified for the trial.\n",
53
+ "You do not have to determine if the patient is actually eligible; instead please just evaluate whether it is reasonable for the trial to be considered further by the patient's oncologist.\n",
54
+ "Some trials have biomarker requirements that are not assessed until formal eligibility screening begins; please ignore these requirements.\n",
55
+ "Reason step by step, then answer the question \"Is this trial a reasonable consideration for this patient?\" with a one-word \"Yes!\" or \"No!\" answer.\n",
56
+ "Make sure to include the exclamation point in your final one-word answer.\"\"\"}]\n",
57
+ "\n",
58
+ " \n",
59
+ " prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
60
+ " prompts.append(prompt)\n",
61
+ " \n",
62
+ " responses = llama_model.generate(\n",
63
+ " prompts, \n",
64
+ " SamplingParams(\n",
65
+ " temperature=0.0,\n",
66
+ " top_p=0.2,\n",
67
+ " max_tokens=2048,\n",
68
+ " repetition_penalty=1.2,\n",
69
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
70
+ " ))\n",
71
+ "\n",
72
+ " response_texts = [x.outputs[0].text for x in responses]\n",
73
+ "\n",
74
+ " eligibility_results = []\n",
75
+ "\n",
76
+ " for response_text in response_texts:\n",
77
+ " if (\"Yes!\" in response_text) or (\"YES!\" in response_text):\n",
78
+ " eligibility_results.append(1.0)\n",
79
+ " else:\n",
80
+ " eligibility_results.append(0.0)\n",
81
+ " \n",
82
+ " return responses, response_texts, eligibility_results\n",
83
+ " \n",
84
+ "\n",
85
+ " \n",
86
+ "\n"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "2ce4cce6-3833-451a-98c6-d7f4c7b948c6",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "patient_cohort_candidates = pd.read_csv('top_ten_cohorts_tocheck_synthetic.csv')"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "id": "dbb846d2-20e3-4361-b69b-82aa31c1f789",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": []
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "1fb1aa2f-6c28-4a4e-a4e1-bfd69d0b39a1",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "patient_cohort_candidates.info()"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "d30bf018-e135-40be-b636-0ba17acf8e61",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "%%capture\n",
125
+ "output_list = []\n",
126
+ "batch_list = []\n",
127
+ "\n",
128
+ "num_in_batch = 0\n",
129
+ "\n",
130
+ "for i in range(0, patient_cohort_candidates.shape[0]):\n",
131
+ " \n",
132
+ " batch_list.append(patient_cohort_candidates.iloc[[i]])\n",
133
+ " num_in_batch += 1\n",
134
+ " \n",
135
+ " if (num_in_batch == 500) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
136
+ "\n",
137
+ " output = pd.concat(batch_list, axis=0)\n",
138
+ " _, output['llama_response'], output['eligibility_result'] = ask_about_trials_loosely(output['patient_summary'].tolist(), output['this_space'].astype(str).tolist(), llama)\n",
139
+ "\n",
140
+ " output_list.append(output)\n",
141
+ " num_in_batch = 0\n",
142
+ " batch_list = []\n",
143
+ " \n",
144
+ " if (len(output_list) > 0 and (i % 500 == 0)) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
145
+ " output_file = pd.concat(output_list, axis=0)\n",
146
+ " output_file.to_csv('top_ten_cohorts_checked_synthetic.csv')\n"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "91534c0e-4873-4eda-9a69-53660a84b4df",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": []
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "id": "eaebffcc-4b62-4ab6-a077-69a6e4340773",
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": []
164
+ }
165
+ ],
166
+ "metadata": {
167
+ "kernelspec": {
168
+ "display_name": "Python 3 (ipykernel)",
169
+ "language": "python",
170
+ "name": "python3"
171
+ },
172
+ "language_info": {
173
+ "codemirror_mode": {
174
+ "name": "ipython",
175
+ "version": 3
176
+ },
177
+ "file_extension": ".py",
178
+ "mimetype": "text/x-python",
179
+ "name": "python",
180
+ "nbconvert_exporter": "python",
181
+ "pygments_lexer": "ipython3",
182
+ "version": "3.9.18"
183
+ }
184
+ },
185
+ "nbformat": 4,
186
+ "nbformat_minor": 5
187
+ }
6a_make_top_10_cohorts_llama_check_list.ipynb ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "07d99d54-84c8-4531-8951-133dfaf64c1e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import numpy as np\n",
12
+ "import os\n",
13
+ "#os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
14
+ "from sentence_transformers import SentenceTransformer, InputExample, losses\n",
15
+ "from torch.utils.data import DataLoader\n",
16
+ "import torch.nn.functional as F\n",
17
+ "import torch\n",
18
+ "from sklearn.metrics import roc_auc_score"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "b7a22e15-81dd-4ab4-b7df-1d8da8992685",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": []
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "3553d4f3-c23c-4d21-94ad-da2bcd31a63d",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "test_spaces = pd.read_csv('test_spaces.csv')\n",
37
+ "test_spaces.info()"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "e4757eb7-4a69-4578-9020-075e960817ce",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import pandas as pd\n",
48
+ "dfci_trials = pd.read_csv(\"../v7/space_specific_eligibility_checks_11-6-24.csv\")\n",
49
+ "# this will not run out of the box, because dfci_trials file was not included in upload, since it contains PHI\n",
50
+ "\n",
51
+ "other_trials = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')\n",
52
+ "other_trials = other_trials[~other_trials.nct_id.isin(test_spaces.nct_id)]\n",
53
+ "other_trials = other_trials[~other_trials.nct_id.isin(dfci_trials.nct_id)]\n",
54
+ "\n",
55
+ "unique_trials = other_trials.groupby('nct_id').first().reset_index()[['nct_id', 'this_space']]\n",
56
+ "unique_trials.shape[0]\n",
57
+ "\n",
58
+ "unique_trial_sample = unique_trials.nct_id.sample(n=500, random_state=42)\n",
59
+ "\n",
60
+ "valid_spaces = unique_trials[unique_trials.nct_id.isin(unique_trial_sample)]\n",
61
+ "\n",
62
+ "valid_spaces.to_csv('valid_spaces.csv')\n",
63
+ "\n"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "8d6b5654-2f5f-4caa-82ff-b02a9007dc49",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "train_spaces = unique_trials[~unique_trials.nct_id.isin(valid_spaces.nct_id)]"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "01e0eb5b-9282-48b8-9917-3147fbf25730",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "train_spaces.nct_id.isin(valid_spaces.nct_id).value_counts()"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "ed9139ac-a5f2-4791-a3a4-fd2127b7af0c",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "train_spaces.info()"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "cfa78d0d-3de1-48f7-a01e-f7d81a6a5b3f",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "patients = pd.read_parquet('synthetic_pt_summaries_11-23-24.parquet')\n",
104
+ "\n",
105
+ "patients = patients[patients.split == 'train'][['patient_summary','split']]"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "fcba54bc-aba3-4d02-b874-c8275badf015",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "patients.info()"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "399e05a5-7a58-4362-a002-0bce62a348ac",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "train_unique_patient_summaries = patients.patient_summary.unique().tolist()\n",
126
+ "print(len(train_unique_patient_summaries))\n",
127
+ "train_unique_spaces = train_spaces.this_space.unique().tolist()\n",
128
+ "print(len(train_unique_spaces))"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "c37f1137-41cf-4a6d-8099-7bed5838a1ee",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": []
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "id": "06e7e0f2-5382-4507-908e-5ccf92b4beae",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "embedding_model = SentenceTransformer('dunzhang/stella_en_1.5B_v5', trust_remote_code=True, device='cuda')"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "77659362-cf51-4906-9e12-5d1e55440b25",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "with torch.no_grad():\n",
157
+ " train_unique_patient_embeddings = embedding_model.encode(train_unique_patient_summaries, convert_to_tensor=True, prompt_name = \"s2s_query\")"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": null,
163
+ "id": "a0bb8926-3fa8-49bc-b507-6cc36d998600",
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "with torch.no_grad():\n",
168
+ " train_unique_space_embeddings = embedding_model.encode(train_unique_spaces, convert_to_tensor=True, prompt_name = \"s2s_query\")"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "id": "bf259f29-4934-41e8-9b3a-87f09d0ef52e",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "output_list = []\n",
179
+ "train_unique_patient_series = pd.Series(train_unique_patient_summaries)\n",
180
+ "for i, space_summary in enumerate(train_unique_spaces):\n",
181
+ " space_embedding = train_unique_space_embeddings[i, :]\n",
182
+ " similarities = F.cosine_similarity(space_embedding, train_unique_patient_embeddings)\n",
183
+ " sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n",
184
+ " relevant_patients = train_unique_patient_series.iloc[sorted_indices[0:20].cpu().numpy()]\n",
185
+ " output = pd.DataFrame({'space_summary':space_summary, 'this_patient':relevant_patients})\n",
186
+ " output_list.append(output)\n",
187
+ "\n",
188
+ "train_output = pd.concat(output_list, axis=0).reset_index(drop=True)\n",
189
+ "train_output['space_summary'] = train_output.space_summary.str.strip()\n",
190
+ "train_output['split'] = 'train'\n"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "id": "52c926cc-d481-4c43-8288-7d1620c1f06f",
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "train_output.to_csv('top_twenty_patients_tocheck_synthetic.csv')"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "0e0e18c1-bfb6-4c20-9aa4-257f0fb0424c",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": []
210
+ }
211
+ ],
212
+ "metadata": {
213
+ "kernelspec": {
214
+ "display_name": "Python 3 (ipykernel)",
215
+ "language": "python",
216
+ "name": "python3"
217
+ },
218
+ "language_info": {
219
+ "codemirror_mode": {
220
+ "name": "ipython",
221
+ "version": 3
222
+ },
223
+ "file_extension": ".py",
224
+ "mimetype": "text/x-python",
225
+ "name": "python",
226
+ "nbconvert_exporter": "python",
227
+ "pygments_lexer": "ipython3",
228
+ "version": "3.9.18"
229
+ }
230
+ },
231
+ "nbformat": 4,
232
+ "nbformat_minor": 5
233
+ }
6b_check_top20_patient_matches_synthetic.ipynb ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a3d6ff53-2176-44aa-8590-ec0aa301342d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n",
19
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'\n",
20
+ "\n"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "62669512-19e7-43cd-a518-4572eea700af",
27
+ "metadata": {
28
+ "scrolled": true
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, \n",
33
+ " gpu_memory_utilization=0.50,\n",
34
+ " download_dir = \"../../\", max_model_len=5000)"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "16bca2af-0cf4-41f2-ae28-d2c669a1af21",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def ask_about_trials_loosely(patient_summaries, trial_summaries, llama_model):\n",
45
+ "\n",
46
+ " tokenizer = llama_model.get_tokenizer()\n",
47
+ "\n",
48
+ " prompts = []\n",
49
+ "\n",
50
+ " for patient_summary, trial_summary in zip(patient_summaries, trial_summaries):\n",
51
+ " messages = [{'role':'system', 'content': \"\"\"You are a brilliant oncologist with encyclopedic knowledge about cancer and its treatment. \n",
52
+ " Your job is to evaluate whether a given clinical trial is a reasonable consideration for a patient, given a clinical trial summary and a patient summary.\\n\"\"\"}, \n",
53
+ " {'role':'user', 'content': \"Here is a summary of the clinical trial:\\n\" + trial_summary + \"\\nHere is a summary of the patient:\\n\" + patient_summary + \"\"\"\n",
54
+ "Base your judgment on whether the patient generally fits the cancer type(s), cancer burden, prior treatment(s), and biomarker criteria specified for the trial.\n",
55
+ "You do not have to determine if the patient is actually eligible; instead please just evaluate whether it is reasonable for the trial to be considered further by the patient's oncologist.\n",
56
+ "Some trials have biomarker requirements that are not assessed until formal eligibility screening begins; please ignore these requirements.\n",
57
+ "Reason step by step, then answer the question \"Is this trial a reasonable consideration for this patient?\" with a one-word \"Yes!\" or \"No!\" answer.\n",
58
+ "Make sure to include the exclamation point in your final one-word answer.\"\"\"}]\n",
59
+ "\n",
60
+ " \n",
61
+ " prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
62
+ " prompts.append(prompt)\n",
63
+ " \n",
64
+ " responses = llama_model.generate(\n",
65
+ " prompts, \n",
66
+ " SamplingParams(\n",
67
+ " temperature=0.0,\n",
68
+ " top_p=0.2,\n",
69
+ " max_tokens=2048,\n",
70
+ " repetition_penalty=1.2,\n",
71
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
72
+ " ))\n",
73
+ "\n",
74
+ " response_texts = [x.outputs[0].text for x in responses]\n",
75
+ "\n",
76
+ " eligibility_results = []\n",
77
+ "\n",
78
+ " for response_text in response_texts:\n",
79
+ " if (\"Yes!\" in response_text) or (\"YES!\" in response_text):\n",
80
+ " eligibility_results.append(1.0)\n",
81
+ " else:\n",
82
+ " eligibility_results.append(0.0)\n",
83
+ " \n",
84
+ " return responses, response_texts, eligibility_results\n",
85
+ " \n",
86
+ "\n",
87
+ " \n",
88
+ "\n"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "2ce4cce6-3833-451a-98c6-d7f4c7b948c6",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "patient_cohort_candidates = pd.read_csv('top_twenty_patients_tocheck_synthetic.csv')"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "dbb846d2-20e3-4361-b69b-82aa31c1f789",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "patient_cohort_candidates = patient_cohort_candidates.rename(columns={'this_patient':'patient_summary', 'space_summary':'this_space'})"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "1fb1aa2f-6c28-4a4e-a4e1-bfd69d0b39a1",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "patient_cohort_candidates.info()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "d30bf018-e135-40be-b636-0ba17acf8e61",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "%%capture\n",
129
+ "output_list = []\n",
130
+ "batch_list = []\n",
131
+ "\n",
132
+ "num_in_batch = 0\n",
133
+ "\n",
134
+ "for i in range(0, patient_cohort_candidates.shape[0]):\n",
135
+ " \n",
136
+ " batch_list.append(patient_cohort_candidates.iloc[[i]])\n",
137
+ " num_in_batch += 1\n",
138
+ " \n",
139
+ " if (num_in_batch == 500) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
140
+ "\n",
141
+ " output = pd.concat(batch_list, axis=0)\n",
142
+ " _, output['llama_response'], output['eligibility_result'] = ask_about_trials_loosely(output['patient_summary'].tolist(), output['this_space'].astype(str).tolist(), llama)\n",
143
+ "\n",
144
+ " output_list.append(output)\n",
145
+ " num_in_batch = 0\n",
146
+ " batch_list = []\n",
147
+ " \n",
148
+ " if (len(output_list) > 0 and (i % 500 == 0)) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
149
+ " output_file = pd.concat(output_list, axis=0)\n",
150
+ " output_file.to_csv('top_twenty_patients_checked_synthetic.csv')\n"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "91534c0e-4873-4eda-9a69-53660a84b4df",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": []
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "eaebffcc-4b62-4ab6-a077-69a6e4340773",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": []
168
+ }
169
+ ],
170
+ "metadata": {
171
+ "kernelspec": {
172
+ "display_name": "Python 3 (ipykernel)",
173
+ "language": "python",
174
+ "name": "python3"
175
+ },
176
+ "language_info": {
177
+ "codemirror_mode": {
178
+ "name": "ipython",
179
+ "version": 3
180
+ },
181
+ "file_extension": ".py",
182
+ "mimetype": "text/x-python",
183
+ "name": "python",
184
+ "nbconvert_exporter": "python",
185
+ "pygments_lexer": "ipython3",
186
+ "version": "3.12.5"
187
+ }
188
+ },
189
+ "nbformat": 4,
190
+ "nbformat_minor": 5
191
+ }
7_train_trialspace_round1.ipynb ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "81b83fa8-421d-4be5-b9eb-5892f01fd5b0",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import numpy as np\n",
12
+ "import os\n",
13
+ "#os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'\n",
14
+ "from sentence_transformers import SentenceTransformer, InputExample, losses\n",
15
+ "from torch.utils.data import DataLoader\n",
16
+ "import torch.nn.functional as F\n",
17
+ "import torch\n",
18
+ "from sklearn.metrics import roc_auc_score"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "937cbcda-0cd6-47f7-b52e-17ed2bafce3d",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "model = SentenceTransformer('dunzhang/stella_en_1.5b_v5', trust_remote_code=True, device='cuda')\n"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "id": "853e6b86-db0b-4650-b98f-f437987baa5a",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "cohort_checks = pd.read_csv('top_ten_cohorts_checked_synthetic.csv')"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "474950e0-869b-414e-823f-df5ba8e5de92",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "cohort_checks.info()"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "c9835dad-4fc4-4a0e-aba2-d45358edbee9",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "cohort_checks['mod_eligibility_result'] = np.where(cohort_checks.llama_response.str.contains('Yes!|YES!'), 1, 0)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "413bddbc-f35c-48ec-bcbb-48405bd2c9c9",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "cohort_checks.eligibility_result.value_counts()"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "79c2d994-1e39-41b6-aeb3-962ba3ba5611",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "cohort_checks.mod_eligibility_result.value_counts()"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "9c8e6a20-4513-422c-be6e-3459ce98a2be",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "patient_checks = pd.read_csv('top_twenty_patients_checked_synthetic.csv')"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "e91074cf-07de-40d1-8bf3-baf825d3f625",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "patient_checks['mod_eligibility_result'] = np.where(patient_checks.llama_response.str.contains('Yes!|YES!'), 1, 0)"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "3a2d45fc-7d92-4a65-aad0-a9ab8f783779",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "patient_checks.info()"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "0cf7c705-ba6d-4f01-ac2e-88d345fef7f6",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "patient_checks.eligibility_result.value_counts(), patient_checks.mod_eligibility_result.value_counts()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "dec4a8c8-c5db-4164-a06c-27fb59782fa5",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "patient_checks = patient_checks.rename(columns={'this_patient':'patient_summary', 'space_summary':'this_space'})"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "0bf55e82-c91d-472f-84ad-74c755e9bf29",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "combined_checks = pd.concat([patient_checks, cohort_checks], axis=0)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "ebd07c9c-6263-4005-bfb1-2a8468b76a98",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "combined_checks.info()"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "49f59429-c9f4-43df-a1b2-750a3c94517a",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": []
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "e1d0126e-a58d-41ca-ad2f-a2d37bc585ad",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "train_summaries = combined_checks[combined_checks.split=='train']\n",
167
+ "train_summaries = train_summaries[~train_summaries.patient_summary.isnull()]\n",
168
+ "train_summaries = train_summaries[~train_summaries.llama_response.isnull()]\n",
169
+ "train_summaries.split.value_counts()"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "id": "2f6506ed-dcbf-4e0b-8722-6e234c2d4509",
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "train_summaries.mod_eligibility_result.value_counts()"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "c678a59a-c301-42d9-83dd-511503cee2fb",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "train_summaries.info()"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "57932264-103a-413b-9a48-43b7be254ac0",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "# mll loss\n",
200
+ "train_eligibles_only = train_summaries[train_summaries.eligibility_result == 1]\n",
201
+ "example_list = []\n",
202
+ "for i in range(train_eligibles_only.shape[0]):\n",
203
+ " example_list.append(InputExample(texts=[train_summaries.patient_summary.iloc[i], train_summaries.this_space.iloc[i]]))\n",
204
+ "\n",
205
+ "train_eligibles_only_dataloader = DataLoader(example_list, shuffle=True, batch_size=8)\n",
206
+ "train_eligibles_only_loss = losses.MultipleNegativesRankingLoss(model=model)"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "id": "e5482be3-9a13-4ce1-aa8a-429c54bf6be0",
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "# for attempt at contrastive loss\n",
217
+ "contrastive_example_list = []\n",
218
+ "for i in range(train_summaries.shape[0]):\n",
219
+ " contrastive_example_list.append(InputExample(texts=[train_summaries.patient_summary.iloc[i], train_summaries.this_space.iloc[i]],\n",
220
+ " label=train_summaries.mod_eligibility_result.iloc[i]))\n",
221
+ "\n",
222
+ "contrastive_dataloader = DataLoader(contrastive_example_list, shuffle=True, batch_size=12)\n",
223
+ "contrastive_train_loss = losses.OnlineContrastiveLoss(model=model)"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "id": "4e825dae-a5a9-4f87-af35-63ac2d73de33",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": []
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "id": "f17ad7a6-8911-4d7d-8495-3e37cb00597d",
238
+ "metadata": {
239
+ "scrolled": true
240
+ },
241
+ "outputs": [],
242
+ "source": [
243
+ "#%%capture\n",
244
+ "model.fit(train_objectives=[(contrastive_dataloader, contrastive_train_loss),\n",
245
+ " (train_eligibles_only_dataloader, train_eligibles_only_loss)], epochs=2, warmup_steps=100)"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "id": "c9cb6021-21d8-44bf-b440-980fcdae3b3d",
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "model.save('reranker_round1.model')"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "id": "bae79a2e-4357-4c90-ba4c-a08b1206a99d",
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "model = SentenceTransformer('reranker_round1.model', trust_remote_code=True, device='cuda')"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "id": "f5517caa-c45b-4b62-ae8d-0af61b61fd25",
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": []
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "id": "c6bfb8f7-ca6b-474b-8ce3-ba5acacb6b6a",
280
+ "metadata": {},
281
+ "outputs": [],
282
+ "source": [
283
+ "# check model's ability to do initial discriminate among diseases task\n",
284
+ "# (on PHI)\n"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "4172f6ba-b334-4b83-b73e-d05dad6c05f0",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "cohort_checks = pd.read_csv('../v7/space_specific_eligibility_checks_11-6-24.csv')\n",
295
+ "# this cohort_checks file is not provided publicly, since it contains PHI/IP"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "id": "d6b25941-0007-4347-9ef3-899f9258542a",
302
+ "metadata": {},
303
+ "outputs": [],
304
+ "source": [
305
+ "validation_set = cohort_checks[cohort_checks.split.str.contains('valid')]\n",
306
+ "validation_set.info()\n"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "4b791608-6011-4bf6-914a-9534a08eba5a",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "validation_set = validation_set[~validation_set.patient_summary.isnull()]\n",
317
+ "validation_set.info()"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "id": "479b9905-fcd6-4d37-9b03-7bbbfb88f123",
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": [
327
+ "\n",
328
+ "eligibles_only = validation_set[validation_set.eligibility_result == 1]\n",
329
+ "patient_summary_embeddings = model.encode(eligibles_only.patient_summary.tolist())\n",
330
+ "trial_summary_embeddings = model.encode(eligibles_only.this_space.tolist())"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "id": "9b8f3a40-0854-43a5-bd83-a7fe6770f52b",
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "# among patient to trial space candidate matches that pass llama checks, how good is TrialSpace at discriminating between true and random matches?\n",
341
+ "import random\n",
342
+ "labels = []\n",
343
+ "similarities = []\n",
344
+ "for i in range(trial_summary_embeddings.shape[0]):\n",
345
+ " if random.choice([0,1]) == 1:\n",
346
+ " similarities.append(F.cosine_similarity(torch.tensor(patient_summary_embeddings[i,:]).unsqueeze(0), torch.tensor(trial_summary_embeddings[i, :]).unsqueeze(0)))\n",
347
+ " labels.append(1.)\n",
348
+ " else:\n",
349
+ " random_index = random.choice([x for x in range(0,trial_summary_embeddings.shape[0])])\n",
350
+ " similarities.append(F.cosine_similarity(torch.tensor(patient_summary_embeddings[i,:]).unsqueeze(0), torch.tensor(trial_summary_embeddings[random_index, :]).unsqueeze(0)))\n",
351
+ " labels.append(0.)\n",
352
+ "roc_auc_score(labels, np.array([x.numpy() for x in similarities]))"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "id": "16dd4634-0389-466d-8257-160ddd2659af",
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "# how good are embeddings at discriminating between llama yes and no checks?\n",
363
+ "# (on PHI)\n",
364
+ "patient_summary_embeddings = model.encode(validation_set.patient_summary.tolist(), convert_to_tensor=True)\n",
365
+ "trial_summary_embeddings = model.encode(validation_set.this_space.tolist(), convert_to_tensor=True)"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": null,
371
+ "id": "5bb0bc89-0b4f-451d-9523-550f7344e4d9",
372
+ "metadata": {},
373
+ "outputs": [],
374
+ "source": [
375
+ "similarities = F.cosine_similarity(patient_summary_embeddings, trial_summary_embeddings).detach().cpu().numpy()\n",
376
+ "roc_auc_score(validation_set.eligibility_result, similarities)"
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "id": "c6035e62-8d28-49c5-8d0a-049633edd553",
383
+ "metadata": {},
384
+ "outputs": [],
385
+ "source": []
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": null,
390
+ "id": "4f587899-0101-4d81-91a7-f8ef72be949f",
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "validation_set.eligibility_result.value_counts()"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "453c2f3c-105a-4b71-851c-372bf29d3fe8",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": []
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "id": "69a3fc1d-86f1-49f7-a93a-54f4748c5dbf",
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": []
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "id": "23d7d1f4-9f1f-42f6-a366-0e39af8893b2",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": []
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": null,
424
+ "id": "0c4415d5-d0fd-48ca-b88c-2e244434561d",
425
+ "metadata": {},
426
+ "outputs": [],
427
+ "source": []
428
+ }
429
+ ],
430
+ "metadata": {
431
+ "kernelspec": {
432
+ "display_name": "Python 3 (ipykernel)",
433
+ "language": "python",
434
+ "name": "python3"
435
+ },
436
+ "language_info": {
437
+ "codemirror_mode": {
438
+ "name": "ipython",
439
+ "version": 3
440
+ },
441
+ "file_extension": ".py",
442
+ "mimetype": "text/x-python",
443
+ "name": "python",
444
+ "nbconvert_exporter": "python",
445
+ "pygments_lexer": "ipython3",
446
+ "version": "3.9.18"
447
+ }
448
+ },
449
+ "nbformat": 4,
450
+ "nbformat_minor": 5
451
+ }
8a_make_top_10_cohorts_llama_check_list_round2.ipynb ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "07d99d54-84c8-4531-8951-133dfaf64c1e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import numpy as np\n",
12
+ "import os\n",
13
+ "#os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
14
+ "from sentence_transformers import SentenceTransformer, InputExample, losses\n",
15
+ "from torch.utils.data import DataLoader\n",
16
+ "import torch.nn.functional as F\n",
17
+ "import torch\n",
18
+ "from sklearn.metrics import roc_auc_score"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "b7a22e15-81dd-4ab4-b7df-1d8da8992685",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": []
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "3553d4f3-c23c-4d21-94ad-da2bcd31a63d",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "test_spaces = pd.read_csv('test_spaces.csv')\n",
37
+ "test_spaces.info()"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "e4757eb7-4a69-4578-9020-075e960817ce",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import pandas as pd\n",
48
+ "dfci_trials = pd.read_csv(\"../v7/space_specific_eligibility_checks_11-6-24.csv\")\n",
49
+ "# this dfci_trials file is not provided for public use, since it contains PHI, so this will not run out of the box\n",
50
+ "\n",
51
+ "other_trials = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')\n",
52
+ "other_trials = other_trials[~other_trials.nct_id.isin(test_spaces.nct_id)]\n",
53
+ "other_trials = other_trials[~other_trials.nct_id.isin(dfci_trials.nct_id)]\n",
54
+ "\n",
55
+ "unique_trials = other_trials.groupby('nct_id').first().reset_index()[['nct_id', 'this_space']]\n",
56
+ "unique_trials.shape[0]\n",
57
+ "\n",
58
+ "unique_trial_sample = unique_trials.nct_id.sample(n=500, random_state=42)\n",
59
+ "\n",
60
+ "valid_spaces = unique_trials[unique_trials.nct_id.isin(unique_trial_sample)]\n",
61
+ "\n",
62
+ "valid_spaces.to_csv('valid_spaces.csv')\n",
63
+ "\n"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "8d6b5654-2f5f-4caa-82ff-b02a9007dc49",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "train_spaces = unique_trials[~unique_trials.nct_id.isin(valid_spaces.nct_id)]"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "01e0eb5b-9282-48b8-9917-3147fbf25730",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "train_spaces.nct_id.isin(valid_spaces.nct_id).value_counts()"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "ed9139ac-a5f2-4791-a3a4-fd2127b7af0c",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "train_spaces.info()"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "cfa78d0d-3de1-48f7-a01e-f7d81a6a5b3f",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "patients = pd.read_parquet('synthetic_pt_summaries_11-23-24.parquet')\n",
104
+ "patients = patients[patients.split == 'train'][['patient_summary','split']]"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "fcba54bc-aba3-4d02-b874-c8275badf015",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "patients.info()"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "399e05a5-7a58-4362-a002-0bce62a348ac",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "train_unique_patient_summaries = patients.patient_summary.unique().tolist()\n",
125
+ "print(len(train_unique_patient_summaries))\n",
126
+ "train_unique_spaces = train_spaces.this_space.unique().tolist()\n",
127
+ "print(len(train_unique_spaces))"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "c37f1137-41cf-4a6d-8099-7bed5838a1ee",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": []
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "id": "06e7e0f2-5382-4507-908e-5ccf92b4beae",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "embedding_model = SentenceTransformer('reranker_round1.model', trust_remote_code=True, device='cuda')"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "77659362-cf51-4906-9e12-5d1e55440b25",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "with torch.no_grad():\n",
156
+ " train_unique_patient_embeddings = embedding_model.encode(train_unique_patient_summaries, convert_to_tensor=True, prompt_name = \"s2s_query\")"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "a0bb8926-3fa8-49bc-b507-6cc36d998600",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "with torch.no_grad():\n",
167
+ " train_unique_space_embeddings = embedding_model.encode(train_unique_spaces, convert_to_tensor=True, prompt_name = \"s2s_query\")"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "bf259f29-4934-41e8-9b3a-87f09d0ef52e",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "output_list = []\n",
178
+ "train_unique_space_series = pd.Series(train_unique_spaces)\n",
179
+ "for i, patient_summary in enumerate(train_unique_patient_summaries):\n",
180
+ " patient_embedding = train_unique_patient_embeddings[i, :]\n",
181
+ " similarities = F.cosine_similarity(patient_embedding, train_unique_space_embeddings)\n",
182
+ " sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n",
183
+ " relevant_spaces = train_unique_space_series.iloc[sorted_indices[0:10].cpu().numpy()]\n",
184
+ " output = pd.DataFrame({'patient_summary':patient_summary, 'this_space':relevant_spaces})\n",
185
+ " output_list.append(output)\n",
186
+ "\n",
187
+ "train_output = pd.concat(output_list, axis=0).reset_index(drop=True)\n",
188
+ "train_output['patient_summary'] = train_output.patient_summary.str.strip()\n",
189
+ "train_output['split'] = 'train'"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "52c926cc-d481-4c43-8288-7d1620c1f06f",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "train_output.to_csv('top_ten_cohorts_tocheck_synthetic_round2.csv')"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "0e0e18c1-bfb6-4c20-9aa4-257f0fb0424c",
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": []
209
+ }
210
+ ],
211
+ "metadata": {
212
+ "kernelspec": {
213
+ "display_name": "Python 3 (ipykernel)",
214
+ "language": "python",
215
+ "name": "python3"
216
+ },
217
+ "language_info": {
218
+ "codemirror_mode": {
219
+ "name": "ipython",
220
+ "version": 3
221
+ },
222
+ "file_extension": ".py",
223
+ "mimetype": "text/x-python",
224
+ "name": "python",
225
+ "nbconvert_exporter": "python",
226
+ "pygments_lexer": "ipython3",
227
+ "version": "3.9.18"
228
+ }
229
+ },
230
+ "nbformat": 4,
231
+ "nbformat_minor": 5
232
+ }
8b_check_top10_cohorts_synthetic_round2.ipynb ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a3d6ff53-2176-44aa-8590-ec0aa301342d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "62669512-19e7-43cd-a518-4572eea700af",
25
+ "metadata": {
26
+ "scrolled": true
27
+ },
28
+ "outputs": [],
29
+ "source": [
30
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 2, \n",
31
+ " gpu_memory_utilization=0.90,\n",
32
+ " download_dir = \"../../\", max_model_len=5000)"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "16bca2af-0cf4-41f2-ae28-d2c669a1af21",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "def ask_about_trials_loosely(patient_summaries, trial_summaries, llama_model):\n",
43
+ "\n",
44
+ " tokenizer = llama_model.get_tokenizer()\n",
45
+ "\n",
46
+ " prompts = []\n",
47
+ "\n",
48
+ " for patient_summary, trial_summary in zip(patient_summaries, trial_summaries):\n",
49
+ " messages = [{'role':'system', 'content': \"\"\"You are a brilliant oncologist with encyclopedic knowledge about cancer and its treatment. \n",
50
+ " Your job is to evaluate whether a given clinical trial is a reasonable consideration for a patient, given a clinical trial summary and a patient summary.\\n\"\"\"}, \n",
51
+ " {'role':'user', 'content': \"Here is a summary of the clinical trial:\\n\" + trial_summary + \"\\nHere is a summary of the patient:\\n\" + patient_summary + \"\"\"\n",
52
+ "Base your judgment on whether the patient generally fits the cancer type(s), cancer burden, prior treatment(s), and biomarker criteria specified for the trial.\n",
53
+ "You do not have to determine if the patient is actually eligible; instead please just evaluate whether it is reasonable for the trial to be considered further by the patient's oncologist.\n",
54
+ "Some trials have biomarker requirements that are not assessed until formal eligibility screening begins; please ignore these requirements.\n",
55
+ "Reason step by step, then answer the question \"Is this trial a reasonable consideration for this patient?\" with a one-word \"Yes!\" or \"No!\" answer.\n",
56
+ "Make sure to include the exclamation point in your final one-word answer.\"\"\"}]\n",
57
+ "\n",
58
+ " \n",
59
+ " prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
60
+ " prompts.append(prompt)\n",
61
+ " \n",
62
+ " responses = llama_model.generate(\n",
63
+ " prompts, \n",
64
+ " SamplingParams(\n",
65
+ " temperature=0.0,\n",
66
+ " top_p=0.2,\n",
67
+ " max_tokens=2048,\n",
68
+ " repetition_penalty=1.2,\n",
69
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
70
+ " ))\n",
71
+ "\n",
72
+ " response_texts = [x.outputs[0].text for x in responses]\n",
73
+ "\n",
74
+ " eligibility_results = []\n",
75
+ "\n",
76
+ " for response_text in response_texts:\n",
77
+ " if (\"Yes!\" in response_text) or (\"YES!\" in response_text):\n",
78
+ " eligibility_results.append(1.0)\n",
79
+ " else:\n",
80
+ " eligibility_results.append(0.0)\n",
81
+ " \n",
82
+ " return responses, response_texts, eligibility_results\n",
83
+ " \n",
84
+ "\n",
85
+ " \n",
86
+ "\n"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "2ce4cce6-3833-451a-98c6-d7f4c7b948c6",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "patient_cohort_candidates = pd.read_csv('top_ten_cohorts_tocheck_synthetic_round2.csv')"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "id": "dbb846d2-20e3-4361-b69b-82aa31c1f789",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": []
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "1fb1aa2f-6c28-4a4e-a4e1-bfd69d0b39a1",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "patient_cohort_candidates.info()"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "d30bf018-e135-40be-b636-0ba17acf8e61",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "%%capture\n",
125
+ "output_list = []\n",
126
+ "batch_list = []\n",
127
+ "\n",
128
+ "num_in_batch = 0\n",
129
+ "\n",
130
+ "for i in range(0, patient_cohort_candidates.shape[0]):\n",
131
+ " \n",
132
+ " batch_list.append(patient_cohort_candidates.iloc[[i]])\n",
133
+ " num_in_batch += 1\n",
134
+ " \n",
135
+ " if (num_in_batch == 500) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
136
+ "\n",
137
+ " output = pd.concat(batch_list, axis=0)\n",
138
+ " _, output['llama_response'], output['eligibility_result'] = ask_about_trials_loosely(output['patient_summary'].tolist(), output['this_space'].astype(str).tolist(), llama)\n",
139
+ "\n",
140
+ " output_list.append(output)\n",
141
+ " num_in_batch = 0\n",
142
+ " batch_list = []\n",
143
+ " \n",
144
+ " if (len(output_list) > 0 and (i % 500 == 0)) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
145
+ " output_file = pd.concat(output_list, axis=0)\n",
146
+ " output_file.to_csv('top_ten_cohorts_checked_synthetic_round2.csv')\n"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "91534c0e-4873-4eda-9a69-53660a84b4df",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": []
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "id": "eaebffcc-4b62-4ab6-a077-69a6e4340773",
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": []
164
+ }
165
+ ],
166
+ "metadata": {
167
+ "kernelspec": {
168
+ "display_name": "Python 3 (ipykernel)",
169
+ "language": "python",
170
+ "name": "python3"
171
+ },
172
+ "language_info": {
173
+ "codemirror_mode": {
174
+ "name": "ipython",
175
+ "version": 3
176
+ },
177
+ "file_extension": ".py",
178
+ "mimetype": "text/x-python",
179
+ "name": "python",
180
+ "nbconvert_exporter": "python",
181
+ "pygments_lexer": "ipython3",
182
+ "version": "3.9.18"
183
+ }
184
+ },
185
+ "nbformat": 4,
186
+ "nbformat_minor": 5
187
+ }
9a_make_top_20_patients_llama_check_list_round2.ipynb ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "07d99d54-84c8-4531-8951-133dfaf64c1e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import pandas as pd\n",
11
+ "import numpy as np\n",
12
+ "import os\n",
13
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
14
+ "from sentence_transformers import SentenceTransformer, InputExample, losses\n",
15
+ "from torch.utils.data import DataLoader\n",
16
+ "import torch.nn.functional as F\n",
17
+ "import torch\n",
18
+ "from sklearn.metrics import roc_auc_score"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "b7a22e15-81dd-4ab4-b7df-1d8da8992685",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": []
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "id": "3553d4f3-c23c-4d21-94ad-da2bcd31a63d",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "test_spaces = pd.read_csv('test_spaces.csv')\n",
37
+ "test_spaces.info()"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "e4757eb7-4a69-4578-9020-075e960817ce",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import pandas as pd\n",
48
+ "dfci_trials = pd.read_csv(\"../v7/space_specific_eligibility_checks_11-6-24.csv\")\n",
49
+ "# this dfci_trials file is not provided publicly, since it contains PHI, so this cell and dependent cells will not run out of the box\n",
50
+ "\n",
51
+ "other_trials = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')\n",
52
+ "other_trials = other_trials[~other_trials.nct_id.isin(test_spaces.nct_id)]\n",
53
+ "other_trials = other_trials[~other_trials.nct_id.isin(dfci_trials.nct_id)]\n",
54
+ "\n",
55
+ "unique_trials = other_trials.groupby('nct_id').first().reset_index()[['nct_id', 'this_space']]\n",
56
+ "unique_trials.shape[0]\n",
57
+ "\n",
58
+ "unique_trial_sample = unique_trials.nct_id.sample(n=500, random_state=42)\n",
59
+ "\n",
60
+ "valid_spaces = unique_trials[unique_trials.nct_id.isin(unique_trial_sample)]\n",
61
+ "\n",
62
+ "valid_spaces.to_csv('valid_spaces.csv')\n",
63
+ "\n"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "8d6b5654-2f5f-4caa-82ff-b02a9007dc49",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "train_spaces = unique_trials[~unique_trials.nct_id.isin(valid_spaces.nct_id)]"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "01e0eb5b-9282-48b8-9917-3147fbf25730",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "train_spaces.nct_id.isin(valid_spaces.nct_id).value_counts()"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "ed9139ac-a5f2-4791-a3a4-fd2127b7af0c",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "train_spaces.info()"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "cfa78d0d-3de1-48f7-a01e-f7d81a6a5b3f",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "patients = pd.read_parquet('synthetic_pt_summaries_11-23-24.parquet')\n",
104
+ "patients = patients[patients.split == 'train'][['patient_summary','split']]"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "fcba54bc-aba3-4d02-b874-c8275badf015",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "patients.info()"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "399e05a5-7a58-4362-a002-0bce62a348ac",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "train_unique_patient_summaries = patients.patient_summary.unique().tolist()\n",
125
+ "print(len(train_unique_patient_summaries))\n",
126
+ "train_unique_spaces = train_spaces.this_space.unique().tolist()\n",
127
+ "print(len(train_unique_spaces))"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "id": "c37f1137-41cf-4a6d-8099-7bed5838a1ee",
134
+ "metadata": {},
135
+ "outputs": [],
136
+ "source": []
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "id": "06e7e0f2-5382-4507-908e-5ccf92b4beae",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "embedding_model = SentenceTransformer('reranker_round1.model', trust_remote_code=True, device='cuda')"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "77659362-cf51-4906-9e12-5d1e55440b25",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "with torch.no_grad():\n",
156
+ " train_unique_patient_embeddings = embedding_model.encode(train_unique_patient_summaries, convert_to_tensor=True, prompt_name = \"s2s_query\")"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "a0bb8926-3fa8-49bc-b507-6cc36d998600",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "with torch.no_grad():\n",
167
+ " train_unique_space_embeddings = embedding_model.encode(train_unique_spaces, convert_to_tensor=True, prompt_name = \"s2s_query\")"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "bf259f29-4934-41e8-9b3a-87f09d0ef52e",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "output_list = []\n",
178
+ "train_unique_patient_series = pd.Series(train_unique_patient_summaries)\n",
179
+ "for i, space_summary in enumerate(train_unique_spaces):\n",
180
+ " space_embedding = train_unique_space_embeddings[i, :]\n",
181
+ " similarities = F.cosine_similarity(space_embedding, train_unique_patient_embeddings)\n",
182
+ " sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n",
183
+ " relevant_patients = train_unique_patient_series.iloc[sorted_indices[0:20].cpu().numpy()]\n",
184
+ " output = pd.DataFrame({'space_summary':space_summary, 'this_patient':relevant_patients})\n",
185
+ " output_list.append(output)\n",
186
+ "\n",
187
+ "train_output = pd.concat(output_list, axis=0).reset_index(drop=True)\n",
188
+ "train_output['space_summary'] = train_output.space_summary.str.strip()\n",
189
+ "train_output['split'] = 'train'\n"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "id": "52c926cc-d481-4c43-8288-7d1620c1f06f",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "train_output.to_csv('top_twenty_patients_tocheck_synthetic_round2.csv')"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "id": "0e0e18c1-bfb6-4c20-9aa4-257f0fb0424c",
206
+ "metadata": {},
207
+ "outputs": [],
208
+ "source": []
209
+ }
210
+ ],
211
+ "metadata": {
212
+ "kernelspec": {
213
+ "display_name": "Python 3 (ipykernel)",
214
+ "language": "python",
215
+ "name": "python3"
216
+ },
217
+ "language_info": {
218
+ "codemirror_mode": {
219
+ "name": "ipython",
220
+ "version": 3
221
+ },
222
+ "file_extension": ".py",
223
+ "mimetype": "text/x-python",
224
+ "name": "python",
225
+ "nbconvert_exporter": "python",
226
+ "pygments_lexer": "ipython3",
227
+ "version": "3.9.18"
228
+ }
229
+ },
230
+ "nbformat": 4,
231
+ "nbformat_minor": 5
232
+ }
9b_check_top20_patients_synthetic_round2.ipynb ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "a3d6ff53-2176-44aa-8590-ec0aa301342d",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from vllm import LLM, SamplingParams\n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "import torch.nn.functional as F\n",
14
+ "import torch\n",
15
+ "from transformers import AutoTokenizer\n",
16
+ "from transformers import AutoModelForCausalLM\n",
17
+ "import re\n",
18
+ "import os\n",
19
+ "#os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'\n",
20
+ "\n"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "62669512-19e7-43cd-a518-4572eea700af",
27
+ "metadata": {
28
+ "scrolled": true
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "llama = LLM(model='hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4', tensor_parallel_size = 4, \n",
33
+ " gpu_memory_utilization=0.20,\n",
34
+ " download_dir = \"../../\", max_model_len=5000)"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "16bca2af-0cf4-41f2-ae28-d2c669a1af21",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "def ask_about_trials_loosely(patient_summaries, trial_summaries, llama_model):\n",
45
+ "\n",
46
+ " tokenizer = llama_model.get_tokenizer()\n",
47
+ "\n",
48
+ " prompts = []\n",
49
+ "\n",
50
+ " for patient_summary, trial_summary in zip(patient_summaries, trial_summaries):\n",
51
+ " messages = [{'role':'system', 'content': \"\"\"You are a brilliant oncologist with encyclopedic knowledge about cancer and its treatment. \n",
52
+ " Your job is to evaluate whether a given clinical trial is a reasonable consideration for a patient, given a clinical trial summary and a patient summary.\\n\"\"\"}, \n",
53
+ " {'role':'user', 'content': \"Here is a summary of the clinical trial:\\n\" + trial_summary + \"\\nHere is a summary of the patient:\\n\" + patient_summary + \"\"\"\n",
54
+ "Base your judgment on whether the patient generally fits the cancer type(s), cancer burden, prior treatment(s), and biomarker criteria specified for the trial.\n",
55
+ "You do not have to determine if the patient is actually eligible; instead please just evaluate whether it is reasonable for the trial to be considered further by the patient's oncologist.\n",
56
+ "Some trials have biomarker requirements that are not assessed until formal eligibility screening begins; please ignore these requirements.\n",
57
+ "Reason step by step, then answer the question \"Is this trial a reasonable consideration for this patient?\" with a one-word \"Yes!\" or \"No!\" answer.\n",
58
+ "Make sure to include the exclamation point in your final one-word answer.\"\"\"}]\n",
59
+ "\n",
60
+ " \n",
61
+ " prompt = tokenizer.apply_chat_template(conversation=messages, add_generation_prompt=True, tokenize=False)\n",
62
+ " prompts.append(prompt)\n",
63
+ " \n",
64
+ " responses = llama_model.generate(\n",
65
+ " prompts, \n",
66
+ " SamplingParams(\n",
67
+ " temperature=0.0,\n",
68
+ " top_p=0.2,\n",
69
+ " max_tokens=2048,\n",
70
+ " repetition_penalty=1.2,\n",
71
+ " stop_token_ids=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")], # KEYPOINT HERE\n",
72
+ " ))\n",
73
+ "\n",
74
+ " response_texts = [x.outputs[0].text for x in responses]\n",
75
+ "\n",
76
+ " eligibility_results = []\n",
77
+ "\n",
78
+ " for response_text in response_texts:\n",
79
+ " if (\"Yes!\" in response_text) or (\"YES!\" in response_text):\n",
80
+ " eligibility_results.append(1.0)\n",
81
+ " else:\n",
82
+ " eligibility_results.append(0.0)\n",
83
+ " \n",
84
+ " return responses, response_texts, eligibility_results\n",
85
+ " \n",
86
+ "\n",
87
+ " \n",
88
+ "\n"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "id": "2ce4cce6-3833-451a-98c6-d7f4c7b948c6",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "patient_cohort_candidates = pd.read_csv('top_twenty_patients_tocheck_synthetic_round2.csv')"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "id": "dbb846d2-20e3-4361-b69b-82aa31c1f789",
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": [
108
+ "patient_cohort_candidates = patient_cohort_candidates.rename(columns={'this_patient':'patient_summary', 'space_summary':'this_space'})"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "1fb1aa2f-6c28-4a4e-a4e1-bfd69d0b39a1",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "patient_cohort_candidates.info()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "d30bf018-e135-40be-b636-0ba17acf8e61",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "%%capture\n",
129
+ "output_list = []\n",
130
+ "batch_list = []\n",
131
+ "\n",
132
+ "num_in_batch = 0\n",
133
+ "\n",
134
+ "for i in range(0, patient_cohort_candidates.shape[0]):\n",
135
+ " \n",
136
+ " batch_list.append(patient_cohort_candidates.iloc[[i]])\n",
137
+ " num_in_batch += 1\n",
138
+ " \n",
139
+ " if (num_in_batch == 500) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
140
+ "\n",
141
+ " output = pd.concat(batch_list, axis=0)\n",
142
+ " _, output['llama_response'], output['eligibility_result'] = ask_about_trials_loosely(output['patient_summary'].tolist(), output['this_space'].astype(str).tolist(), llama)\n",
143
+ "\n",
144
+ " output_list.append(output)\n",
145
+ " num_in_batch = 0\n",
146
+ " batch_list = []\n",
147
+ " \n",
148
+ " if (len(output_list) > 0 and (i % 500 == 0)) or (i == (patient_cohort_candidates.shape[0] - 1)):\n",
149
+ " output_file = pd.concat(output_list, axis=0)\n",
150
+ " output_file.to_csv('top_twenty_patients_checked_synthetic_round2.csv')\n"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "id": "91534c0e-4873-4eda-9a69-53660a84b4df",
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": []
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "eaebffcc-4b62-4ab6-a077-69a6e4340773",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": []
168
+ }
169
+ ],
170
+ "metadata": {
171
+ "kernelspec": {
172
+ "display_name": "Python 3 (ipykernel)",
173
+ "language": "python",
174
+ "name": "python3"
175
+ },
176
+ "language_info": {
177
+ "codemirror_mode": {
178
+ "name": "ipython",
179
+ "version": 3
180
+ },
181
+ "file_extension": ".py",
182
+ "mimetype": "text/x-python",
183
+ "name": "python",
184
+ "nbconvert_exporter": "python",
185
+ "pygments_lexer": "ipython3",
186
+ "version": "3.9.18"
187
+ }
188
+ },
189
+ "nbformat": 4,
190
+ "nbformat_minor": 5
191
+ }