LittleApple-fp16's picture
Upload 88 files
4f8ad24
raw
history blame
No virus
3.06 kB
from functools import partial
from typing import Iterator, Union, List, Mapping, Literal
from PIL import Image
from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags
from .base import ProcessAction, BaseAction
from ..model import ImageItem
def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False,
general_threshold: float = 0.5, character_threshold: float = 0.5, **kwargs):
_ = kwargs
_, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold)
return {**features, **characters}
def _wd14_tagging(image: Image.Image, model_name: str,
general_threshold: float = 0.35, character_threshold: float = 0.85, **kwargs):
_ = kwargs
_, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold)
return {**features, **characters}
def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.7, **kwargs):
_ = kwargs
features = get_mldanbooru_tags(image, use_real_name, general_threshold)
return features
_TAGGING_METHODS = {
'deepdanbooru': _deepdanbooru_tagging,
'wd14_vit': partial(_wd14_tagging, model_name='ViT'),
'wd14_convnext': partial(_wd14_tagging, model_name='ConvNext'),
'wd14_convnextv2': partial(_wd14_tagging, model_name='ConvNextV2'),
'wd14_swinv2': partial(_wd14_tagging, model_name='SwinV2'),
'mldanbooru': _mldanbooru_tagging,
}
TaggingMethodTyping = Literal[
'deepdanbooru', 'wd14_vit', 'wd14_convnext', 'wd14_convnextv2', 'wd14_swinv2', 'mldanbooru']
class TaggingAction(ProcessAction):
def __init__(self, method: TaggingMethodTyping = 'wd14_convnextv2', force: bool = False, **kwargs):
self.method = _TAGGING_METHODS[method]
self.force = force
self.kwargs = kwargs
def process(self, item: ImageItem) -> ImageItem:
if 'tags' in item.meta and not self.force:
return item
else:
tags = self.method(image=item.image, **self.kwargs)
return ImageItem(item.image, {**item.meta, 'tags': tags})
class TagFilterAction(BaseAction):
def __init__(self, tags: Union[List[str], Mapping[str, float]],
method: TaggingMethodTyping = 'wd14_convnextv2', **kwargs):
if isinstance(tags, (list, tuple)):
self.tags = {tag: 1e-6 for tag in tags}
elif isinstance(tags, dict):
self.tags = dict(tags)
else:
raise TypeError(f'Unknown type of tags - {tags!r}.')
self.tagger = TaggingAction(method, force=False, **kwargs)
def iter(self, item: ImageItem) -> Iterator[ImageItem]:
item = self.tagger(item)
tags = item.meta['tags']
valid = True
for tag, min_score in self.tags.items():
if tags[tag] < min_score:
valid = False
break
if valid:
yield item
def reset(self):
self.tagger.reset()