kenlkehl's picture
Upload app.py
1aa4ab6 verified
import gradio as gr
import pandas as pd
import torch
import torch.nn.functional as F
import tempfile
from sentence_transformers import SentenceTransformer
from safetensors import safe_open
from transformers import pipeline, AutoTokenizer
# Load trial spaces data
trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv')
# Load embedding model
embedding_model = SentenceTransformer('ksg-dfci/TrialSpace', trust_remote_code=True)
# Load precomputed trial space embeddings
with safe_open("trial_space_embeddings.safetensors", framework="pt") as f:
trial_space_embeddings = f.get_tensor("space_embeddings")
# Load checker pipeline
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
checker_pipe = pipeline(
'text-classification',
'ksg-dfci/TrialChecker',
tokenizer=tokenizer,
truncation=True,
padding='max_length',
max_length=512
)
def match_clinical_trials_dropdown(patient_summary: str, max_results_str: str):
"""
1) Runs the trial matching logic.
2) Returns a Dropdown (with the matched trials) and a DataFrame (for further use).
3) The user-supplied max_results_str is converted to an int (1-50).
"""
# Parse the max_results input
try:
max_results = int(max_results_str)
except ValueError:
max_results = 10 # if invalid input, default to 10
# Clamp within [1, 50]
if max_results < 1:
max_results = 1
elif max_results > 50:
max_results = 50
# 1. Encode user input
patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True)
# 2. Compute similarities
similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings)
# 3. Pull top 'max_results'
sorted_similarities, sorted_indices = torch.sort(similarities, descending=True)
top_indices = sorted_indices[:max_results].cpu().numpy()
# 4. Build DataFrame
relevant_spaces = trial_spaces.iloc[top_indices].this_space
relevant_nctid = trial_spaces.iloc[top_indices].nct_id
relevant_title = trial_spaces.iloc[top_indices].title
relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary
relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria
analysis = pd.DataFrame({
'patient_summary_query': patient_summary,
'nct_id': relevant_nctid,
'trial_title': relevant_title,
'trial_brief_summary': relevant_brief_summary,
'trial_eligibility_criteria': relevant_eligibility_criteria,
'this_space': relevant_spaces,
}).reset_index(drop=True)
# 5. Prepare for checker pipeline
analysis['pt_trial_pair'] = (
analysis['this_space']
+ "\nNow here is the patient summary:"
+ analysis['patient_summary_query']
)
# 6. Run checker pipeline
classifier_results = checker_pipe(analysis['pt_trial_pair'].tolist())
analysis['trial_checker_result'] = [x['label'] for x in classifier_results]
analysis['trial_checker_score'] = [x['score'] for x in classifier_results]
# 7. Restrict to POSITIVE results only
analysis = analysis[analysis.trial_checker_result == 'POSITIVE'].reset_index(drop=True)
# 8. Final columns
out_df = analysis[[
'patient_summary_query',
'nct_id',
'trial_title',
'trial_brief_summary',
'trial_eligibility_criteria',
'this_space',
'trial_checker_result',
'trial_checker_score'
]]
# Build the dropdown choices, e.g., "1. NCT001 - Some Title"
dropdown_options = []
for i, row in out_df.iterrows():
option_str = f"{i+1}. {row['nct_id']} - {row['trial_title']}"
dropdown_options.append(option_str)
# If we have no results, keep the dropdown empty
if len(dropdown_options) == 0:
return gr.Dropdown(choices=[], interactive=True, value=None), out_df
# Otherwise, pick the first item as the default
return (
gr.Dropdown(choices=dropdown_options, interactive=True, value=dropdown_options[0]),
out_df
)
def show_selected_trial(selected_option: str, df: pd.DataFrame):
"""
1) Given the selected dropdown option, e.g. "1. NCT001 - Some Title"
2) Find the row in df and build a summary string.
"""
if not selected_option:
return ""
# Parse the index from "1. NCT001 - Some Title"
chosen_index_str = selected_option.split(".")[0].strip()
try:
chosen_index = int(chosen_index_str) - 1
except ValueError:
return "No data found for the selected trial."
if chosen_index < 0 or chosen_index >= len(df):
return "No data found for the selected trial."
record = df.iloc[chosen_index].to_dict()
details = (
f"Patient Summary Query: {record['patient_summary_query']}\n\n"
f"NCT ID: {record['nct_id']}\n"
f"Trial Title: {record['trial_title']}\n\n"
f"Trial Space: {record['this_space']}\n\n"
f"Trial Checker Result: {record['trial_checker_result']}\n"
f"Trial Checker Score: {record['trial_checker_score']}\n\n"
f"Brief Summary: {record['trial_brief_summary']}\n\n"
f"Full Eligibility Criteria: {record['trial_eligibility_criteria']}\n\n"
)
return details
def export_results(df: pd.DataFrame):
"""
Saves the DataFrame to a temporary CSV file so Gradio can provide a downloadable link.
"""
temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
df.to_csv(temp.name, index=False)
return temp.name
# A little CSS for the input boxes
custom_css = """
#input_box textarea {
width: 600px !important;
height: 250px !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
# Intro text
gr.HTML("""
<h3>Demonstration version of clinical trial search based on MatchMiner-AI</h3>
<p>Based on clinicaltrials.gov cancer trials export 10/31/24.</p>
<p>Queries take approximately 30 seconds to run per ten results returned,
since demo is running on a small CPU instance.</p>
<p>Disclaimers:</p>
<p>1. Not a clinical decision support tool. Queries are not saved, but do not input protected health information.</p>
<p>2. AI-extracted trial "spaces" and candidate matches may contain errors</p>
<p>3. Will not necessarily return all trials from clinicaltrials.gov that match a given query</p>
<p>4. Under active development; interface and underlying models will change</p>
<p>5. For better results, spell out cancer types (eg, enter "acute myeloid leukemia" rather than "AML")
""")
# Textbox for patient summary
patient_summary_input = gr.Textbox(
label="Enter Patient Summary",
elem_id="input_box",
value="Cancer type: Non-small cell lung cancer. Histology: Adenocarcinoma. Extent of disease: Metastatic. Prior treatment: Pembrolizumab. Biomarkers: PD-L1 high, KRAS G12C mutant."
)
# Textbox for max results
max_results_input = gr.Textbox(
label="Enter the maximum number of results to return (1-50)",
value="10" # default
)
# Button to run the matching
submit_btn = gr.Button("Find Matches")
# We'll store the DataFrame in a State for CSV export + reference
results_state = gr.State()
# Dropdown (initially empty)
trial_dropdown = gr.Dropdown(
label="Select a Trial",
choices=[],
value=None,
interactive=True
)
# Textbox for showing details of the selected trial
trial_details_box = gr.Textbox(
label="Selected Trial Details",
lines=12,
interactive=False
)
# Export button + file
export_btn = gr.Button("Export Results")
download_file = gr.File()
# 1) "Find Matches" => updates the dropdown choices and the state
submit_btn.click(
fn=match_clinical_trials_dropdown,
inputs=[patient_summary_input, max_results_input],
outputs=[trial_dropdown, results_state]
)
# 2) Selecting from the dropdown => shows more info
trial_dropdown.change(
fn=show_selected_trial,
inputs=[trial_dropdown, results_state],
outputs=trial_details_box
)
# 3) Export => CSV
export_btn.click(
fn=export_results,
inputs=results_state,
outputs=download_file
)
# Enable queue so "Processing..." is shown if logic is slow
demo.queue()
if __name__ == "__main__":
demo.launch()