e1010101's picture
Fix classification issue
fba7380 verified
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch
from PIL import Image
model_name = 'e1010101/vit-384-tongue-image'
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-384")
model = AutoModelForImageClassification.from_pretrained(
model_name,
num_labels=3,
problem_type="multi_label_classification",
ignore_mismatched_sizes=True,
id2label={0: 'Crack', 1: 'Red-Dots', 2: 'Toothmark'},
label2id={'Crack': 0, 'Red-Dots': 1, 'Toothmark': 2}
)
def classify_image(image, threshold=0.5):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Apply sigmoid for multi-label classification
probs = torch.sigmoid(logits)[0].numpy()
# Get label names
labels = model.config.id2label.values()
# Create a dictionary of labels and probabilities
result = {label: float(prob) for label, prob in zip(labels, probs)}
# Sort results by probability
result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True))
return result
interface = gr.Interface(
fn=classify_image,
inputs=[
gr.Image(type="pil"),
gr.Slider(minimum=0, maximum=1, value=0.5, label="Probability Threshold")
],
outputs=gr.Label(num_top_classes=None),
title="Multi-Label Image Classification",
description="Upload an image to get classification results."
)
if __name__ == "__main__":
interface.launch()