import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import gradio as gr import plotly.graph_objects as go from transformers import ViTImageProcessor, ViTForImageClassification device = "cuda" if torch.cuda.is_available() else "cpu" # Load the ViT model and image processor model_name = "google/vit-base-patch16-384" image_processor = ViTImageProcessor.from_pretrained(model_name) class CustomViTModel(nn.Module): def __init__(self, dropout_rate=0.5): super(CustomViTModel, self).__init__() self.vit = ViTForImageClassification.from_pretrained(model_name) # Replace the classifier num_features = self.vit.config.hidden_size self.vit.classifier = nn.Identity() # Remove the original classifier self.classifier = nn.Sequential( nn.Dropout(dropout_rate), nn.Linear(num_features, 128), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(128, 1) # Single output for binary classification ) def forward(self, pixel_values): outputs = self.vit(pixel_values) x = outputs.logits x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1) x = self.classifier(x) return x.squeeze() # Load the trained model model = CustomViTModel() model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_224_bce.pth', map_location=device)) model.to(device) model.eval() def predict(image): # Preprocess the image img = Image.fromarray(image.astype('uint8'), 'RGB') img = img.resize((384, 384)) inputs = image_processor(images=img, return_tensors="pt") pixel_values = inputs['pixel_values'].to(device) # Make prediction with torch.no_grad(): output = model(pixel_values) probability = torch.sigmoid(output).item() # Prepare results ms_prob = probability non_ms_prob = 1 - probability # Create the bar chart using Plotly fig = go.Figure(data=[ go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'), go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red') ]) fig.update_layout( title='Prediction Probabilities', yaxis_title='Probability (%)', barmode='group', yaxis=dict(range=[0, 100]) ) prediction = "MS" if ms_prob > 0.5 else "Non-MS" confidence = max(ms_prob, non_ms_prob) * 100 result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%" return result_text, fig iface = gr.Interface( fn=predict, inputs=gr.Image(), outputs=[gr.Textbox(), gr.Plot()], title="MS Prediction", description="Upload an image to predict whether it shows MS or Non-MS.", ) iface.launch()