import torch import urllib.request from PIL import Image from datasets import load_dataset from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification dataset = load_dataset("chriamue/bird-species-dataset") ##### labels = dataset["test"].features["label"].names label2id, id2label = dict(), dict() for i, label in enumerate(labels): label2id[label] = str(i) id2label[str(i)] = label preprocessor = EfficientNetImageProcessor.from_pretrained("google/efficientnet-b2") model = EfficientNetForImageClassification.from_pretrained("chriamue/bird-species-classifier", num_labels=len( labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True) image = dataset["validation"][0]["image"] url = 'https://upload.wikimedia.org/wikipedia/commons/a/a9/Common_Blackbird.jpg' image = Image.open(urllib.request.urlretrieve(url)[0]) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) inputs = preprocessor(image, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits predicted_label = logits.argmax(-1).item() print(labels[predicted_label])