KhadijaAsehnoune12's picture
Update app.py
543d7d6 verified
raw
history blame
3.27 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
# Define the model and feature extractor
model_name = "KhadijaAsehnoune12/OrangeLeafDiseaseDetector"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Define the label mapping
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
# Define the image slider with examples of each disease
image_slider = [
{"image": "aleurocanthus_spiniferus.jpg", "label": "Aleurocanthus spiniferus"},
{"image": "chancre_citrique.jpg", "label": "Chancre citrique"},
{"image": "cochenille_blanche.jpg", "label": "Cochenille blanche"},
{"image": "deperissement_des_agrumes.jpg", "label": "Dépérissement des agrumes"},
{"image": "feuille_saine.jpg", "label": "Feuille saine"},
{"image": "jaunissement_des_feuilles.jpg", "label": "Jaunissement des feuilles"},
{"image": "maladie_de_loidium.jpg", "label": "Maladie de l'oïdium"},
{"image": "maladie_du_dragon_jaune.jpg", "label": "Maladie du dragon jaune"},
{"image": "mineuse_des_agrumes.jpg", "label": "Mineuse des agrumes"},
{"image": "trou_de_balle.jpg", "label": "Trou de balle"}
]
def predict(image):
# Preprocess the image
inputs = feature_extractor(images=image, return_tensors="pt")
# Forward pass through the model
outputs = model(**inputs)
# Get the predicted label and confidence score
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
# Get the label name
predicted_label = id2label[str(predicted_class_idx)]
# Return the predicted label and confidence score
return predicted_label, f"Confidence: {confidence_score:.2f}"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
<h1>Citrus Leaf Disease Classification</h1>
<p>Upload an image of a citrus leaf to classify its disease.</p>
<p>Supported diseases:</p>
<ul>
<li>Aleurocanthus spiniferus</li>
<li>Chancre citrique</li>
<li>Cochenille blanche</li>
<li>Dépérissement des agrumes</li>
<li>Feuille saine</li>
<li>Jaunissement des feuilles</li>
<li>Maladie de l'oïdium</li>
<li>Maladie du dragon jaune</li>
<li>Mineuse des agrumes</li>
<li>Trou de balle</li>
</ul>
<p>Example images:</p>
""")
image_slider_component = gr.Gallery(image_slider, label="Example Images")
image_input = gr.Image(type="pil", label="Upload Image")
label_output = gr.Textbox(label="Prediction")
btn = gr.Button("Classify")
btn.click(fn=predict, inputs=image_input, outputs=label_output)
demo.launch()