Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import gradio as gr | |
import pandas as pd | |
import torch | |
import torch.nn.functional as F | |
import tempfile | |
from sentence_transformers import SentenceTransformer | |
from safetensors import safe_open | |
from transformers import pipeline, AutoTokenizer | |
# Load trial spaces data | |
trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv') | |
# Load embedding model | |
embedding_model = SentenceTransformer('ksg-dfci/TrialSpace', trust_remote_code=True) | |
# Load precomputed trial space embeddings | |
with safe_open("trial_space_embeddings.safetensors", framework="pt") as f: | |
trial_space_embeddings = f.get_tensor("space_embeddings") | |
# Load checker pipeline | |
tokenizer = AutoTokenizer.from_pretrained("roberta-large") | |
checker_pipe = pipeline( | |
'text-classification', | |
'ksg-dfci/TrialChecker', | |
tokenizer=tokenizer, | |
truncation=True, | |
padding='max_length', | |
max_length=512 | |
) | |
def match_clinical_trials_dropdown(patient_summary: str, max_results_str: str): | |
""" | |
1) Runs the trial matching logic. | |
2) Returns a Dropdown (with the matched trials) and a DataFrame (for further use). | |
3) The user-supplied max_results_str is converted to an int (1-50). | |
""" | |
# Parse the max_results input | |
try: | |
max_results = int(max_results_str) | |
except ValueError: | |
max_results = 10 # if invalid input, default to 10 | |
# Clamp within [1, 50] | |
if max_results < 1: | |
max_results = 1 | |
elif max_results > 50: | |
max_results = 50 | |
# 1. Encode user input | |
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True) | |
# 2. Compute similarities | |
similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings) | |
# 3. Pull top 'max_results' | |
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True) | |
top_indices = sorted_indices[:max_results].cpu().numpy() | |
# 4. Build DataFrame | |
relevant_spaces = trial_spaces.iloc[top_indices].this_space | |
relevant_nctid = trial_spaces.iloc[top_indices].nct_id | |
relevant_title = trial_spaces.iloc[top_indices].title | |
relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary | |
relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria | |
analysis = pd.DataFrame({ | |
'patient_summary_query': patient_summary, | |
'nct_id': relevant_nctid, | |
'trial_title': relevant_title, | |
'trial_brief_summary': relevant_brief_summary, | |
'trial_eligibility_criteria': relevant_eligibility_criteria, | |
'this_space': relevant_spaces, | |
}).reset_index(drop=True) | |
# 5. Prepare for checker pipeline | |
analysis['pt_trial_pair'] = ( | |
analysis['this_space'] | |
+ "\nNow here is the patient summary:" | |
+ analysis['patient_summary_query'] | |
) | |
# 6. Run checker pipeline | |
classifier_results = checker_pipe(analysis['pt_trial_pair'].tolist()) | |
analysis['trial_checker_result'] = [x['label'] for x in classifier_results] | |
analysis['trial_checker_score'] = [x['score'] for x in classifier_results] | |
# 7. Restrict to POSITIVE results only | |
analysis = analysis[analysis.trial_checker_result == 'POSITIVE'].reset_index(drop=True) | |
# 8. Final columns | |
out_df = analysis[[ | |
'patient_summary_query', | |
'nct_id', | |
'trial_title', | |
'trial_brief_summary', | |
'trial_eligibility_criteria', | |
'this_space', | |
'trial_checker_result', | |
'trial_checker_score' | |
]] | |
# Build the dropdown choices, e.g., "1. NCT001 - Some Title" | |
dropdown_options = [] | |
for i, row in out_df.iterrows(): | |
option_str = f"{i+1}. {row['nct_id']} - {row['trial_title']}" | |
dropdown_options.append(option_str) | |
# If we have no results, keep the dropdown empty | |
if len(dropdown_options) == 0: | |
return gr.Dropdown(choices=[], interactive=True, value=None), out_df | |
# Otherwise, pick the first item as the default | |
return ( | |
gr.Dropdown(choices=dropdown_options, interactive=True, value=dropdown_options[0]), | |
out_df | |
) | |
def show_selected_trial(selected_option: str, df: pd.DataFrame): | |
""" | |
1) Given the selected dropdown option, e.g. "1. NCT001 - Some Title" | |
2) Find the row in df and build a summary string. | |
""" | |
if not selected_option: | |
return "" | |
# Parse the index from "1. NCT001 - Some Title" | |
chosen_index_str = selected_option.split(".")[0].strip() | |
try: | |
chosen_index = int(chosen_index_str) - 1 | |
except ValueError: | |
return "No data found for the selected trial." | |
if chosen_index < 0 or chosen_index >= len(df): | |
return "No data found for the selected trial." | |
record = df.iloc[chosen_index].to_dict() | |
details = ( | |
f"Patient Summary Query: {record['patient_summary_query']}\n\n" | |
f"NCT ID: {record['nct_id']}\n" | |
f"Trial Title: {record['trial_title']}\n\n" | |
f"Trial Space: {record['this_space']}\n\n" | |
f"Trial Checker Result: {record['trial_checker_result']}\n" | |
f"Trial Checker Score: {record['trial_checker_score']}\n\n" | |
f"Brief Summary: {record['trial_brief_summary']}\n\n" | |
f"Full Eligibility Criteria: {record['trial_eligibility_criteria']}\n\n" | |
) | |
return details | |
def export_results(df: pd.DataFrame): | |
""" | |
Saves the DataFrame to a temporary CSV file so Gradio can provide a downloadable link. | |
""" | |
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
df.to_csv(temp.name, index=False) | |
return temp.name | |
# A little CSS for the input boxes | |
custom_css = """ | |
#input_box textarea { | |
width: 600px !important; | |
height: 250px !important; | |
} | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
# Intro text | |
gr.HTML(""" | |
<h3>Demonstration version of clinical trial search based on MatchMiner-AI</h3> | |
<p>Based on clinicaltrials.gov cancer trials export 10/31/24.</p> | |
<p>Queries take approximately 30 seconds to run per ten results returned, | |
since demo is running on a small CPU instance.</p> | |
<p>Disclaimers:</p> | |
<p>1. Not a clinical decision support tool</p> | |
<p>2. AI-extracted trial "spaces" and candidate matches may contain errors</p> | |
<p>3. Will not necessarily return all trials from clinicaltrials.gov that match a given query</p> | |
<p>4. Under active development; interface and underlying models will change</p> | |
<p>5. For better results, spell out cancer types (eg, enter "acute myeloid leukemia" rather than "AML") | |
""") | |
# Textbox for patient summary | |
patient_summary_input = gr.Textbox( | |
label="Enter Patient Summary", | |
elem_id="input_box", | |
value="Cancer type: Non-small cell lung cancer. Histology: Adenocarcinoma. Extent of disease: Metastatic. Prior treatment: Pembrolizumab. Biomarkers: PD-L1 high, KRAS G12C mutant." | |
) | |
# Textbox for max results | |
max_results_input = gr.Textbox( | |
label="Enter the maximum number of results to return (1-50)", | |
value="10" # default | |
) | |
# Button to run the matching | |
submit_btn = gr.Button("Find Matches") | |
# We'll store the DataFrame in a State for CSV export + reference | |
results_state = gr.State() | |
# Dropdown (initially empty) | |
trial_dropdown = gr.Dropdown( | |
label="Select a Trial", | |
choices=[], | |
value=None, | |
interactive=True | |
) | |
# Textbox for showing details of the selected trial | |
trial_details_box = gr.Textbox( | |
label="Selected Trial Details", | |
lines=12, | |
interactive=False | |
) | |
# Export button + file | |
export_btn = gr.Button("Export Results") | |
download_file = gr.File() | |
# 1) "Find Matches" => updates the dropdown choices and the state | |
submit_btn.click( | |
fn=match_clinical_trials_dropdown, | |
inputs=[patient_summary_input, max_results_input], | |
outputs=[trial_dropdown, results_state] | |
) | |
# 2) Selecting from the dropdown => shows more info | |
trial_dropdown.change( | |
fn=show_selected_trial, | |
inputs=[trial_dropdown, results_state], | |
outputs=trial_details_box | |
) | |
# 3) Export => CSV | |
export_btn.click( | |
fn=export_results, | |
inputs=results_state, | |
outputs=download_file | |
) | |
# Enable queue so "Processing..." is shown if logic is slow | |
demo.queue() | |
if __name__ == "__main__": | |
demo.launch() | |