Spaces:
Sleeping
Sleeping
from functools import lru_cache | |
from typing import Mapping | |
from huggingface_hub import hf_hub_download | |
from imgutils.data import ImageTyping, load_image | |
from onnx_ import _open_onnx_model | |
from preprocess import _img_encode | |
_LABELS = ['3d', 'bangumi', 'comic', 'illustration'] | |
_CLS_MODELS = [ | |
'caformer_s36', | |
'caformer_s36_plus', | |
'mobilenetv3', | |
'mobilenetv3_dist', | |
'mobilenetv3_sce', | |
'mobilenetv3_sce_dist', | |
'mobilevitv2_150', | |
] | |
_DEFAULT_CLS_MODEL = 'mobilenetv3_sce_dist' | |
def _open_anime_classify_model(model_name): | |
return _open_onnx_model(hf_hub_download( | |
f'deepghs/anime_classification', | |
f'{model_name}/model.onnx', | |
)) | |
def _gr_classification(image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]: | |
image = load_image(image, mode='RGB') | |
input_ = _img_encode(image, size=(size, size))[None, ...] | |
output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_}) | |
values = dict(zip(_LABELS, map(lambda x: x.item(), output[0]))) | |
return values | |