import gradio as gr import torch from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image # Define the model and feature extractor model_name = "KhadijaAsehnoune12/OrangeLeafDiseaseDetector" model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) # Define the label mapping id2label = { "0": "Aleurocanthus spiniferus", "1": "Chancre citrique", "2": "Cochenille blanche", "3": "Dépérissement des agrumes", "4": "Feuille saine", "5": "Jaunissement des feuilles", "6": "Maladie de l'oïdium", "7": "Maladie du dragon jaune", "8": "Mineuse des agrumes", "9": "Trou de balle" } # Define the image slider with examples of each disease image_slider = [ {"image": "aleurocanthus_spiniferus.jpg", "label": "Aleurocanthus spiniferus"}, {"image": "chancre_citrique.jpg", "label": "Chancre citrique"}, {"image": "cochenille_blanche.jpg", "label": "Cochenille blanche"}, {"image": "deperissement_des_agrumes.jpg", "label": "Dépérissement des agrumes"}, {"image": "feuille_saine.jpg", "label": "Feuille saine"}, {"image": "jaunissement_des_feuilles.jpg", "label": "Jaunissement des feuilles"}, {"image": "maladie_de_loidium.jpg", "label": "Maladie de l'oïdium"}, {"image": "maladie_du_dragon_jaune.jpg", "label": "Maladie du dragon jaune"}, {"image": "mineuse_des_agrumes.jpg", "label": "Mineuse des agrumes"}, {"image": "trou_de_balle.jpg", "label": "Trou de balle"} ] def predict(image): # Preprocess the image inputs = feature_extractor(images=image, return_tensors="pt") # Forward pass through the model outputs = model(**inputs) # Get the predicted label and confidence score logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item() # Get the label name predicted_label = id2label[str(predicted_class_idx)] # Return the predicted label and confidence score return predicted_label, f"Confidence: {confidence_score:.2f}" # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("""

Citrus Leaf Disease Classification

Upload an image of a citrus leaf to classify its disease.

Supported diseases:

Example images:

""") image_slider_component = gr.Gallery(image_slider, label="Example Images") image_input = gr.Image(type="pil", label="Upload Image") label_output = gr.Textbox(label="Prediction") btn = gr.Button("Classify") btn.click(fn=predict, inputs=image_input, outputs=label_output) demo.launch()