reab5555 commited on
Commit
f2740f6
1 Parent(s): 07670e2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import plotly.graph_objects as go
7
+ from transformers import ViTImageProcessor, ViTForImageClassification
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load the ViT model and image processor
12
+ model_name = "google/vit-base-patch16-384"
13
+ image_processor = ViTImageProcessor.from_pretrained(model_name)
14
+
15
+ class CustomViTModel(nn.Module):
16
+ def __init__(self, dropout_rate=0.5):
17
+ super(CustomViTModel, self).__init__()
18
+ self.vit = ViTForImageClassification.from_pretrained(model_name)
19
+
20
+ # Replace the classifier
21
+ num_features = self.vit.config.hidden_size
22
+ self.vit.classifier = nn.Identity() # Remove the original classifier
23
+
24
+ self.classifier = nn.Sequential(
25
+ nn.Dropout(dropout_rate),
26
+ nn.Linear(num_features, 128),
27
+ nn.ReLU(),
28
+ nn.Dropout(dropout_rate),
29
+ nn.Linear(128, 1) # Single output for binary classification
30
+ )
31
+
32
+ def forward(self, pixel_values):
33
+ outputs = self.vit(pixel_values)
34
+ x = outputs.logits
35
+ x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1)
36
+ x = self.classifier(x)
37
+ return x.squeeze()
38
+
39
+ # Load the trained model
40
+ model = CustomViTModel()
41
+ model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_224_bce.pth', map_location=device))
42
+ model.to(device)
43
+ model.eval()
44
+
45
+ def predict(image):
46
+ # Preprocess the image
47
+ img = Image.fromarray(image.astype('uint8'), 'RGB')
48
+ img = img.resize((384, 384))
49
+ inputs = image_processor(images=img, return_tensors="pt")
50
+ pixel_values = inputs['pixel_values'].to(device)
51
+
52
+ # Make prediction
53
+ with torch.no_grad():
54
+ output = model(pixel_values)
55
+ probability = torch.sigmoid(output).item()
56
+
57
+ # Prepare results
58
+ ms_prob = probability
59
+ non_ms_prob = 1 - probability
60
+
61
+ # Create the bar chart using Plotly
62
+ fig = go.Figure(data=[
63
+ go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
64
+ go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
65
+ ])
66
+
67
+ fig.update_layout(
68
+ title='Prediction Probabilities',
69
+ yaxis_title='Probability (%)',
70
+ barmode='group',
71
+ yaxis=dict(range=[0, 100])
72
+ )
73
+
74
+ prediction = "MS" if ms_prob > 0.5 else "Non-MS"
75
+ confidence = max(ms_prob, non_ms_prob) * 100
76
+
77
+ result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
78
+
79
+ return result_text, fig
80
+
81
+ iface = gr.Interface(
82
+ fn=predict,
83
+ inputs=gr.Image(),
84
+ outputs=[gr.Textbox(), gr.Plot()],
85
+ title="MS Prediction",
86
+ description="Upload an image to predict whether it shows MS or Non-MS.",
87
+ )
88
+
89
+ iface.launch()