import gradio as gr import spaces import torch from PIL import Image from pathlib import Path from main import load_and_preprocess_image, genomic_plip_predictions, classify_tiles @spaces.GPU def run_load_and_preprocess_image(image_path, clip_processor_path): image_tensor = load_and_preprocess_image(image_path, clip_processor_path) if image_tensor is not None: image_tensor = image_tensor.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) return image_tensor @spaces.GPU def run_genomic_plip_predictions(image_tensor, model_path): if image_tensor is None: return None # Handle the case where image_tensor is None pred_data = genomic_plip_predictions(image_tensor, model_path) return pred_data @spaces.GPU def run_classify_tiles(pred_data, model_path): if pred_data is None: return "Error: Prediction data is None." # Handle the case where pred_data is None output = classify_tiles(pred_data, model_path) return output example_files = list(Path("sample_tiles").glob("*.jpeg")) with gr.Blocks() as demo: gr.Markdown("## Cancer Risk Prediction from Tissue Slide") image_file = gr.Image(type="filepath", label="Upload Image File") clip_processor_path = gr.Textbox(label="CLIP Processor Path", value="./genomic_plip_model") genomic_plip_model_path = gr.Textbox(label="Genomic PLIP Model Path", value="./genomic_plip_model") classifier_model_path = gr.Textbox(label="Classifier Model Path", value="./models/classifier.pth") image_tensor_output = gr.State() pred_data_output = gr.State() result_output = gr.Textbox(label="Result") preprocess_button = gr.Button("Preprocess Image") predict_button = gr.Button("Features with Genomic PLIP") classify_button = gr.Button("Identify Risk") preprocess_button.click(run_load_and_preprocess_image, inputs=[image_file, clip_processor_path], outputs=image_tensor_output) predict_button.click(run_genomic_plip_predictions, inputs=[image_tensor_output, genomic_plip_model_path], outputs=pred_data_output) classify_button.click(run_classify_tiles, inputs=[pred_data_output, classifier_model_path], outputs=result_output) gr.Markdown("## Step by Step Workflow") with gr.Row(): preprocess_status = gr.Checkbox(label="Preprocessed Image") predict_status = gr.Checkbox(label="Features with Genomic PLIP") classify_status = gr.Checkbox(label="Identify Risk") def update_status(status, result): return status, result is not None preprocess_button.click(update_status, inputs=[preprocess_status, image_tensor_output], outputs=preprocess_status) predict_button.click(update_status, inputs=[predict_status, pred_data_output], outputs=predict_status) classify_button.click(update_status, inputs=[classify_status, result_output], outputs=classify_status) gr.Markdown("## Example Images") gr.Examples(example_files, inputs=image_file) demo.launch()