VatsalPatel18's picture
Update app.py
abb70d4 verified
raw
history blame contribute delete
No virus
2.98 kB
import gradio as gr
import spaces
import torch
from PIL import Image
from pathlib import Path
from main import load_and_preprocess_image, genomic_plip_predictions, classify_tiles
@spaces.GPU
def run_load_and_preprocess_image(image_path, clip_processor_path):
image_tensor = load_and_preprocess_image(image_path, clip_processor_path)
if image_tensor is not None:
image_tensor = image_tensor.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
return image_tensor
@spaces.GPU
def run_genomic_plip_predictions(image_tensor, model_path):
if image_tensor is None:
return None # Handle the case where image_tensor is None
pred_data = genomic_plip_predictions(image_tensor, model_path)
return pred_data
@spaces.GPU
def run_classify_tiles(pred_data, model_path):
if pred_data is None:
return "Error: Prediction data is None." # Handle the case where pred_data is None
output = classify_tiles(pred_data, model_path)
return output
example_files = list(Path("sample_tiles").glob("*.jpeg"))
with gr.Blocks() as demo:
gr.Markdown("## Cancer Risk Prediction from Tissue Slide")
image_file = gr.Image(type="filepath", label="Upload Image File")
clip_processor_path = gr.Textbox(label="CLIP Processor Path", value="./genomic_plip_model")
genomic_plip_model_path = gr.Textbox(label="Genomic PLIP Model Path", value="./genomic_plip_model")
classifier_model_path = gr.Textbox(label="Classifier Model Path", value="./models/classifier.pth")
image_tensor_output = gr.State()
pred_data_output = gr.State()
result_output = gr.Textbox(label="Result")
preprocess_button = gr.Button("Preprocess Image")
predict_button = gr.Button("Features with Genomic PLIP")
classify_button = gr.Button("Identify Risk")
preprocess_button.click(run_load_and_preprocess_image, inputs=[image_file, clip_processor_path], outputs=image_tensor_output)
predict_button.click(run_genomic_plip_predictions, inputs=[image_tensor_output, genomic_plip_model_path], outputs=pred_data_output)
classify_button.click(run_classify_tiles, inputs=[pred_data_output, classifier_model_path], outputs=result_output)
gr.Markdown("## Step by Step Workflow")
with gr.Row():
preprocess_status = gr.Checkbox(label="Preprocessed Image")
predict_status = gr.Checkbox(label="Features with Genomic PLIP")
classify_status = gr.Checkbox(label="Identify Risk")
def update_status(status, result):
return status, result is not None
preprocess_button.click(update_status, inputs=[preprocess_status, image_tensor_output], outputs=preprocess_status)
predict_button.click(update_status, inputs=[predict_status, pred_data_output], outputs=predict_status)
classify_button.click(update_status, inputs=[classify_status, result_output], outputs=classify_status)
gr.Markdown("## Example Images")
gr.Examples(example_files, inputs=image_file)
demo.launch()