KhadijaAsehnoune12 commited on
Commit
5da6a11
·
verified ·
1 Parent(s): efac948

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -1
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
  from PIL import Image
 
5
 
6
  # Définir le modèle et le feature extractor
7
  model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
@@ -22,7 +23,28 @@ id2label = {
22
  "9": "Trou de balle"
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def predict(image):
 
 
 
26
  # Prétraiter l'image
27
  inputs = feature_extractor(images=image, return_tensors="pt")
28
 
@@ -54,4 +76,4 @@ gr.Interface(fn=predict,
54
  outputs=label,
55
  title="Classification des maladies des agrumes",
56
  description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
57
- examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)
 
2
  import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
  from PIL import Image
5
+ import numpy as np
6
 
7
  # Définir le modèle et le feature extractor
8
  model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
 
23
  "9": "Trou de balle"
24
  }
25
 
26
+ def remove_background(image):
27
+ # Convert image to numpy array
28
+ img_array = np.array(image)
29
+
30
+ # Calculate the threshold to separate background from foreground
31
+ threshold = np.mean(img_array) + 2 * np.std(img_array)
32
+
33
+ # Create a mask to separate background from foreground
34
+ mask = img_array > threshold
35
+
36
+ # Replace background with white
37
+ img_array[mask] = 255
38
+
39
+ # Convert back to PIL image
40
+ image = Image.fromarray(img_array)
41
+
42
+ return image
43
+
44
  def predict(image):
45
+ # Remove background and replace with white
46
+ image = remove_background(image)
47
+
48
  # Prétraiter l'image
49
  inputs = feature_extractor(images=image, return_tensors="pt")
50
 
 
76
  outputs=label,
77
  title="Classification des maladies des agrumes",
78
  description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
79
+ examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)