File size: 884 Bytes
19b3da3 b71808f 19b3da3 b71808f 19b3da3 b71808f 19b3da3 b71808f 19b3da3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import List
from transformers import pipeline
from internals.util.commons import download_image
class ImageClassifier:
__loaded = False
def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]):
self.__candidates = candidates
def load(self):
if self.__loaded:
return
self.pipe = pipeline(
"zero-shot-image-classification",
model="philschmid/clip-zero-shot-image-classification",
)
self.__loaded = True
def classify(self, image_url: str, width: int, height: int) -> str:
self.load()
image = download_image(image_url).resize((width, height))
results = self.pipe.__call__([image], candidate_labels=self.__candidates)
results = results[0]
if len(results) > 0:
return results[0]["label"]
return ""
|