aesthetics_vit / model.py
dveranieto
Added usage helpers
036a350
# * SPDX-License-Identifier: Apache-2.0
# * © 2023 ETH Zurich and other contributors, see AUTHORS.txt for details
import os
from typing import List
import torch
from PIL import Image
from fastapi import HTTPException
from mtc_api_utils.base_model import MLBaseModel
from mtc_api_utils.init_api import download_if_not_exists
from torchvision import transforms
from aesthetics_model.api_types import AestheticsModelResponse
from aesthetics_model.config import AestheticsModelConfig
from aesthetics_model.model.vision_transformer import vit_large_patch16_224_in21k
class AestheticModel(MLBaseModel):
model = None
transform = None
device = None
model_checkpoint_path = None
def init_model(self):
print("Initializing Aesthetic Model...")
# Download image archive
self.model_checkpoint_path = download_if_not_exists(
artifact_url=os.path.join(AestheticsModelConfig.model_base_url, AestheticsModelConfig.model_url),
download_dir=AestheticsModelConfig.model_checkpoint_path,
polybox_auth=AestheticsModelConfig.polybox_credentials,
is_tar=False,
)
print("Model downloaded, loading to device..")
self.model, self.transform, self.device = self.load_model()
print("Model initialized")
def is_ready(self):
return self.model is not None
def inference(self, images: List[Image.Image]) -> List[AestheticsModelResponse]:
if not self.is_ready():
raise HTTPException(status_code=503, detail="Model is not ready")
results = []
for image in images:
im = self.validate_and_resize_image(image, max_size=AestheticsModelConfig.max_image_size)
image_tensor = self.transform(im).unsqueeze(0).to(self.device)
with torch.no_grad():
image_embedding = self.model.forward_features(image_tensor)
score = self.model.head(image_embedding).squeeze().to("cpu")
result: AestheticsModelResponse = AestheticsModelResponse(
aesthetics_score=float(score),
aesthetics_embedding=image_embedding.to("cpu").tolist(),
# TODO: success is going to be always true
success=True
)
results.append(result)
return results
def load_model(self):
# Check if a gpu and cuda is available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print(f"Running on {device}")
model = vit_large_patch16_224_in21k()
model.reset_classifier(num_classes=1)
model.load_state_dict(torch.load(self.model_checkpoint_path, map_location=device))
model.eval()
model.to(device)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)
return model, transform, device
def validate_and_resize_image(self, image: Image.Image, max_size: int):
"""Check if an image is to big, if it is too big resize it. Transform to RGB if B&W image.
Args:
img : PIL Image
max_size : int
"""
if not image.mode == "RGB":
image = image.convert(mode="RGB")
size_factor = (max_size + 0.0) / max(image.size)
if size_factor < 1:
new_h = round(image.size[0] * size_factor)
new_w = round(image.size[1] * size_factor)
image = image.resize((new_w, new_h), resample=Image.HAMMING)
return image
if __name__ == '__main__':
model = AestheticModel()
model.__wait_until_ready__()
example_images = ["../tests/test_image.png"] # model.get_examples()
# Perform inference with underlying model
images = [model.model(Image.open(image)) for image in example_images]
model_result = model.inference(images=images)
print(model_result)