import torch from torch import nn, optim from torchvision import transforms, models #from torch_snippets import * #from torch.utils.data import DataLoader, Dataset #from torchsummary import summary #import seaborn as sns #import matplotlib.pyplot as plt #from sklearn.model_selection import train_test_split from PIL import Image #import numpy as np #import cv2 #from glob import glob #import pandas as pd import numpy as np #device = 'cuda' if torch.cuda.is_available() else 'cpu' class ActionClassifier(nn.Module): def __init__(self, ntargets): super().__init__() resnet = models.resnet50(pretrained=True, progress=True) modules = list(resnet.children())[:-1] # delete last layer self.resnet = nn.Sequential(*modules) for param in self.resnet.parameters(): param.requires_grad = False self.fc = nn.Sequential( nn.Flatten(), nn.BatchNorm1d(resnet.fc.in_features), nn.Dropout(0.2), nn.Linear(resnet.fc.in_features, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Dropout(0.2), nn.Linear(256, ntargets) ) def forward(self, x): x = self.resnet(x) x = self.fc(x) return x def get_transform(): transform = transforms.Compose([ transforms.Resize([224, 244]), transforms.ToTensor(), # std multiply by 255 to convert img of [0, 255] # to img of [0, 1] transforms.Normalize((0.485, 0.456, 0.406), (0.229*255, 0.224*255, 0.225*255))] ) return transform def get_model(): model = ActionClassifier(15) model.load_state_dict(torch.load('./classifier_weights.pth', map_location=torch.device('cpu'))) return model def get_class(index): ind2cat = [ 'calling', 'clapping', 'cycling', 'dancing', 'drinking', 'eating', 'fighting', 'hugging', 'laughing', 'listening_to_music', 'running', 'sitting', 'sleeping', 'texting', 'using_laptop' ] return ind2cat[index] # img = Image.open('./inputs/Image_102.jpg').convert('RGB') # #print(transform(img)) # img = transform(img) # img = img.unsqueeze(dim=0) # print(img.shape) # model.eval() # with torch.no_grad(): # out = model(img) # out = nn.Softmax()(out).squeeze() # print(out.shape) # res = torch.argmax(out) # print(ind2cat[res])