In [73]:
import torch    
import pickle
import matplotlib.pyplot as plt
import numpy as np
import time

In [75]:
IMAGE_SIZE = 224 # We need to resize the images given resnet takes input of image size >= 224

mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]
classes = ('airplane', 
           'automobile', 
           'bird',
           'cat',
           'deer',
           'dog', 
           'frog', 
           'horse', 
           'ship',
           'truck')

if torch.cuda.is_available():
    torch.set_default_device('cuda')

def show_data(img):
    try:
        plt.imshow(img[0])
    except Exception as e:
        print(e)
    print(img[0].shape, img[0].permute(1,2,0).shape)
    plt.imshow(img[0].permute(1,2,0))
    plt.title('y = '+ str(img[1]))
    plt.show()
    
# We need to convert the images to numpy arrays as tensors are not compatible with matplotlib.
def im_convert(tensor):
    #Lets
    img = tensor.cpu().clone().detach().numpy() #
    img = img.transpose(1, 2, 0)
    img = img * np.array(tuple(mean)) + np.array(tuple(std))
    img = img.clip(0, 1) # Clipping the size to print the images later
    return img

In [64]:
def unpickle(file):
    with open(file, 'rb') as fo:
        data_dict = pickle.load(fo, encoding='bytes')
        
        # Decode keys from bytes to strings
        decoded_dict = {}
        for key, value in data_dict.items():
            decoded_key = key.decode('utf-8')  # Assuming UTF-8 encoding
            decoded_dict[decoded_key] = value
        
    return decoded_dict


In [76]:
decoded_dict = unpickle('./test_batch')
decoded_dict
data = torch.tensor(decoded_dict['data']).reshape([10000,3,32,32])
dataset = {"image":data, "target": torch.tensor(decoded_dict["labels"])}

In [77]:
decoded_dict.keys()

dict_keys(['batch_label', 'labels', 'data', 'filenames'])

In [78]:
idx = 0
image = dataset['image'][idx]
label = dataset["target"][idx].item()

In [79]:
classes[label]

'cat'

In [82]:
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("heyitskim1912/AML_A2_Q4")
model = AutoModelForImageClassification.from_pretrained("heyitskim1912/AML_A2_Q4")

inputs = processor(image, return_tensors="pt")

start_time = time.time()
with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
end_time = time.time()
time_taken = round(end_time - start_time, 3)
print(f"Time taken: {time_taken} s")

cat
Time taken: 0.013 s
