import gradio as gr from torchvision.models import resnet50, ResNet50_Weights import torch.nn as nn import torch import numpy as np from explain import Explainer @staticmethod def create_model_from_checkpoint(): # Loads a model from a checkpoint model = resnet50() model.fc = nn.Linear(model.fc.in_features, 3) model.load_state_dict(torch.load("best_model")) model.eval() return model model = create_model_from_checkpoint() labels = [ "benign", "malignant", "normal" ] def predict(img): explainer = Explainer(model, img, labels) shap_img = explainer.shap() return [explainer.confidences, shap_img] ui = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")], examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch() ui.launch(share=True)