Spaces:
Runtime error
Runtime error
import gradio as gr | |
from torchvision.models import resnet50, ResNet50_Weights | |
import torch.nn as nn | |
import torch | |
import numpy as np | |
from explain import Explainer | |
def create_model_from_checkpoint(): | |
# Loads a model from a checkpoint | |
model = resnet50() | |
model.fc = nn.Linear(model.fc.in_features, 3) | |
model.load_state_dict(torch.load("best_model")) | |
model.eval() | |
return model | |
model = create_model_from_checkpoint() | |
labels = [ "benign", "malignant", "normal" ] | |
def predict(img): | |
explainer = Explainer(model, img, labels) | |
shap_img = explainer.shap() | |
return [explainer.confidences, shap_img] | |
ui = gr.Interface(fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=[gr.Label(num_top_classes=3), gr.Image(type="pil")], | |
examples=["benign (52).png", "benign (243).png", "malignant (127).png", "malignant (201).png", "normal (81).png", "normal (101).png"]).launch() | |
ui.launch(share=True) |