import gradio as gr from transformers import AutoImageProcessor, AutoModelForImageClassification import torch from PIL import Image model_name = 'e1010101/vit-384-tongue-image' processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-384") model = AutoModelForImageClassification.from_pretrained( model_name, num_labels=3, problem_type="multi_label_classification", ignore_mismatched_sizes=True, id2label={0: 'Crack', 1: 'Red-Dots', 2: 'Toothmark'}, label2id={'Crack': 0, 'Red-Dots': 1, 'Toothmark': 2} ) def classify_image(image, threshold=0.5): inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Apply sigmoid for multi-label classification probs = torch.sigmoid(logits)[0].numpy() # Get label names labels = model.config.id2label.values() # Create a dictionary of labels and probabilities result = {label: float(prob) for label, prob in zip(labels, probs)} # Sort results by probability result = dict(sorted(result.items(), key=lambda item: item[1], reverse=True)) return result interface = gr.Interface( fn=classify_image, inputs=[ gr.Image(type="pil"), gr.Slider(minimum=0, maximum=1, value=0.5, label="Probability Threshold") ], outputs=gr.Label(num_top_classes=None), title="Multi-Label Image Classification", description="Upload an image to get classification results." ) if __name__ == "__main__": interface.launch()