reab5555's picture
Create app.py
f2740f6 verified
raw
history blame
No virus
2.83 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import plotly.graph_objects as go
from transformers import ViTImageProcessor, ViTForImageClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the ViT model and image processor
model_name = "google/vit-base-patch16-384"
image_processor = ViTImageProcessor.from_pretrained(model_name)
class CustomViTModel(nn.Module):
def __init__(self, dropout_rate=0.5):
super(CustomViTModel, self).__init__()
self.vit = ViTForImageClassification.from_pretrained(model_name)
# Replace the classifier
num_features = self.vit.config.hidden_size
self.vit.classifier = nn.Identity() # Remove the original classifier
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(num_features, 128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, 1) # Single output for binary classification
)
def forward(self, pixel_values):
outputs = self.vit(pixel_values)
x = outputs.logits
x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1)
x = self.classifier(x)
return x.squeeze()
# Load the trained model
model = CustomViTModel()
model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_224_bce.pth', map_location=device))
model.to(device)
model.eval()
def predict(image):
# Preprocess the image
img = Image.fromarray(image.astype('uint8'), 'RGB')
img = img.resize((384, 384))
inputs = image_processor(images=img, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
# Make prediction
with torch.no_grad():
output = model(pixel_values)
probability = torch.sigmoid(output).item()
# Prepare results
ms_prob = probability
non_ms_prob = 1 - probability
# Create the bar chart using Plotly
fig = go.Figure(data=[
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
])
fig.update_layout(
title='Prediction Probabilities',
yaxis_title='Probability (%)',
barmode='group',
yaxis=dict(range=[0, 100])
)
prediction = "MS" if ms_prob > 0.5 else "Non-MS"
confidence = max(ms_prob, non_ms_prob) * 100
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
return result_text, fig
iface = gr.Interface(
fn=predict,
inputs=gr.Image(),
outputs=[gr.Textbox(), gr.Plot()],
title="MS Prediction",
description="Upload an image to predict whether it shows MS or Non-MS.",
)
iface.launch()