File size: 3,266 Bytes
8d12dce
bc18618
 
 
 
 
 
6bb1d3a
bc18618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543d7d6
 
 
 
 
 
 
 
 
 
 
 
 
 
bc18618
 
 
 
 
 
 
6bb1d3a
bc18618
 
185be2e
6bb1d3a
bc18618
 
6bb1d3a
185be2e
 
bc18618
543d7d6
 
 
c1508e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543d7d6
c1508e5
543d7d6
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image

# Define the model and feature extractor
model_name = "KhadijaAsehnoune12/OrangeLeafDiseaseDetector"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Define the label mapping
id2label = {
    "0": "Aleurocanthus spiniferus",
    "1": "Chancre citrique",
    "2": "Cochenille blanche",
    "3": "Dépérissement des agrumes",
    "4": "Feuille saine",
    "5": "Jaunissement des feuilles",
    "6": "Maladie de l'oïdium",
    "7": "Maladie du dragon jaune",
    "8": "Mineuse des agrumes",
    "9": "Trou de balle"
}

# Define the image slider with examples of each disease
image_slider = [
    {"image": "aleurocanthus_spiniferus.jpg", "label": "Aleurocanthus spiniferus"},
    {"image": "chancre_citrique.jpg", "label": "Chancre citrique"},
    {"image": "cochenille_blanche.jpg", "label": "Cochenille blanche"},
    {"image": "deperissement_des_agrumes.jpg", "label": "Dépérissement des agrumes"},
    {"image": "feuille_saine.jpg", "label": "Feuille saine"},
    {"image": "jaunissement_des_feuilles.jpg", "label": "Jaunissement des feuilles"},
    {"image": "maladie_de_loidium.jpg", "label": "Maladie de l'oïdium"},
    {"image": "maladie_du_dragon_jaune.jpg", "label": "Maladie du dragon jaune"},
    {"image": "mineuse_des_agrumes.jpg", "label": "Mineuse des agrumes"},
    {"image": "trou_de_balle.jpg", "label": "Trou de balle"}
]

def predict(image):
    # Preprocess the image
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Forward pass through the model
    outputs = model(**inputs)
    
    # Get the predicted label and confidence score
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
    
    # Get the label name
    predicted_label = id2label[str(predicted_class_idx)]
    
    # Return the predicted label and confidence score
    return predicted_label, f"Confidence: {confidence_score:.2f}"

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("""
        <h1>Citrus Leaf Disease Classification</h1>
        <p>Upload an image of a citrus leaf to classify its disease.</p>
        <p>Supported diseases:</p>
        <ul>
            <li>Aleurocanthus spiniferus</li>
            <li>Chancre citrique</li>
            <li>Cochenille blanche</li>
            <li>Dépérissement des agrumes</li>
            <li>Feuille saine</li>
            <li>Jaunissement des feuilles</li>
            <li>Maladie de l'oïdium</li>
            <li>Maladie du dragon jaune</li>
            <li>Mineuse des agrumes</li>
            <li>Trou de balle</li>
        </ul>
        <p>Example images:</p>
    """)
    image_slider_component = gr.Gallery(image_slider, label="Example Images")
    image_input = gr.Image(type="pil", label="Upload Image")
    label_output = gr.Textbox(label="Prediction")
    btn = gr.Button("Classify")
    
    btn.click(fn=predict, inputs=image_input, outputs=label_output)
    
demo.launch()