Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from PIL import Image | |
import requests | |
import gradio as gr | |
import os | |
# Define the model architecture | |
class BacterialMorphologyClassifier(nn.Module): | |
def __init__(self): | |
super(BacterialMorphologyClassifier, self).__init__() | |
self.feature_extractor = nn.Sequential( | |
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), | |
) | |
self.fc = nn.Sequential( | |
nn.Flatten(), | |
nn.Linear(64 * 56 * 56, 128), | |
nn.ReLU(), | |
nn.Dropout(0.5), | |
nn.Linear(128, 3), | |
nn.Softmax(dim=1), | |
) | |
def forward(self, x): | |
x = self.feature_extractor(x) | |
x = self.fc(x) | |
return x | |
# Load the model | |
MODEL_PATH = "model.pth" | |
model = BacterialMorphologyClassifier() | |
try: | |
# Download the model if it doesn't exist | |
if not os.path.exists(MODEL_PATH): | |
print("Downloading the model...") | |
url = "https://huggingface.co/yolac/BacterialMorphologyClassification/resolve/main/model.pth" | |
response = requests.get(url) | |
with open(MODEL_PATH, "wb") as f: | |
f.write(response.content) | |
# Load the model weights | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))) | |
model.eval() | |
print("Model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading the model: {e}") | |
# Define image preprocessing to match training preprocessing | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to match model input size | |
transforms.ToTensor(), # Convert to a tensor | |
transforms.Normalize(mean=[0, 0, 0], std=[1/255, 1/255, 1/255]), # Scale pixel values to [0, 1] | |
]) | |
# Prediction function | |
def predict(image): | |
try: | |
# Convert the image to a tensor | |
image_tensor = transform(image).unsqueeze(0) | |
# Perform prediction | |
with torch.no_grad(): # Ensure no gradients are calculated | |
output = model(image_tensor) | |
# Class mapping | |
class_labels = {0: 'cocci', 1: 'bacilli', 2: 'spirilla'} | |
# Return the predicted class and confidence | |
predicted_class = class_labels[output.argmax().item()] | |
confidence = output.max().item() # Softmax value as confidence | |
return f"Predicted Class: {predicted_class}\nConfidence: {confidence:.2f}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Define example images | |
examples = [ | |
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20290.jpg"], | |
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%20565.jpg"], | |
["https://huggingface.co/datasets/yolac/BacterialMorphologyClassification/resolve/main/img%208.jpg"], | |
] | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Text(label="Prediction"), | |
title="Bacterial Morphology Classification", | |
description="Upload an image of bacteria to classify it as cocci, bacilli, or spirilla.", | |
examples=examples, | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |