File size: 2,344 Bytes
8d12dce
bc18618
 
 
 
1af1578
efac948
6bb1d3a
bc18618
 
1af1578
bc18618
 
 
 
 
 
 
 
 
 
 
 
 
 
7e0ee60
bc18618
 
7e0ee60
bc18618
 
7e0ee60
bc18618
6bb1d3a
7e0ee60
 
6bb1d3a
b26ba26
 
7e0ee60
b26ba26
 
da42d5c
7e0ee60
b26ba26
da42d5c
fc5c4f6
1af1578
9da9b42
 
fc5c4f6
00b5ecc
1af1578
 
 
 
f0ffeca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image

# Définir le modèle et le feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Définir la mappage des labels
id2label = {
    "0": "Aleurocanthus spiniferus",
    "1": "Chancre citrique",
    "2": "Cochenille blanche",
    "3": "Dépérissement des agrumes",
    "4": "Feuille saine",
    "5": "Jaunissement des feuilles",
    "6": "Maladie de l'oïdium",
    "7": "Maladie du dragon jaune",
    "8": "Mineuse des agrumes",
    "9": "Trou de balle"
}

def predict(image):
    # Prétraiter l'image
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Passage en avant dans le modèle
    outputs = model(**inputs)
    
    # Obtenir les logits
    logits = outputs.logits
    
    # Calculer les scores de confiance avec softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    # Obtenir l'indice de la classe la plus probable
    predicted_class_idx = probs.argmax().item()
    
    # Obtenir le label et le score de confiance de la classe la plus probable
    predicted_label = id2label[str(predicted_class_idx)]
    confidence_score = probs[predicted_class_idx].item() * 100  # Multiplie par 100 pour obtenir un pourcentage
    
    # Retourner le label et le score de confiance
    return f"{predicted_label}: {confidence_score:.2f}%"

# Créer l'interface Gradio
image = gr.Image(type="pil") 
label = gr.Textbox(label="Prediction")

gr.Interface(fn=predict,
             inputs=image,
             outputs=label,
             title="Classification des maladies des agrumes",
             description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
             examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)