kenlkehl commited on
Commit
787eac9
1 Parent(s): 61826b1

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ctgov_all_trials_trial_space_lineitems_10-31-24.csv filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -8,7 +8,7 @@ sdk_version: 5.9.1
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-2.0
11
- short_description: Gradio version of cancer trial search engine
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-2.0
11
+ short_description: For research use only. Not for clinical decision support.
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from sentence_transformers import SentenceTransformer
6
+ from safetensors import safe_open
7
+ from transformers import pipeline, AutoTokenizer
8
+
9
+ # Load trial spaces data
10
+ trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')
11
+
12
+ # Load embedding model
13
+ embedding_model = SentenceTransformer('ksg-dfci/trialspace', trust_remote_code=True)
14
+
15
+ # Load precomputed trial space embeddings
16
+ with safe_open("trial_space_embeddings.safetensors", framework="pt", device=0) as f:
17
+ trial_space_embeddings = f.get_tensor("space_embeddings")
18
+
19
+ # Load checker pipeline
20
+ tokenizer = AutoTokenizer.from_pretrained("roberta-large")
21
+ checker_pipe = pipeline('text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer,
22
+ truncation=True, padding='max_length', max_length=512)
23
+
24
+
25
+ import gradio as gr
26
+ import pandas as pd
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from sentence_transformers import SentenceTransformer
30
+ from safetensors import safe_open
31
+ from transformers import pipeline, AutoTokenizer
32
+
33
+ # We assume the following objects have already been loaded:
34
+ # trial_spaces (DataFrame), embedding_model (SentenceTransformer),
35
+ # trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)
36
+
37
+ def match_clinical_trials(patient_summary: str):
38
+ # Encode patient summary
39
+ patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
40
+
41
+ # Compute similarities
42
+ similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
43
+
44
+ # Pull top 10
45
+ sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
46
+ top_indices = sorted_indices[0:10].cpu().numpy()
47
+
48
+ relevant_spaces = trial_spaces.iloc[top_indices].this_space
49
+ relevant_nctid = trial_spaces.iloc[top_indices].nct_id
50
+ relevant_title = trial_spaces.iloc[top_indices].title
51
+ relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary
52
+ relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
53
+
54
+ analysis = pd.DataFrame({
55
+ 'patient_summary': patient_summary,
56
+ 'this_space': relevant_spaces,
57
+ 'nct_id': relevant_nctid,
58
+ 'trial_title': relevant_title,
59
+ 'trial_brief_summary': relevant_brief_summary,
60
+ 'trial_eligibility_criteria': relevant_eligibility_criteria
61
+ }).reset_index(drop=True)
62
+
63
+ analysis['pt_trial_pair'] = analysis['this_space'] + "\nNow here is the patient summary:" + analysis['patient_summary']
64
+
65
+ # Run checker pipeline
66
+ classifier_results = checker_pipe(analysis.pt_trial_pair.tolist())
67
+ analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
68
+ analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
69
+
70
+ # Return a subset of columns that are most relevant
71
+ return analysis[[
72
+ 'nct_id',
73
+ 'trial_title',
74
+ 'trial_brief_summary',
75
+ 'trial_eligibility_criteria',
76
+ 'trial_checker_result',
77
+ 'trial_checker_score'
78
+ ]]
79
+
80
+ custom_css = """
81
+ #input_box textarea {
82
+ width: 600px !important;
83
+ height: 250px !important;
84
+ }
85
+
86
+ #output_df table {
87
+ width: 100% !important;
88
+ table-layout: auto !important;
89
+ border-collapse: collapse !important;
90
+ }
91
+
92
+ #output_df table td, #output_df table th {
93
+ min-width: 100px;
94
+ overflow: hidden;
95
+ text-overflow: ellipsis;
96
+ white-space: nowrap;
97
+ border: 1px solid #ccc;
98
+ padding: 4px;
99
+ }
100
+ """
101
+
102
+ # JavaScript for enabling colResizable
103
+ js_script = """
104
+ <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
105
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/colResizable-1.6.min.js"></script>
106
+ <script>
107
+ document.addEventListener('DOMContentLoaded', function() {
108
+ var interval = setInterval(function() {
109
+ var table = document.querySelector('#output_df table');
110
+ if (table && typeof jQuery !== 'undefined' && typeof jQuery(table).colResizable === 'function') {
111
+ jQuery('#output_df table').colResizable({liveDrag:true});
112
+ clearInterval(interval);
113
+ }
114
+ }, 500);
115
+ });
116
+ </script>
117
+ """
118
+
119
+ with gr.Blocks(css=custom_css) as demo:
120
+ gr.HTML("<h3>Clinical Trial Matcher</h3>")
121
+ patient_summary_input = gr.Textbox(label="Enter Patient Summary", elem_id="input_box")
122
+ submit_btn = gr.Button("Find Matches")
123
+ output_df = gr.DataFrame(
124
+ headers=[
125
+ "nct_id",
126
+ "trial_title",
127
+ "trial_brief_summary",
128
+ "trial_eligibility_criteria",
129
+ "trial_checker_result",
130
+ "trial_checker_score"
131
+ ],
132
+ elem_id="output_df"
133
+ )
134
+
135
+ submit_btn.click(fn=match_clinical_trials,
136
+ inputs=patient_summary_input,
137
+ outputs=output_df)
138
+
139
+ gr.HTML(js_script)
140
+
141
+ if __name__ == "__main__":
142
+ demo.launch()
ctgov_all_trials_trial_space_lineitems_10-31-24.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e1253f9d2545f6db6e1c2b66105cbb9d54d2409a3f5ac266da9e30faf9262de
3
+ size 404858338
gradio_app.ipynb ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "d31b58d0-132d-4a98-b199-c3b1d2ed9eb5",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/klkehl/miniconda3/envs/vllm/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n",
15
+ "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.62s/it]\n"
16
+ ]
17
+ }
18
+ ],
19
+ "source": [
20
+ "import gradio as gr\n",
21
+ "import pandas as pd\n",
22
+ "import torch\n",
23
+ "import torch.nn.functional as F\n",
24
+ "from sentence_transformers import SentenceTransformer\n",
25
+ "from safetensors import safe_open\n",
26
+ "from transformers import pipeline, AutoTokenizer\n",
27
+ "\n",
28
+ "# Load trial spaces data\n",
29
+ "trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')\n",
30
+ "\n",
31
+ "# Load embedding model\n",
32
+ "embedding_model = SentenceTransformer('reranker_round2.model', trust_remote_code=True, device='cuda')\n",
33
+ "\n",
34
+ "# Load precomputed trial space embeddings\n",
35
+ "with safe_open(\"trial_space_embeddings.safetensors\", framework=\"pt\", device=0) as f:\n",
36
+ " trial_space_embeddings = f.get_tensor(\"space_embeddings\")\n",
37
+ "\n",
38
+ "# Load checker pipeline\n",
39
+ "tokenizer = AutoTokenizer.from_pretrained(\"roberta-large\")\n",
40
+ "checker_pipe = pipeline('text-classification', './roberta-checker', tokenizer=tokenizer, \n",
41
+ " truncation=True, padding='max_length', max_length=512, device='cuda')\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 11,
47
+ "id": "36d48a31-8514-4b0d-84a9-5fccc7ec7227",
48
+ "metadata": {},
49
+ "outputs": [
50
+ {
51
+ "name": "stdout",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "Running on local URL: http://127.0.0.1:7860\n",
55
+ "\n",
56
+ "To create a public link, set `share=True` in `launch()`.\n"
57
+ ]
58
+ },
59
+ {
60
+ "data": {
61
+ "text/html": [
62
+ "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
63
+ ],
64
+ "text/plain": [
65
+ "<IPython.core.display.HTML object>"
66
+ ]
67
+ },
68
+ "metadata": {},
69
+ "output_type": "display_data"
70
+ },
71
+ {
72
+ "data": {
73
+ "text/plain": []
74
+ },
75
+ "execution_count": 11,
76
+ "metadata": {},
77
+ "output_type": "execute_result"
78
+ }
79
+ ],
80
+ "source": [
81
+ "import gradio as gr\n",
82
+ "import pandas as pd\n",
83
+ "import torch\n",
84
+ "import torch.nn.functional as F\n",
85
+ "from sentence_transformers import SentenceTransformer\n",
86
+ "from safetensors import safe_open\n",
87
+ "from transformers import pipeline, AutoTokenizer\n",
88
+ "\n",
89
+ "# We assume the following objects have already been loaded:\n",
90
+ "# trial_spaces (DataFrame), embedding_model (SentenceTransformer),\n",
91
+ "# trial_space_embeddings (torch.tensor), checker_pipe (transformers pipeline)\n",
92
+ "\n",
93
+ "def match_clinical_trials(patient_summary: str):\n",
94
+ " # Encode patient summary\n",
95
+ " patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)\n",
96
+ " \n",
97
+ " # Compute similarities\n",
98
+ " similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)\n",
99
+ " \n",
100
+ " # Pull top 10\n",
101
+ " sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)\n",
102
+ " top_indices = sorted_indices[0:10].cpu().numpy()\n",
103
+ " \n",
104
+ " relevant_spaces = trial_spaces.iloc[top_indices].this_space\n",
105
+ " relevant_nctid = trial_spaces.iloc[top_indices].nct_id\n",
106
+ " relevant_title = trial_spaces.iloc[top_indices].title\n",
107
+ " relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary\n",
108
+ " relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria\n",
109
+ "\n",
110
+ " analysis = pd.DataFrame({\n",
111
+ " 'patient_summary': patient_summary, \n",
112
+ " 'this_space': relevant_spaces,\n",
113
+ " 'nct_id': relevant_nctid, \n",
114
+ " 'trial_title': relevant_title,\n",
115
+ " 'trial_brief_summary': relevant_brief_summary, \n",
116
+ " 'trial_eligibility_criteria': relevant_eligibility_criteria\n",
117
+ " }).reset_index(drop=True)\n",
118
+ " \n",
119
+ " analysis['pt_trial_pair'] = analysis['this_space'] + \"\\nNow here is the patient summary:\" + analysis['patient_summary']\n",
120
+ " \n",
121
+ " # Run checker pipeline\n",
122
+ " classifier_results = checker_pipe(analysis.pt_trial_pair.tolist())\n",
123
+ " analysis['trial_checker_result'] = [x['label'] for x in classifier_results]\n",
124
+ " analysis['trial_checker_score'] = [x['score'] for x in classifier_results]\n",
125
+ " \n",
126
+ " # Return a subset of columns that are most relevant\n",
127
+ " return analysis[[\n",
128
+ " 'nct_id', \n",
129
+ " 'trial_title', \n",
130
+ " 'trial_brief_summary', \n",
131
+ " 'trial_eligibility_criteria', \n",
132
+ " 'trial_checker_result', \n",
133
+ " 'trial_checker_score'\n",
134
+ " ]]\n",
135
+ "\n",
136
+ "custom_css = \"\"\"\n",
137
+ "#input_box textarea {\n",
138
+ " width: 600px !important;\n",
139
+ " height: 250px !important;\n",
140
+ "}\n",
141
+ "\n",
142
+ "#output_df table {\n",
143
+ " width: 100% !important;\n",
144
+ " table-layout: auto !important;\n",
145
+ " border-collapse: collapse !important;\n",
146
+ "}\n",
147
+ "\n",
148
+ "#output_df table td, #output_df table th {\n",
149
+ " min-width: 100px;\n",
150
+ " overflow: hidden;\n",
151
+ " text-overflow: ellipsis;\n",
152
+ " white-space: nowrap;\n",
153
+ " border: 1px solid #ccc;\n",
154
+ " padding: 4px;\n",
155
+ "}\n",
156
+ "\"\"\"\n",
157
+ "\n",
158
+ "# JavaScript for enabling colResizable\n",
159
+ "js_script = \"\"\"\n",
160
+ "<script src=\"https://code.jquery.com/jquery-3.6.0.min.js\"></script>\n",
161
+ "<script src=\"https://cdn.jsdelivr.net/npm/[email protected]/colResizable-1.6.min.js\"></script>\n",
162
+ "<script>\n",
163
+ "document.addEventListener('DOMContentLoaded', function() {\n",
164
+ " var interval = setInterval(function() {\n",
165
+ " var table = document.querySelector('#output_df table');\n",
166
+ " if (table && typeof jQuery !== 'undefined' && typeof jQuery(table).colResizable === 'function') {\n",
167
+ " jQuery('#output_df table').colResizable({liveDrag:true});\n",
168
+ " clearInterval(interval);\n",
169
+ " }\n",
170
+ " }, 500);\n",
171
+ "});\n",
172
+ "</script>\n",
173
+ "\"\"\"\n",
174
+ "\n",
175
+ "with gr.Blocks(css=custom_css) as demo:\n",
176
+ " gr.HTML(\"<h3>Clinical Trial Matcher</h3>\")\n",
177
+ " patient_summary_input = gr.Textbox(label=\"Enter Patient Summary\", elem_id=\"input_box\")\n",
178
+ " submit_btn = gr.Button(\"Find Matches\")\n",
179
+ " output_df = gr.DataFrame(\n",
180
+ " headers=[\n",
181
+ " \"nct_id\", \n",
182
+ " \"trial_title\", \n",
183
+ " \"trial_brief_summary\", \n",
184
+ " \"trial_eligibility_criteria\", \n",
185
+ " \"trial_checker_result\", \n",
186
+ " \"trial_checker_score\"\n",
187
+ " ], \n",
188
+ " elem_id=\"output_df\"\n",
189
+ " )\n",
190
+ "\n",
191
+ " submit_btn.click(fn=match_clinical_trials, \n",
192
+ " inputs=patient_summary_input, \n",
193
+ " outputs=output_df)\n",
194
+ " \n",
195
+ " gr.HTML(js_script)\n",
196
+ "\n",
197
+ "demo.launch()\n"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 10,
203
+ "id": "80ba3cd2-6a76-44d0-b5f4-6d3debd510ff",
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "name": "stdout",
208
+ "output_type": "stream",
209
+ "text": [
210
+ "Closing server running on port: 7860\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "demo.close()"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "execution_count": null,
221
+ "id": "5e43df71-6f06-48d2-8dce-b1f27ab40d6c",
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": []
225
+ }
226
+ ],
227
+ "metadata": {
228
+ "kernelspec": {
229
+ "display_name": "Python 3 (ipykernel)",
230
+ "language": "python",
231
+ "name": "python3"
232
+ },
233
+ "language_info": {
234
+ "codemirror_mode": {
235
+ "name": "ipython",
236
+ "version": 3
237
+ },
238
+ "file_extension": ".py",
239
+ "mimetype": "text/x-python",
240
+ "name": "python",
241
+ "nbconvert_exporter": "python",
242
+ "pygments_lexer": "ipython3",
243
+ "version": "3.9.18"
244
+ }
245
+ },
246
+ "nbformat": 4,
247
+ "nbformat_minor": 5
248
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ pandas
3
+ torch
4
+ sentence_transformers
5
+ safetensors
6
+ transformers
trial_space_embeddings.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c03f700caf34d1c30fdf07718fa79b86f16601275fa8fb0372cdd8954bc8c9b2
3
+ size 156221536