File size: 2,944 Bytes
8d12dce
bc18618
 
 
5da6a11
bc18618
1af1578
efac948
6bb1d3a
bc18618
 
1af1578
bc18618
 
 
 
 
 
 
 
 
 
 
 
 
5da6a11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc18618
5da6a11
 
 
7e0ee60
bc18618
 
7e0ee60
bc18618
 
7e0ee60
bc18618
6bb1d3a
7e0ee60
 
6bb1d3a
b26ba26
 
7e0ee60
b26ba26
 
da42d5c
7e0ee60
b26ba26
da42d5c
fc5c4f6
1af1578
9da9b42
 
fc5c4f6
00b5ecc
1af1578
 
 
 
5da6a11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np

# Définir le modèle et le feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Définir la mappage des labels
id2label = {
    "0": "Aleurocanthus spiniferus",
    "1": "Chancre citrique",
    "2": "Cochenille blanche",
    "3": "Dépérissement des agrumes",
    "4": "Feuille saine",
    "5": "Jaunissement des feuilles",
    "6": "Maladie de l'oïdium",
    "7": "Maladie du dragon jaune",
    "8": "Mineuse des agrumes",
    "9": "Trou de balle"
}

def remove_background(image):
    # Convert image to numpy array
    img_array = np.array(image)
    
    # Calculate the threshold to separate background from foreground
    threshold = np.mean(img_array) + 2 * np.std(img_array)
    
    # Create a mask to separate background from foreground
    mask = img_array > threshold
    
    # Replace background with white
    img_array[mask] = 255
    
    # Convert back to PIL image
    image = Image.fromarray(img_array)
    
    return image

def predict(image):
    # Remove background and replace with white
    image = remove_background(image)
    
    # Prétraiter l'image
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # Passage en avant dans le modèle
    outputs = model(**inputs)
    
    # Obtenir les logits
    logits = outputs.logits
    
    # Calculer les scores de confiance avec softmax
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    # Obtenir l'indice de la classe la plus probable
    predicted_class_idx = probs.argmax().item()
    
    # Obtenir le label et le score de confiance de la classe la plus probable
    predicted_label = id2label[str(predicted_class_idx)]
    confidence_score = probs[predicted_class_idx].item() * 100  # Multiplie par 100 pour obtenir un pourcentage
    
    # Retourner le label et le score de confiance
    return f"{predicted_label}: {confidence_score:.2f}%"

# Créer l'interface Gradio
image = gr.Image(type="pil") 
label = gr.Textbox(label="Prediction")

gr.Interface(fn=predict,
             inputs=image,
             outputs=label,
             title="Classification des maladies des agrumes",
             description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
             examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)