|
from typing import List |
|
|
|
from transformers import pipeline |
|
|
|
from internals.util.commons import download_image |
|
|
|
|
|
class ImageClassifier: |
|
def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]): |
|
self.__candidates = candidates |
|
|
|
def load(self): |
|
self.pipe = pipeline( |
|
"zero-shot-image-classification", |
|
model="philschmid/clip-zero-shot-image-classification", |
|
) |
|
|
|
def classify(self, image_url: str, width: int, height: int) -> str: |
|
image = download_image(image_url).resize((width, height)) |
|
results = self.pipe.__call__([image], candidate_labels=self.__candidates) |
|
results = results[0] |
|
if len(results) > 0: |
|
return results[0]["label"] |
|
return "" |
|
|