from typing import List | |
from transformers import pipeline | |
from internals.util.commons import download_image | |
class ImageClassifier: | |
__loaded = False | |
def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]): | |
self.__candidates = candidates | |
def load(self): | |
if self.__loaded: | |
return | |
self.pipe = pipeline( | |
"zero-shot-image-classification", | |
model="philschmid/clip-zero-shot-image-classification", | |
) | |
self.__loaded = True | |
def classify(self, image_url: str, width: int, height: int) -> str: | |
self.load() | |
image = download_image(image_url).resize((width, height)) | |
results = self.pipe.__call__([image], candidate_labels=self.__candidates) | |
results = results[0] | |
if len(results) > 0: | |
return results[0]["label"] | |
return "" | |