|
import gradio as gr |
|
import spaces |
|
import torch |
|
from PIL import Image |
|
from pathlib import Path |
|
from main import load_and_preprocess_image, genomic_plip_predictions, classify_tiles |
|
|
|
@spaces.GPU |
|
def run_load_and_preprocess_image(image_path, clip_processor_path): |
|
image_tensor = load_and_preprocess_image(image_path, clip_processor_path) |
|
if image_tensor is not None: |
|
image_tensor = image_tensor.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) |
|
return image_tensor |
|
|
|
@spaces.GPU |
|
def run_genomic_plip_predictions(image_tensor, model_path): |
|
if image_tensor is None: |
|
return None |
|
pred_data = genomic_plip_predictions(image_tensor, model_path) |
|
return pred_data |
|
|
|
@spaces.GPU |
|
def run_classify_tiles(pred_data, model_path): |
|
if pred_data is None: |
|
return "Error: Prediction data is None." |
|
output = classify_tiles(pred_data, model_path) |
|
return output |
|
|
|
|
|
example_files = list(Path("sample_tiles").glob("*.jpeg")) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Cancer Risk Prediction from Tissue Slide") |
|
|
|
image_file = gr.Image(type="filepath", label="Upload Image File") |
|
clip_processor_path = gr.Textbox(label="CLIP Processor Path", value="./genomic_plip_model") |
|
genomic_plip_model_path = gr.Textbox(label="Genomic PLIP Model Path", value="./genomic_plip_model") |
|
classifier_model_path = gr.Textbox(label="Classifier Model Path", value="./models/classifier.pth") |
|
|
|
image_tensor_output = gr.State() |
|
pred_data_output = gr.State() |
|
result_output = gr.Textbox(label="Result") |
|
|
|
preprocess_button = gr.Button("Preprocess Image") |
|
predict_button = gr.Button("Features with Genomic PLIP") |
|
classify_button = gr.Button("Identify Risk") |
|
|
|
preprocess_button.click(run_load_and_preprocess_image, inputs=[image_file, clip_processor_path], outputs=image_tensor_output) |
|
predict_button.click(run_genomic_plip_predictions, inputs=[image_tensor_output, genomic_plip_model_path], outputs=pred_data_output) |
|
classify_button.click(run_classify_tiles, inputs=[pred_data_output, classifier_model_path], outputs=result_output) |
|
|
|
gr.Markdown("## Step by Step Workflow") |
|
with gr.Row(): |
|
preprocess_status = gr.Checkbox(label="Preprocessed Image") |
|
predict_status = gr.Checkbox(label="Features with Genomic PLIP") |
|
classify_status = gr.Checkbox(label="Identify Risk") |
|
|
|
def update_status(status, result): |
|
return status, result is not None |
|
|
|
preprocess_button.click(update_status, inputs=[preprocess_status, image_tensor_output], outputs=preprocess_status) |
|
predict_button.click(update_status, inputs=[predict_status, pred_data_output], outputs=predict_status) |
|
classify_button.click(update_status, inputs=[classify_status, result_output], outputs=classify_status) |
|
|
|
gr.Markdown("## Example Images") |
|
gr.Examples(example_files, inputs=image_file) |
|
|
|
demo.launch() |