KhadijaAsehnoune12's picture
Update app.py
5da6a11 verified
raw
history blame
2.94 kB
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np
# Définir le modèle et le feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Définir la mappage des labels
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def remove_background(image):
# Convert image to numpy array
img_array = np.array(image)
# Calculate the threshold to separate background from foreground
threshold = np.mean(img_array) + 2 * np.std(img_array)
# Create a mask to separate background from foreground
mask = img_array > threshold
# Replace background with white
img_array[mask] = 255
# Convert back to PIL image
image = Image.fromarray(img_array)
return image
def predict(image):
# Remove background and replace with white
image = remove_background(image)
# Prétraiter l'image
inputs = feature_extractor(images=image, return_tensors="pt")
# Passage en avant dans le modèle
outputs = model(**inputs)
# Obtenir les logits
logits = outputs.logits
# Calculer les scores de confiance avec softmax
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Obtenir l'indice de la classe la plus probable
predicted_class_idx = probs.argmax().item()
# Obtenir le label et le score de confiance de la classe la plus probable
predicted_label = id2label[str(predicted_class_idx)]
confidence_score = probs[predicted_class_idx].item() * 100 # Multiplie par 100 pour obtenir un pourcentage
# Retourner le label et le score de confiance
return f"{predicted_label}: {confidence_score:.2f}%"
# Créer l'interface Gradio
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs=image,
outputs=label,
title="Classification des maladies des agrumes",
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True)