Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
import gradio as gr
|
6 |
+
import plotly.graph_objects as go
|
7 |
+
from transformers import ViTImageProcessor, ViTForImageClassification
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
+
# Load the ViT model and image processor
|
12 |
+
model_name = "google/vit-base-patch16-384"
|
13 |
+
image_processor = ViTImageProcessor.from_pretrained(model_name)
|
14 |
+
|
15 |
+
class CustomViTModel(nn.Module):
|
16 |
+
def __init__(self, dropout_rate=0.5):
|
17 |
+
super(CustomViTModel, self).__init__()
|
18 |
+
self.vit = ViTForImageClassification.from_pretrained(model_name)
|
19 |
+
|
20 |
+
# Replace the classifier
|
21 |
+
num_features = self.vit.config.hidden_size
|
22 |
+
self.vit.classifier = nn.Identity() # Remove the original classifier
|
23 |
+
|
24 |
+
self.classifier = nn.Sequential(
|
25 |
+
nn.Dropout(dropout_rate),
|
26 |
+
nn.Linear(num_features, 128),
|
27 |
+
nn.ReLU(),
|
28 |
+
nn.Dropout(dropout_rate),
|
29 |
+
nn.Linear(128, 1) # Single output for binary classification
|
30 |
+
)
|
31 |
+
|
32 |
+
def forward(self, pixel_values):
|
33 |
+
outputs = self.vit(pixel_values)
|
34 |
+
x = outputs.logits
|
35 |
+
x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1)
|
36 |
+
x = self.classifier(x)
|
37 |
+
return x.squeeze()
|
38 |
+
|
39 |
+
# Load the trained model
|
40 |
+
model = CustomViTModel()
|
41 |
+
model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_224_bce.pth', map_location=device))
|
42 |
+
model.to(device)
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
def predict(image):
|
46 |
+
# Preprocess the image
|
47 |
+
img = Image.fromarray(image.astype('uint8'), 'RGB')
|
48 |
+
img = img.resize((384, 384))
|
49 |
+
inputs = image_processor(images=img, return_tensors="pt")
|
50 |
+
pixel_values = inputs['pixel_values'].to(device)
|
51 |
+
|
52 |
+
# Make prediction
|
53 |
+
with torch.no_grad():
|
54 |
+
output = model(pixel_values)
|
55 |
+
probability = torch.sigmoid(output).item()
|
56 |
+
|
57 |
+
# Prepare results
|
58 |
+
ms_prob = probability
|
59 |
+
non_ms_prob = 1 - probability
|
60 |
+
|
61 |
+
# Create the bar chart using Plotly
|
62 |
+
fig = go.Figure(data=[
|
63 |
+
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
|
64 |
+
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
|
65 |
+
])
|
66 |
+
|
67 |
+
fig.update_layout(
|
68 |
+
title='Prediction Probabilities',
|
69 |
+
yaxis_title='Probability (%)',
|
70 |
+
barmode='group',
|
71 |
+
yaxis=dict(range=[0, 100])
|
72 |
+
)
|
73 |
+
|
74 |
+
prediction = "MS" if ms_prob > 0.5 else "Non-MS"
|
75 |
+
confidence = max(ms_prob, non_ms_prob) * 100
|
76 |
+
|
77 |
+
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
|
78 |
+
|
79 |
+
return result_text, fig
|
80 |
+
|
81 |
+
iface = gr.Interface(
|
82 |
+
fn=predict,
|
83 |
+
inputs=gr.Image(),
|
84 |
+
outputs=[gr.Textbox(), gr.Plot()],
|
85 |
+
title="MS Prediction",
|
86 |
+
description="Upload an image to predict whether it shows MS or Non-MS.",
|
87 |
+
)
|
88 |
+
|
89 |
+
iface.launch()
|