Update app.py
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
4 |
from PIL import Image
|
|
|
5 |
|
6 |
# Définir le modèle et le feature extractor
|
7 |
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
|
@@ -22,7 +23,28 @@ id2label = {
|
|
22 |
"9": "Trou de balle"
|
23 |
}
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def predict(image):
|
|
|
|
|
|
|
26 |
# Prétraiter l'image
|
27 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
28 |
|
@@ -54,4 +76,4 @@ gr.Interface(fn=predict,
|
|
54 |
outputs=label,
|
55 |
title="Classification des maladies des agrumes",
|
56 |
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
|
57 |
-
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)
|
|
|
2 |
import torch
|
3 |
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
4 |
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
|
7 |
# Définir le modèle et le feature extractor
|
8 |
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
|
|
|
23 |
"9": "Trou de balle"
|
24 |
}
|
25 |
|
26 |
+
def remove_background(image):
|
27 |
+
# Convert image to numpy array
|
28 |
+
img_array = np.array(image)
|
29 |
+
|
30 |
+
# Calculate the threshold to separate background from foreground
|
31 |
+
threshold = np.mean(img_array) + 2 * np.std(img_array)
|
32 |
+
|
33 |
+
# Create a mask to separate background from foreground
|
34 |
+
mask = img_array > threshold
|
35 |
+
|
36 |
+
# Replace background with white
|
37 |
+
img_array[mask] = 255
|
38 |
+
|
39 |
+
# Convert back to PIL image
|
40 |
+
image = Image.fromarray(img_array)
|
41 |
+
|
42 |
+
return image
|
43 |
+
|
44 |
def predict(image):
|
45 |
+
# Remove background and replace with white
|
46 |
+
image = remove_background(image)
|
47 |
+
|
48 |
# Prétraiter l'image
|
49 |
inputs = feature_extractor(images=image, return_tensors="pt")
|
50 |
|
|
|
76 |
outputs=label,
|
77 |
title="Classification des maladies des agrumes",
|
78 |
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
|
79 |
+
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)
|