Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import ViTFeatureExtractor, ViTForImageClassification | |
from PIL import Image | |
import requests | |
# Load a pre-trained Vision Transformer model from Hugging Face | |
model_name = "nateraw/vit-base-patch16-224-in21k" # Replace with the model you've trained or a similar model | |
model = ViTForImageClassification.from_pretrained(model_name) | |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
# Define the disease labels (placeholders) | |
labels = { | |
0: "Healthy", | |
1: "Tobacco Mosaic Virus", | |
2: "Brown Spot", | |
3: "Frog Eye Leaf Spot", | |
4: "Other" | |
} | |
# Define a function for disease detection | |
def detect_disease(image): | |
# Preprocess the image | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
# Run the model | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class = logits.argmax().item() | |
# Get disease name from label dictionary | |
disease_name = labels.get(predicted_class, "Unknown Disease") | |
return f"Disease Detected: {disease_name}" | |
# Build Gradio Interface | |
title = "Tobacco Leaf Disease Detection" | |
description = """ | |
Upload or take a real-time picture of a tobacco leaf, and the app will detect the disease (if any). | |
""" | |
# Create Gradio interface with camera and real-time processing | |
iface = gr.Interface( | |
fn=detect_disease, | |
inputs=gr.Image(source="camera", type="pil", tool="editor"), | |
outputs="text", | |
title=title, | |
description=description, | |
live=True # Enables real-time processing | |
) | |
# Launch Gradio app | |
iface.launch() | |