KhadijaAsehnoune12 commited on
Commit
c9d6d7a
·
verified ·
1 Parent(s): 0e6738f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -3,13 +3,14 @@ import torch
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
  from PIL import Image
5
  import numpy as np
 
6
 
7
- # Définir le modèle et le feature extractor
8
  model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
9
  model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
10
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
11
 
12
- # Définir la mappage des labels
13
  id2label = {
14
  "0": "Aleurocanthus spiniferus",
15
  "1": "Chancre citrique",
@@ -23,34 +24,52 @@ id2label = {
23
  "9": "Trou de balle"
24
  }
25
 
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def predict(image):
29
-
 
30
 
31
- # Prétraiter l'image
32
  inputs = feature_extractor(images=image, return_tensors="pt")
33
 
34
- # Passage en avant dans le modèle
35
  outputs = model(**inputs)
36
 
37
- # Obtenir les logits
38
  logits = outputs.logits
39
 
40
- # Calculer les scores de confiance avec softmax
41
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
42
 
43
- # Obtenir l'indice de la classe la plus probable
44
  predicted_class_idx = probs.argmax().item()
45
 
46
- # Obtenir le label et le score de confiance de la classe la plus probable
47
  predicted_label = id2label[str(predicted_class_idx)]
48
- confidence_score = probs[predicted_class_idx].item() * 100 # Multiplie par 100 pour obtenir un pourcentage
49
 
50
- # Retourner le label et le score de confiance
51
  return f"{predicted_label}: {confidence_score:.2f}%"
52
 
53
- # Créer l'interface Gradio
54
  image = gr.Image(type="pil")
55
  label = gr.Textbox(label="Prediction")
56
 
@@ -59,4 +78,4 @@ gr.Interface(fn=predict,
59
  outputs=label,
60
  title="Classification des maladies des agrumes",
61
  description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
62
- examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)
 
3
  from transformers import ViTFeatureExtractor, ViTForImageClassification
4
  from PIL import Image
5
  import numpy as np
6
+ import rembg
7
 
8
+ # Define the model and feature extractor
9
  model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
10
  model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
11
  feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
12
 
13
+ # Define the label mapping
14
  id2label = {
15
  "0": "Aleurocanthus spiniferus",
16
  "1": "Chancre citrique",
 
24
  "9": "Trou de balle"
25
  }
26
 
27
+ def remove_background(image):
28
+ # Convert the image to RGBA
29
+ image = image.convert("RGBA")
30
+
31
+ # Remove the background
32
+ image_np = np.array(image)
33
+ output_np = rembg.remove(image_np)
34
+
35
+ # Create a white background image
36
+ white_bg = Image.new("RGBA", image.size, "WHITE")
37
+
38
+ # Composite the original image over the white background
39
+ output_image = Image.alpha_composite(white_bg, Image.fromarray(output_np))
40
+
41
+ # Convert back to RGB
42
+ output_image = output_image.convert("RGB")
43
+
44
+ return output_image
45
 
46
  def predict(image):
47
+ # Remove the background
48
+ image = remove_background(image)
49
 
50
+ # Preprocess the image
51
  inputs = feature_extractor(images=image, return_tensors="pt")
52
 
53
+ # Forward pass through the model
54
  outputs = model(**inputs)
55
 
56
+ # Get the logits
57
  logits = outputs.logits
58
 
59
+ # Calculate confidence scores with softmax
60
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
61
 
62
+ # Get the index of the most probable class
63
  predicted_class_idx = probs.argmax().item()
64
 
65
+ # Get the label and confidence score of the most probable class
66
  predicted_label = id2label[str(predicted_class_idx)]
67
+ confidence_score = probs[predicted_class_idx].item() * 100 # Multiply by 100 to get a percentage
68
 
69
+ # Return the label and confidence score
70
  return f"{predicted_label}: {confidence_score:.2f}%"
71
 
72
+ # Create the Gradio interface
73
  image = gr.Image(type="pil")
74
  label = gr.Textbox(label="Prediction")
75
 
 
78
  outputs=label,
79
  title="Classification des maladies des agrumes",
80
  description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
81
+ examples=["maladie_du_dragon_jaune.jpg", "critique.jpg", "feuille_saine.jpg"]).launch(share=True)