import gradio as gr import pandas as pd import torch import torch.nn.functional as F import tempfile from sentence_transformers import SentenceTransformer from safetensors import safe_open from transformers import pipeline, AutoTokenizer # Load trial spaces data trial_spaces = pd.read_csv('ctgov_all_trials_trial_space_lineitems_10-31-24.csv') # Load embedding model embedding_model = SentenceTransformer('ksg-dfci/TrialSpace', trust_remote_code=True) # Load precomputed trial space embeddings with safe_open("trial_space_embeddings.safetensors", framework="pt") as f: trial_space_embeddings = f.get_tensor("space_embeddings") # Load checker pipeline tokenizer = AutoTokenizer.from_pretrained("roberta-large") checker_pipe = pipeline( 'text-classification', 'ksg-dfci/TrialChecker', tokenizer=tokenizer, truncation=True, padding='max_length', max_length=512 ) def match_clinical_trials_dropdown(patient_summary: str, max_results_str: str): """ 1) Runs the trial matching logic. 2) Returns a Dropdown (with the matched trials) and a DataFrame (for further use). 3) The user-supplied max_results_str is converted to an int (1-50). """ # Parse the max_results input try: max_results = int(max_results_str) except ValueError: max_results = 10 # if invalid input, default to 10 # Clamp within [1, 50] if max_results < 1: max_results = 1 elif max_results > 50: max_results = 50 # 1. Encode user input patient_embedding = embedding_model.encode([patient_summary], convert_to_tensor=True) # 2. Compute similarities similarities = F.cosine_similarity(patient_embedding, trial_space_embeddings) # 3. Pull top 'max_results' sorted_similarities, sorted_indices = torch.sort(similarities, descending=True) top_indices = sorted_indices[:max_results].cpu().numpy() # 4. Build DataFrame relevant_spaces = trial_spaces.iloc[top_indices].this_space relevant_nctid = trial_spaces.iloc[top_indices].nct_id relevant_title = trial_spaces.iloc[top_indices].title relevant_brief_summary = trial_spaces.iloc[top_indices].brief_summary relevant_eligibility_criteria = trial_spaces.iloc[top_indices].eligibility_criteria analysis = pd.DataFrame({ 'patient_summary_query': patient_summary, 'nct_id': relevant_nctid, 'trial_title': relevant_title, 'trial_brief_summary': relevant_brief_summary, 'trial_eligibility_criteria': relevant_eligibility_criteria, 'this_space': relevant_spaces, }).reset_index(drop=True) # 5. Prepare for checker pipeline analysis['pt_trial_pair'] = ( analysis['this_space'] + "\nNow here is the patient summary:" + analysis['patient_summary_query'] ) # 6. Run checker pipeline classifier_results = checker_pipe(analysis['pt_trial_pair'].tolist()) analysis['trial_checker_result'] = [x['label'] for x in classifier_results] analysis['trial_checker_score'] = [x['score'] for x in classifier_results] # 7. Restrict to POSITIVE results only analysis = analysis[analysis.trial_checker_result == 'POSITIVE'].reset_index(drop=True) # 8. Final columns out_df = analysis[[ 'patient_summary_query', 'nct_id', 'trial_title', 'trial_brief_summary', 'trial_eligibility_criteria', 'this_space', 'trial_checker_result', 'trial_checker_score' ]] # Build the dropdown choices, e.g., "1. NCT001 - Some Title" dropdown_options = [] for i, row in out_df.iterrows(): option_str = f"{i+1}. {row['nct_id']} - {row['trial_title']}" dropdown_options.append(option_str) # If we have no results, keep the dropdown empty if len(dropdown_options) == 0: return gr.Dropdown(choices=[], interactive=True, value=None), out_df # Otherwise, pick the first item as the default return ( gr.Dropdown(choices=dropdown_options, interactive=True, value=dropdown_options[0]), out_df ) def show_selected_trial(selected_option: str, df: pd.DataFrame): """ 1) Given the selected dropdown option, e.g. "1. NCT001 - Some Title" 2) Find the row in df and build a summary string. """ if not selected_option: return "" # Parse the index from "1. NCT001 - Some Title" chosen_index_str = selected_option.split(".")[0].strip() try: chosen_index = int(chosen_index_str) - 1 except ValueError: return "No data found for the selected trial." if chosen_index < 0 or chosen_index >= len(df): return "No data found for the selected trial." record = df.iloc[chosen_index].to_dict() details = ( f"Patient Summary Query: {record['patient_summary_query']}\n\n" f"NCT ID: {record['nct_id']}\n" f"Trial Title: {record['trial_title']}\n\n" f"Trial Space: {record['this_space']}\n\n" f"Trial Checker Result: {record['trial_checker_result']}\n" f"Trial Checker Score: {record['trial_checker_score']}\n\n" f"Brief Summary: {record['trial_brief_summary']}\n\n" f"Full Eligibility Criteria: {record['trial_eligibility_criteria']}\n\n" ) return details def export_results(df: pd.DataFrame): """ Saves the DataFrame to a temporary CSV file so Gradio can provide a downloadable link. """ temp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") df.to_csv(temp.name, index=False) return temp.name # A little CSS for the input boxes custom_css = """ #input_box textarea { width: 600px !important; height: 250px !important; } """ with gr.Blocks(css=custom_css) as demo: # Intro text gr.HTML("""
Based on clinicaltrials.gov cancer trials export 10/31/24.
Queries take approximately 30 seconds to run per ten results returned, since demo is running on a small CPU instance.
Disclaimers:
1. Not a clinical decision support tool. Queries are not saved, but do not input protected health information.
2. AI-extracted trial "spaces" and candidate matches may contain errors
3. Will not necessarily return all trials from clinicaltrials.gov that match a given query
4. Under active development; interface and underlying models will change
5. For better results, spell out cancer types (eg, enter "acute myeloid leukemia" rather than "AML") """) # Textbox for patient summary patient_summary_input = gr.Textbox( label="Enter Patient Summary", elem_id="input_box", value="Cancer type: Non-small cell lung cancer. Histology: Adenocarcinoma. Extent of disease: Metastatic. Prior treatment: Pembrolizumab. Biomarkers: PD-L1 high, KRAS G12C mutant." ) # Textbox for max results max_results_input = gr.Textbox( label="Enter the maximum number of results to return (1-50)", value="10" # default ) # Button to run the matching submit_btn = gr.Button("Find Matches") # We'll store the DataFrame in a State for CSV export + reference results_state = gr.State() # Dropdown (initially empty) trial_dropdown = gr.Dropdown( label="Select a Trial", choices=[], value=None, interactive=True ) # Textbox for showing details of the selected trial trial_details_box = gr.Textbox( label="Selected Trial Details", lines=12, interactive=False ) # Export button + file export_btn = gr.Button("Export Results") download_file = gr.File() # 1) "Find Matches" => updates the dropdown choices and the state submit_btn.click( fn=match_clinical_trials_dropdown, inputs=[patient_summary_input, max_results_input], outputs=[trial_dropdown, results_state] ) # 2) Selecting from the dropdown => shows more info trial_dropdown.change( fn=show_selected_trial, inputs=[trial_dropdown, results_state], outputs=trial_details_box ) # 3) Export => CSV export_btn.click( fn=export_results, inputs=results_state, outputs=download_file ) # Enable queue so "Processing..." is shown if logic is slow demo.queue() if __name__ == "__main__": demo.launch()