File size: 1,221 Bytes
62a2783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import urllib.request
from PIL import Image
from datasets import load_dataset
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification

dataset = load_dataset("chriamue/bird-species-dataset")

#####
labels = dataset["test"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

preprocessor = EfficientNetImageProcessor.from_pretrained("google/efficientnet-b2")
model = EfficientNetForImageClassification.from_pretrained("chriamue/bird-species-classifier", num_labels=len(
    labels), id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)

image = dataset["validation"][0]["image"]

url = 'https://upload.wikimedia.org/wikipedia/commons/a/a9/Common_Blackbird.jpg'
image = Image.open(urllib.request.urlretrieve(url)[0])


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

inputs = preprocessor(image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    logits = model(**inputs).logits
    predicted_label = logits.argmax(-1).item()
    print(labels[predicted_label])