import gradio as gr import torch from PIL import Image from pathlib import Path from main import load_and_preprocess_image, genomic_plip_predictions, classify_tiles def run_load_and_preprocess_image(image_path, clip_processor_path): image_tensor = load_and_preprocess_image(image_path, clip_processor_path) return image_tensor def run_genomic_plip_predictions(image_tensor, model_path): pred_data = genomic_plip_predictions(image_tensor, model_path) return pred_data def run_classify_tiles(pred_data, model_path): output = classify_tiles(pred_data, model_path) return output example_files = list(Path("sample_tiles").glob("*.jpeg")) with gr.Blocks() as demo: gr.Markdown("## Cancer Risk Prediction from Tissue Slide") image_file = gr.Image(type="filepath", label="Upload Image File") clip_processor_path = gr.Textbox(label="CLIP Processor Path", value="./genomic_plip_model") genomic_plip_model_path = gr.Textbox(label="Genomic PLIP Model Path", value="./genomic_plip_model") classifier_model_path = gr.Textbox(label="Classifier Model Path", value="./models/classifier.pth") image_tensor_output = gr.Variable() pred_data_output = gr.Variable() result_output = gr.Textbox(label="Result") preprocess_button = gr.Button("Preprocess Image") predict_button = gr.Button("Predict with Genomic PLIP") classify_button = gr.Button("Classify Tiles") preprocess_button.click(run_load_and_preprocess_image, inputs=[image_file, clip_processor_path], outputs=image_tensor_output) predict_button.click(run_genomic_plip_predictions, inputs=[image_tensor_output, genomic_plip_model_path], outputs=pred_data_output) classify_button.click(run_classify_tiles, inputs=[pred_data_output, classifier_model_path], outputs=result_output) gr.Markdown("## Step by Step Workflow") with gr.Row(): preprocess_status = gr.Checkbox(label="Preprocessed Image") predict_status = gr.Checkbox(label="Predicted with Genomic PLIP") classify_status = gr.Checkbox(label="Classified Tiles") def update_status(status, result): return status, result is not None preprocess_button.click(update_status, inputs=[preprocess_status, image_tensor_output], outputs=preprocess_status) predict_button.click(update_status, inputs=[predict_status, pred_data_output], outputs=predict_status) classify_button.click(update_status, inputs=[classify_status, result_output], outputs=classify_status) gr.Markdown("## Example Images") gr.Examples(example_files, inputs=image_file) demo.launch()