gruhit-patel's picture
Transferred from from PyTorch -> Tensorflow
071870f verified
raw
history blame
1.27 kB
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from fastapi import FastAPI, UploadFile, File
import json
from PIL import Image
from io import BytesIO
from model import get_model
app = FastAPI()
IMAGE_WIDTH = 224
IMAGE_HEIGHT = 224
MODEL_WEIGHT_PATH = 'vgg_face_weights2.h5'
model = get_model(
image_shape = (IMAGE_WIDTH, IMAGE_HEIGHT, 3),
num_classes = 6,
model_weights = MODEL_WEIGHT_PATH
)
print(model.summary())
print("Model Loaded Successfully")
######### Utilities #########
def load_image(image_data):
image = Image.open(BytesIO(image_data))
return image
def preprocess(image):
image = image.resize((IMAGE_WIDTH, IMAGE_HEIGHT))
image = np.array(image)
image = np.expand_dims(image, axis=0)
return image
def get_prediction(image):
probs = model.predict(image)[0]
label = np.argmax(probs)
return {
'pred_probs': pred_probs.tolist(),
'label': int(label)
}
@app.get("/")
def foo():
return {
"status": "Face Expression Classifier"
}
@app.post("/get_prediction")
async def predict(face_image: UploadFile = File(...)):
image = load_image(await face_image.read())
image = preprocess(image)
result = get_prediction(image)
return {
"result": json.dumps(result)
}