KhadijaAsehnoune12
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -3,13 +3,14 @@ import torch
|
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
|
|
6 |
|
7 |
-
#
|
8 |
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
|
9 |
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
|
10 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
11 |
|
12 |
-
#
|
13 |
id2label = {
|
14 |
"0": "Aleurocanthus spiniferus",
|
15 |
"1": "Chancre citrique",
|
@@ -23,34 +24,52 @@ id2label = {
|
|
23 |
"9": "Trou de balle"
|
24 |
}
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
def predict(image):
|
29 |
-
|
|
|
30 |
|
31 |
-
#
|
32 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
33 |
|
34 |
-
#
|
35 |
outputs = model(**inputs)
|
36 |
|
37 |
-
#
|
38 |
logits = outputs.logits
|
39 |
|
40 |
-
#
|
41 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
42 |
|
43 |
-
#
|
44 |
predicted_class_idx = probs.argmax().item()
|
45 |
|
46 |
-
#
|
47 |
predicted_label = id2label[str(predicted_class_idx)]
|
48 |
-
confidence_score = probs[predicted_class_idx].item() * 100 #
|
49 |
|
50 |
-
#
|
51 |
return f"{predicted_label}: {confidence_score:.2f}%"
|
52 |
|
53 |
-
#
|
54 |
image = gr.Image(type="pil")
|
55 |
label = gr.Textbox(label="Prediction")
|
56 |
|
@@ -59,4 +78,4 @@ gr.Interface(fn=predict,
|
|
59 |
outputs=label,
|
60 |
title="Classification des maladies des agrumes",
|
61 |
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
|
62 |
-
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)
|
|
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
4 |
from PIL import Image
|
5 |
import numpy as np
|
6 |
+
import rembg
|
7 |
|
8 |
+
# Define the model and feature extractor
|
9 |
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
|
10 |
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
|
11 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
12 |
|
13 |
+
# Define the label mapping
|
14 |
id2label = {
|
15 |
"0": "Aleurocanthus spiniferus",
|
16 |
"1": "Chancre citrique",
|
|
|
24 |
"9": "Trou de balle"
|
25 |
}
|
26 |
|
27 |
+
def remove_background(image):
|
28 |
+
# Convert the image to RGBA
|
29 |
+
image = image.convert("RGBA")
|
30 |
+
|
31 |
+
# Remove the background
|
32 |
+
image_np = np.array(image)
|
33 |
+
output_np = rembg.remove(image_np)
|
34 |
+
|
35 |
+
# Create a white background image
|
36 |
+
white_bg = Image.new("RGBA", image.size, "WHITE")
|
37 |
+
|
38 |
+
# Composite the original image over the white background
|
39 |
+
output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
|
40 |
+
|
41 |
+
# Convert back to RGB
|
42 |
+
output_image = output_image.convert("RGB")
|
43 |
+
|
44 |
+
return output_image
|
45 |
|
46 |
def predict(image):
|
47 |
+
# Remove the background
|
48 |
+
image = remove_background(image)
|
49 |
|
50 |
+
# Preprocess the image
|
51 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
52 |
|
53 |
+
# Forward pass through the model
|
54 |
outputs = model(**inputs)
|
55 |
|
56 |
+
# Get the logits
|
57 |
logits = outputs.logits
|
58 |
|
59 |
+
# Calculate confidence scores with softmax
|
60 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
61 |
|
62 |
+
# Get the index of the most probable class
|
63 |
predicted_class_idx = probs.argmax().item()
|
64 |
|
65 |
+
# Get the label and confidence score of the most probable class
|
66 |
predicted_label = id2label[str(predicted_class_idx)]
|
67 |
+
confidence_score = probs[predicted_class_idx].item() * 100 # Multiply by 100 to get a percentage
|
68 |
|
69 |
+
# Return the label and confidence score
|
70 |
return f"{predicted_label}: {confidence_score:.2f}%"
|
71 |
|
72 |
+
# Create the Gradio interface
|
73 |
image = gr.Image(type="pil")
|
74 |
label = gr.Textbox(label="Prediction")
|
75 |
|
|
|
78 |
outputs=label,
|
79 |
title="Classification des maladies des agrumes",
|
80 |
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
|
81 |
+
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg", "feuille_saine.jpg"]).launch(share=True)
|