njgroene's picture
Update app.py
097721a
raw
history blame
3.6 kB
import gradio as gr
from PIL import Image
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from torch_mtcnn import detect_faces
from torch_mtcnn import show_bboxes
def pipeline(img):
bounding_boxes, landmarks = detect_faces(img)
if len(bounding_boxes) == 0:
raise Exception("Didn't find face any faces, try another image!")
if len(bounding_boxes) > 1:
raise Exception("Found more than one face, try a profile picture with only one person in it!")
bb = [bounding_boxes[0,0], bounding_boxes[0,1], bounding_boxes[0,2], bounding_boxes[0,3]]
img_cropped = img.crop(bb)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_fair_7 = torchvision.models.resnet34(pretrained=True)
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
model_fair_7.load_state_dict(torch.load('res34_fair_align_multi_7_20190809.pt', map_location=torch.device('cpu')))
model_fair_7 = model_fair_7.to(device)
model_fair_7.eval()
trans = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
face_names = []
gender_scores_fair = []
age_scores_fair = []
gender_preds_fair = []
age_preds_fair = []
image = trans(img_cropped)
image = image.view(1, 3, 224, 224) # reshape image to match model dimensions (1 batch size)
image = image.to(device)
# fair 7 class
outputs = model_fair_7(image)
outputs = outputs.cpu().detach().numpy()
outputs = np.squeeze(outputs)
gender_outputs = outputs[7:9]
age_outputs = outputs[9:18]
gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs))
age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs))
gender_pred = np.argmax(gender_score)
age_pred = np.argmax(age_score)
gender_scores_fair.append(gender_score)
age_scores_fair.append(age_score)
gender_preds_fair.append(gender_pred)
age_preds_fair.append(age_pred)
result = pd.DataFrame([gender_preds_fair,
age_preds_fair]).T
result.columns = ['gender_preds_fair',
'age_preds_fair']
# gender
result.loc[result['gender_preds_fair'] == 0, 'gender'] = 'Male'
result.loc[result['gender_preds_fair'] == 1, 'gender'] = 'Female'
# age
result.loc[result['age_preds_fair'] == 0, 'age'] = '0-2'
result.loc[result['age_preds_fair'] == 1, 'age'] = '3-9'
result.loc[result['age_preds_fair'] == 2, 'age'] = '10-19'
result.loc[result['age_preds_fair'] == 3, 'age'] = '20-29'
result.loc[result['age_preds_fair'] == 4, 'age'] = '30-39'
result.loc[result['age_preds_fair'] == 5, 'age'] = '40-49'
result.loc[result['age_preds_fair'] == 6, 'age'] = '50-59'
result.loc[result['age_preds_fair'] == 7, 'age'] = '60-69'
result.loc[result['age_preds_fair'] == 8, 'age'] = '70+'
return "A " + result['gender'][0] + " in the age range of " + result['age'][0]
def predict(image):
try :
predictions = pipeline(image)
except Exception as e:
predictions = e
return predictions
gr.Interface(
predict,
inputs=gr.inputs.Image(label="Upload a profile picture of a single person", type="pil"),
outputs=("text"),
title="Estimate age and gender from profile picture",
examples=["ex0.jpg","ex4.jpg", "ex1.jpg","ex2.jpg","ex3.jpg","ex5.jpg"]
).launch()