|
import gradio as gr |
|
import torch |
|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
from PIL import Image |
|
|
|
|
|
model_name = "KhadijaAsehnoune12/OrangeLeafDiseaseDetector" |
|
model = ViTForImageClassification.from_pretrained(model_name, num_labels=10, ignore_mismatched_sizes=True) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) |
|
|
|
|
|
id2label = { |
|
"0": "Aleurocanthus spiniferus", |
|
"1": "Chancre citrique", |
|
"2": "Cochenille blanche", |
|
"3": "Dépérissement des agrumes", |
|
"4": "Feuille saine", |
|
"5": "Jaunissement des feuilles", |
|
"6": "Maladie de l'oïdium", |
|
"7": "Maladie du dragon jaune", |
|
"8": "Mineuse des agrumes", |
|
"9": "Trou de balle" |
|
} |
|
|
|
|
|
image_slider = [ |
|
{"image": "aleurocanthus_spiniferus.jpg", "label": "Aleurocanthus spiniferus"}, |
|
{"image": "chancre_citrique.jpg", "label": "Chancre citrique"}, |
|
{"image": "cochenille_blanche.jpg", "label": "Cochenille blanche"}, |
|
{"image": "deperissement_des_agrumes.jpg", "label": "Dépérissement des agrumes"}, |
|
{"image": "feuille_saine.jpg", "label": "Feuille saine"}, |
|
{"image": "jaunissement_des_feuilles.jpg", "label": "Jaunissement des feuilles"}, |
|
{"image": "maladie_de_loidium.jpg", "label": "Maladie de l'oïdium"}, |
|
{"image": "maladie_du_dragon_jaune.jpg", "label": "Maladie du dragon jaune"}, |
|
{"image": "mineuse_des_agrumes.jpg", "label": "Mineuse des agrumes"}, |
|
{"image": "trou_de_balle.jpg", "label": "Trou de balle"} |
|
] |
|
|
|
def predict(image): |
|
|
|
inputs = feature_extractor(images=image, return_tensors="pt") |
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
confidence_score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item() |
|
|
|
|
|
predicted_label = id2label[str(predicted_class_idx)] |
|
|
|
|
|
return predicted_label, f"Confidence: {confidence_score:.2f}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
<h1>Citrus Leaf Disease Classification</h1> |
|
<p>Upload an image of a citrus leaf to classify its disease.</p> |
|
<p>Supported diseases:</p> |
|
<ul> |
|
<li>Aleurocanthus spiniferus</li> |
|
<li>Chancre citrique</li> |
|
<li>Cochenille blanche</li> |
|
<li>Dépérissement des agrumes</li> |
|
<li>Feuille saine</li> |
|
<li>Jaunissement des feuilles</li> |
|
<li>Maladie de l'oïdium</li> |
|
<li>Maladie du dragon jaune</li> |
|
<li>Mineuse des agrumes</li> |
|
<li>Trou de balle</li> |
|
</ul> |
|
<p>Example images:</p> |
|
""") |
|
image_slider_component = gr.Gallery(image_slider, label="Example Images") |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
label_output = gr.Textbox(label="Prediction") |
|
btn = gr.Button("Classify") |
|
|
|
btn.click(fn=predict, inputs=image_input, outputs=label_output) |
|
|
|
demo.launch() |