File size: 2,944 Bytes
8d12dce bc18618 5da6a11 bc18618 1af1578 efac948 6bb1d3a bc18618 1af1578 bc18618 5da6a11 bc18618 5da6a11 7e0ee60 bc18618 7e0ee60 bc18618 7e0ee60 bc18618 6bb1d3a 7e0ee60 6bb1d3a b26ba26 7e0ee60 b26ba26 da42d5c 7e0ee60 b26ba26 da42d5c fc5c4f6 1af1578 9da9b42 fc5c4f6 00b5ecc 1af1578 5da6a11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import numpy as np
# Définir le modèle et le feature extractor
model_name ="KhadijaAsehnoune12/ViTOrangeLeafDiseaseClassifier"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Définir la mappage des labels
id2label = {
"0": "Aleurocanthus spiniferus",
"1": "Chancre citrique",
"2": "Cochenille blanche",
"3": "Dépérissement des agrumes",
"4": "Feuille saine",
"5": "Jaunissement des feuilles",
"6": "Maladie de l'oïdium",
"7": "Maladie du dragon jaune",
"8": "Mineuse des agrumes",
"9": "Trou de balle"
}
def remove_background(image):
# Convert image to numpy array
img_array = np.array(image)
# Calculate the threshold to separate background from foreground
threshold = np.mean(img_array) + 2 * np.std(img_array)
# Create a mask to separate background from foreground
mask = img_array > threshold
# Replace background with white
img_array[mask] = 255
# Convert back to PIL image
image = Image.fromarray(img_array)
return image
def predict(image):
# Remove background and replace with white
image = remove_background(image)
# Prétraiter l'image
inputs = feature_extractor(images=image, return_tensors="pt")
# Passage en avant dans le modèle
outputs = model(**inputs)
# Obtenir les logits
logits = outputs.logits
# Calculer les scores de confiance avec softmax
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Obtenir l'indice de la classe la plus probable
predicted_class_idx = probs.argmax().item()
# Obtenir le label et le score de confiance de la classe la plus probable
predicted_label = id2label[str(predicted_class_idx)]
confidence_score = probs[predicted_class_idx].item() * 100 # Multiplie par 100 pour obtenir un pourcentage
# Retourner le label et le score de confiance
return f"{predicted_label}: {confidence_score:.2f}%"
# Créer l'interface Gradio
image = gr.Image(type="pil")
label = gr.Textbox(label="Prediction")
gr.Interface(fn=predict,
inputs=image,
outputs=label,
title="Classification des maladies des agrumes",
description="Téléchargez une image d'une feuille d'agrume pour classer sa maladie. Le modèle est entraîné sur les maladies suivantes : Aleurocanthus spiniferus, Chancre citrique, Cochenille blanche, Dépérissement des agrumes, Feuille saine, Jaunissement des feuilles, Maladie de l'oïdium, Maladie du dragon jaune, Mineuse des agrumes, Trou de balle.",
examples=["maladie_du_dragon_jaune.jpg", "critique.jpg","feuille_saine.jpg"]).launch(share=True) |