CosVersin commited on
Commit
f292456
·
1 Parent(s): a1704f6

Upload 25 files

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .vscode/
3
+ .venv/
4
+ .env
5
+
6
+ presets/
README.ko.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [Automatic1111 웹UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)를 위한 태깅(라벨링) 확장 기능
2
+ ---
3
+ DeepDanbooru 와 같은 모델을 통해 단일 또는 여러 이미지로부터 부루에서 사용하는 태그를 알아냅니다.
4
+
5
+ [You don't know how to read Korean? Read it in English here!](README.md)
6
+
7
+ ## 들어가기 앞서
8
+ 모델과 대부분의 코드는 제가 만들지 않았고 [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) 와 MrSmillingWolf 의 태거에서 가져왔습니다.
9
+
10
+ ## 설치하기
11
+ 1. *확장기능* -> *URL로부터 확장기능 설치* -> 이 레포지토리 주소 입력 -> *설치*
12
+ - 또는 이 레포지토리를 `extensions/` 디렉터리 내에 클론합니다.
13
+ ```sh
14
+ $ git clone https://github.com/toriato/stable-diffusion-webui-wd14-tagger.git extensions/tagger
15
+ ```
16
+
17
+ 1. 모델 추가하기
18
+ - #### *MrSmilingWolf's model (a.k.a. Waifu Diffusion 1.4 tagger)*
19
+ 처음 실행할 때 [HuggingFace 레포지토리](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger)로부터 자동으로 받아옵니다.
20
+
21
+ 모델과 관련된 또는 추가 학습에 대한 질문은 원작자인 MrSmilingWolf#5991 으로 물어봐주세요.
22
+
23
+ - #### *DeepDanbooru*
24
+ 1. 다양한 모델 파일은 아래 주소에서 찾을 수 있습니다.
25
+ - [DeepDanbooru model](https://github.com/KichangKim/DeepDanbooru/releases)
26
+ - [e621 model by 🐾Zack🐾#1984](https://discord.gg/BDFpq9Yb7K)
27
+ *(NSFW 주의!)*
28
+
29
+ 1. 모델과 설정 파일이 포함된 프로젝트 폴더를 `models/deepdanbooru` 경로로 옮깁니다.
30
+
31
+ 1. 파일 구조는 다음과 같습니다:
32
+ ```
33
+ models/
34
+ └╴deepdanbooru/
35
+ ├╴deepdanbooru-v3-20211112-sgd-e28/
36
+ │ ├╴project.json
37
+ │ └╴...
38
+
39
+ ├╴deepdanbooru-v4-20200814-sgd-e30/
40
+ │ ├╴project.json
41
+ │ └╴...
42
+
43
+ ├╴e621-v3-20221117-sgd-e32/
44
+ │ ├╴project.json
45
+ │ └╴...
46
+
47
+ ...
48
+ ```
49
+
50
+ 1. 웹UI 를 시작하거나 재시작합니다.
51
+ - 또는 *Interrogator* 드롭다운 상자 우측에 있는 새로고침 버튼을 누릅니다.
52
+
53
+
54
+ ## 스크린샷
55
+ ![Screenshot](docs/screenshot.png)
56
+
57
+ Artwork made by [hecattaart](https://vk.com/hecattaart?w=wall-89063929_3767)
58
+
59
+ ## 저작권
60
+
61
+ 빌려온 코드(예: `dbimutils.py`)를 제외하고 모두 Public domain
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tagger for [Automatic1111's WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
2
+ ---
3
+ Interrogate booru style tags for single or multiple image files using various models, such as DeepDanbooru.
4
+
5
+ [한국어를 사용하시나요? 여기에 한국어 설명서가 있습니다!](README.ko.md)
6
+
7
+ ## Disclaimer
8
+ I didn't make any models, and most of the code was heavily borrowed from the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) and MrSmillingWolf's tagger.
9
+
10
+ ## Installation
11
+ 1. *Extensions* -> *Install from URL* -> Enter URL of this repository -> Press *Install* button
12
+ - or clone this repository under `extensions/`
13
+ ```sh
14
+ $ git clone https://github.com/toriato/stable-diffusion-webui-wd14-tagger.git extensions/tagger
15
+ ```
16
+
17
+ 1. Add interrogate model
18
+ - #### *MrSmilingWolf's model (a.k.a. Waifu Diffusion 1.4 tagger)*
19
+ Downloads automatically from the [HuggingFace repository](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) the first time you run it.
20
+
21
+ Please ask the original author MrSmilingWolf#5991 for questions related to model or additional training.
22
+
23
+ ##### ViT vs Convnext
24
+ > To make it clear: the ViT model is the one used to tag images for WD 1.4. That's why the repo was originally called like that. This one has been trained on the same data and tags, but has got no other relation to WD 1.4, aside from stemming from the same coordination effort. They were trained in parallel, and the best one at the time was selected for WD 1.4
25
+
26
+ > This particular model was trained later and might actually be slightly better than the ViT one. Difference is in the noise range tho
27
+
28
+ — [SmilingWolf](https://github.com/SmilingWolf) from [this thread](https://discord.com/channels/930499730843250783/1052283314997837955) in the [東方Project AI server](https://discord.com/invite/touhouai)
29
+
30
+ - #### *DeepDanbooru*
31
+ 1. Various model files can be found below.
32
+ - [DeepDanbooru models](https://github.com/KichangKim/DeepDanbooru/releases)
33
+ - [e621 model by 🐾Zack🐾#1984](https://discord.gg/BDFpq9Yb7K)
34
+ *(link contains NSFW contents!)*
35
+
36
+ 1. Move the project folder containing the model and config to `models/deepdanbooru`
37
+
38
+ 1. The file structure should look like:
39
+ ```
40
+ models/
41
+ └╴deepdanbooru/
42
+ ├╴deepdanbooru-v3-20211112-sgd-e28/
43
+ │ ├╴project.json
44
+ │ └╴...
45
+
46
+ ├╴deepdanbooru-v4-20200814-sgd-e30/
47
+ │ ├╴project.json
48
+ │ └╴...
49
+
50
+ ├╴e621-v3-20221117-sgd-e32/
51
+ │ ├╴project.json
52
+ │ └╴...
53
+
54
+ ...
55
+ ```
56
+
57
+ 1. Start or restart the WebUI.
58
+ - or you can press refresh button after *Interrogator* dropdown box.
59
+
60
+
61
+ ## Model comparison
62
+
63
+ * Used image: [hecattaart's artwork](https://vk.com/hecattaart?w=wall-89063929_3767)
64
+ * Threshold: `0.5`
65
+
66
+ ### DeepDanbooru
67
+ Used the same image as the one used in the Screenshot item
68
+
69
+ #### [`deepdanbooru-v3-20211112-sgd-e28`](https://github.com/KichangKim/DeepDanbooru/releases/tag/v3-20211112-sgd-e28)
70
+ ```
71
+ 1girl, animal ears, cat ears, cat tail, clothes writing, full body, rating:safe, shiba inu, shirt, shoes, simple background, sneakers, socks, solo, standing, t-shirt, tail, white background, white shirt
72
+ ```
73
+
74
+ #### [`deepdanbooru-v4-20200814-sgd-e30`](https://github.com/KichangKim/DeepDanbooru/releases/tag/v4-20200814-sgd-e30)
75
+ ```
76
+ 1girl, animal, animal ears, bottomless, clothes writing, full body, rating:safe, shirt, shoes, short sleeves, sneakers, solo, standing, t-shirt, tail, white background, white shirt
77
+ ```
78
+
79
+ #### `e621-v3-20221117-sgd-e32`
80
+ ```
81
+ anthro, bottomwear, clothing, footwear, fur, hi res, mammal, shirt, shoes, shorts, simple background, sneakers, socks, solo, standing, text on clothing, text on topwear, topwear, white background
82
+ ```
83
+
84
+ ### Waifu Diffusion Tagger
85
+
86
+ #### [`wd14-vit`](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger)
87
+ ```
88
+ 1boy, animal ears, dog, furry, leg hair, male focus, shirt, shoes, simple background, socks, solo, tail, white background
89
+ ```
90
+
91
+ #### [`wd14-convnext`](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger)
92
+ ```
93
+ full body, furry, shirt, shoes, simple background, socks, solo, tail, white background
94
+ ```
95
+
96
+ #### [`wd14-vit-v2`](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
97
+ ```
98
+ 1boy, animal ears, cat, furry, male focus, shirt, shoes, simple background, socks, solo, tail, white background
99
+ ```
100
+
101
+ #### [`wd14-convnext-v2`](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
102
+ ```
103
+ animal focus, clothes writing, earrings, full body, meme, shirt, shoes, simple background, socks, solo, sweat, tail, white background, white shirt
104
+ ```
105
+
106
+ #### [`wd14-swinv2-v2`](https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2)
107
+ ```
108
+ 1boy, arm hair, black footwear, cat, dirty, full body, furry, leg hair, male focus, shirt, shoes, simple background, socks, solo, standing, tail, white background, white shirt
109
+ ```
110
+
111
+ ## Screenshot
112
+ ![Screenshot](docs/screenshot.png)
113
+
114
+ Artwork made by [hecattaart](https://vk.com/hecattaart?w=wall-89063929_3767)
115
+
116
+ ## Copyright
117
+
118
+ Public domain, except borrowed parts (e.g. `dbimutils.py`)
__pycache__/preload.cpython-310.pyc ADDED
Binary file (733 Bytes). View file
 
docs/screenshot.png ADDED
javascript/tagger.js ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * wait until element is loaded and returns
3
+ * @param {string} selector
4
+ * @param {number} timeout
5
+ * @param {Element} $rootElement
6
+ * @returns {Promise<HTMLElement>}
7
+ */
8
+ function waitQuerySelector(selector, timeout = 5000, $rootElement = gradioApp()) {
9
+ return new Promise((resolve, reject) => {
10
+ const element = $rootElement.querySelector(selector)
11
+ if (document.querySelector(element)) {
12
+ return resolve(element)
13
+ }
14
+
15
+ let timeoutId
16
+
17
+ const observer = new MutationObserver(() => {
18
+ const element = $rootElement.querySelector(selector)
19
+ if (!element) {
20
+ return
21
+ }
22
+
23
+ if (timeoutId) {
24
+ clearInterval(timeoutId)
25
+ }
26
+
27
+ observer.disconnect()
28
+ resolve(element)
29
+ })
30
+
31
+ timeoutId = setTimeout(() => {
32
+ observer.disconnect()
33
+ reject(new Error(`timeout, cannot find element by '${selector}'`))
34
+ }, timeout)
35
+
36
+ observer.observe($rootElement, {
37
+ childList: true,
38
+ subtree: true
39
+ })
40
+ })
41
+ }
42
+
43
+ document.addEventListener('DOMContentLoaded', () => {
44
+ Promise.all([
45
+ // option texts
46
+ waitQuerySelector('#additioanl-tags'),
47
+ waitQuerySelector('#exclude-tags'),
48
+
49
+ // tag-confident labels
50
+ waitQuerySelector('#rating-confidents'),
51
+ waitQuerySelector('#tag-confidents')
52
+ ]).then(elements => {
53
+
54
+ const $additionalTags = elements[0].querySelector('textarea')
55
+ const $excludeTags = elements[1].querySelector('textarea')
56
+ const $ratingConfidents = elements[2]
57
+ const $tagConfidents = elements[3]
58
+
59
+ let $selectedTextarea = $additionalTags
60
+
61
+ /**
62
+ * @this {HTMLElement}
63
+ * @param {MouseEvent} e
64
+ * @listens document#click
65
+ */
66
+ function onClickTextarea(e) {
67
+ $selectedTextarea = this
68
+ }
69
+
70
+ $additionalTags.addEventListener('click', onClickTextarea)
71
+ $excludeTags.addEventListener('click', onClickTextarea)
72
+
73
+ /**
74
+ * @this {HTMLElement}
75
+ * @param {MouseEvent} e
76
+ * @listens document#click
77
+ */
78
+ function onClickLabels(e) {
79
+ // find clicked label item's wrapper element
80
+ const $tag = e.target.closest('.output-label > div:not(:first-child)')
81
+ if (!$tag) {
82
+ return
83
+ }
84
+
85
+ /** @type {string} */
86
+ const tag = $tag.querySelector('.leading-snug').textContent
87
+
88
+ // ignore if tag is already exist in textbox
89
+ const escapedTag = tag.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
90
+ const pattern = new RegExp(`(^|,)\\s{0,}${escapedTag}\\s{0,}($|,)`)
91
+ if (pattern.test($selectedTextarea.value)) {
92
+ return
93
+ }
94
+
95
+ if ($selectedTextarea.value !== '') {
96
+ $selectedTextarea.value += ', '
97
+ }
98
+
99
+ $selectedTextarea.value += tag
100
+ }
101
+
102
+ $ratingConfidents.addEventListener('click', onClickLabels)
103
+ $tagConfidents.addEventListener('click', onClickLabels)
104
+
105
+ }).catch(err => {
106
+ console.error(err)
107
+ })
108
+ })
preload.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from argparse import ArgumentParser
3
+
4
+ from modules.shared import models_path
5
+
6
+ default_ddp_path = Path(models_path, 'deepdanbooru')
7
+ default_onnx_path = Path(models_path, 'TaggerOnnx')
8
+
9
+
10
+ def preload(parser: ArgumentParser):
11
+ # default deepdanbooru use different paths:
12
+ # models/deepbooru and models/torch_deepdanbooru
13
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/c81d440d876dfd2ab3560410f37442ef56fc6632
14
+
15
+ parser.add_argument(
16
+ '--deepdanbooru-projects-path',
17
+ type=str,
18
+ help='Path to directory with DeepDanbooru project(s).',
19
+ default=default_ddp_path
20
+ )
21
+ parser.add_argument(
22
+ '--onnxtagger-path',
23
+ type=str,
24
+ help='Path to directory with DeepDanbooru project(s).',
25
+ default=default_onnx_path
26
+ )
27
+
scripts/tagger.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFile
2
+
3
+ from modules import script_callbacks
4
+ from tagger.api import on_app_started
5
+ from tagger.ui import on_ui_tabs
6
+
7
+
8
+ # if you do not initialize the Image object
9
+ # Image.registered_extensions() returns only PNG
10
+ Image.init()
11
+
12
+ # PIL spits errors when loading a truncated image by default
13
+ # https://pillow.readthedocs.io/en/stable/reference/ImageFile.html#PIL.ImageFile.LOAD_TRUNCATED_IMAGES
14
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
15
+
16
+
17
+ script_callbacks.on_app_started(on_app_started)
18
+ script_callbacks.on_ui_tabs(on_ui_tabs)
style.css ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #rating-confidents .output-label>div:not(:first-child) {
2
+ cursor: pointer;
3
+ }
4
+
5
+ #tag-confidents .output-label>div:not(:first-child) {
6
+ cursor: pointer;
7
+ }
tagger/__pycache__/api.cpython-310.pyc ADDED
Binary file (2.99 kB). View file
 
tagger/__pycache__/api_models.cpython-310.pyc ADDED
Binary file (1.32 kB). View file
 
tagger/__pycache__/dbimutils.cpython-310.pyc ADDED
Binary file (1.69 kB). View file
 
tagger/__pycache__/format.cpython-310.pyc ADDED
Binary file (1.68 kB). View file
 
tagger/__pycache__/interrogator.cpython-310.pyc ADDED
Binary file (7.69 kB). View file
 
tagger/__pycache__/preset.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
tagger/__pycache__/ui.cpython-310.pyc ADDED
Binary file (9.12 kB). View file
 
tagger/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.64 kB). View file
 
tagger/api.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ from threading import Lock
3
+ from secrets import compare_digest
4
+
5
+ from modules import shared
6
+ from modules.api.api import decode_base64_to_image
7
+ from modules.call_queue import queue_lock
8
+ from fastapi import FastAPI, Depends, HTTPException
9
+ from fastapi.security import HTTPBasic, HTTPBasicCredentials
10
+
11
+ from tagger import utils
12
+ from tagger import api_models as models
13
+
14
+
15
+ class Api:
16
+ def __init__(self, app: FastAPI, queue_lock: Lock, prefix: str = None) -> None:
17
+ if shared.cmd_opts.api_auth:
18
+ self.credentials = dict()
19
+ for auth in shared.cmd_opts.api_auth.split(","):
20
+ user, password = auth.split(":")
21
+ self.credentials[user] = password
22
+
23
+ self.app = app
24
+ self.queue_lock = queue_lock
25
+ self.prefix = prefix
26
+
27
+ self.add_api_route(
28
+ 'interrogate',
29
+ self.endpoint_interrogate,
30
+ methods=['POST'],
31
+ response_model=models.TaggerInterrogateResponse
32
+ )
33
+
34
+ self.add_api_route(
35
+ 'interrogators',
36
+ self.endpoint_interrogators,
37
+ methods=['GET'],
38
+ response_model=models.InterrogatorsResponse
39
+ )
40
+
41
+ def auth(self, creds: HTTPBasicCredentials = Depends(HTTPBasic())):
42
+ if creds.username in self.credentials:
43
+ if compare_digest(creds.password, self.credentials[creds.username]):
44
+ return True
45
+
46
+ raise HTTPException(
47
+ status_code=401,
48
+ detail="Incorrect username or password",
49
+ headers={
50
+ "WWW-Authenticate": "Basic"
51
+ })
52
+
53
+ def add_api_route(self, path: str, endpoint: Callable, **kwargs):
54
+ if self.prefix:
55
+ path = f'{self.prefix}/{path}'
56
+
57
+ if shared.cmd_opts.api_auth:
58
+ return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
59
+ return self.app.add_api_route(path, endpoint, **kwargs)
60
+
61
+ def endpoint_interrogate(self, req: models.TaggerInterrogateRequest):
62
+ if req.image is None:
63
+ raise HTTPException(404, 'Image not found')
64
+
65
+ if req.model not in utils.interrogators.keys():
66
+ raise HTTPException(404, 'Model not found')
67
+
68
+ image = decode_base64_to_image(req.image)
69
+ interrogator = utils.interrogators[req.model]
70
+
71
+ with self.queue_lock:
72
+ ratings, tags = interrogator.interrogate(image)
73
+
74
+ return models.TaggerInterrogateResponse(
75
+ caption={
76
+ **ratings,
77
+ **interrogator.postprocess_tags(
78
+ tags,
79
+ req.threshold
80
+ )
81
+ })
82
+
83
+ def endpoint_interrogators(self):
84
+ return models.InterrogatorsResponse(
85
+ models=list(utils.interrogators.keys())
86
+ )
87
+
88
+
89
+ def on_app_started(_, app: FastAPI):
90
+ Api(app, queue_lock, '/tagger/v1')
tagger/api_models.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+
3
+ from modules.api import models as sd_models
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class TaggerInterrogateRequest(sd_models.InterrogateRequest):
8
+ model: str = Field(
9
+ title='Model',
10
+ description='The interrogate model used.'
11
+ )
12
+
13
+ threshold: float = Field(
14
+ default=0.35,
15
+ title='Threshold',
16
+ description='',
17
+ ge=0,
18
+ le=1
19
+ )
20
+
21
+
22
+ class TaggerInterrogateResponse(BaseModel):
23
+ caption: Dict[str, float] = Field(
24
+ title='Caption',
25
+ description='The generated caption for the image.'
26
+ )
27
+
28
+
29
+ class InterrogatorsResponse(BaseModel):
30
+ models: List[str] = Field(
31
+ title='Models',
32
+ description=''
33
+ )
tagger/dbimutils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DanBooru IMage Utility functions
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
9
+ if img.endswith(".gif"):
10
+ img = Image.open(img)
11
+ img = img.convert("RGB")
12
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
13
+ else:
14
+ img = cv2.imread(img, flag)
15
+ return img
16
+
17
+
18
+ def smart_24bit(img):
19
+ if img.dtype is np.dtype(np.uint16):
20
+ img = (img / 257).astype(np.uint8)
21
+
22
+ if len(img.shape) == 2:
23
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
24
+ elif img.shape[2] == 4:
25
+ trans_mask = img[:, :, 3] == 0
26
+ img[trans_mask] = [255, 255, 255, 255]
27
+ img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
28
+ return img
29
+
30
+
31
+ def make_square(img, target_size):
32
+ old_size = img.shape[:2]
33
+ desired_size = max(old_size)
34
+ desired_size = max(desired_size, target_size)
35
+
36
+ delta_w = desired_size - old_size[1]
37
+ delta_h = desired_size - old_size[0]
38
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
39
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
40
+
41
+ color = [255, 255, 255]
42
+ new_im = cv2.copyMakeBorder(
43
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
44
+ )
45
+ return new_im
46
+
47
+
48
+ def smart_resize(img, size):
49
+ # Assumes the image has already gone through make_square
50
+ if img.shape[0] > size:
51
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
52
+ elif img.shape[0] < size:
53
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
54
+ return img
tagger/format.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import hashlib
3
+
4
+ from typing import Dict, Callable, NamedTuple
5
+ from pathlib import Path
6
+
7
+
8
+ class Info(NamedTuple):
9
+ path: Path
10
+ output_ext: str
11
+
12
+
13
+ def hash(i: Info, algo='sha1') -> str:
14
+ try:
15
+ hash = hashlib.new(algo)
16
+ except ImportError:
17
+ raise ValueError(f"'{algo}' is invalid hash algorithm")
18
+
19
+ # TODO: is okay to hash large image?
20
+ with open(i.path, 'rb') as file:
21
+ hash.update(file.read())
22
+
23
+ return hash.hexdigest()
24
+
25
+
26
+ pattern = re.compile(r'\[([\w:]+)\]')
27
+
28
+ # all function must returns string or raise TypeError or ValueError
29
+ # other errors will cause the extension error
30
+ available_formats: Dict[str, Callable] = {
31
+ 'name': lambda i: i.path.stem,
32
+ 'extension': lambda i: i.path.suffix[1:],
33
+ 'hash': hash,
34
+
35
+ 'output_extension': lambda i: i.output_ext
36
+ }
37
+
38
+
39
+ def format(match: re.Match, info: Info) -> str:
40
+ matches = match[1].split(':')
41
+ name, args = matches[0], matches[1:]
42
+
43
+ if name not in available_formats:
44
+ return match[0]
45
+
46
+ return available_formats[name](info, *args)
tagger/interrogator.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import pandas as pd
4
+ import numpy as np
5
+
6
+ from typing import Tuple, List, Dict
7
+ from io import BytesIO
8
+ from PIL import Image
9
+
10
+ from pathlib import Path
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from modules import shared
14
+ from modules.deepbooru import re_special as tag_escape_pattern
15
+
16
+ # i'm not sure if it's okay to add this file to the repository
17
+ from . import dbimutils
18
+
19
+ # select a device to process
20
+ use_cpu = ('all' in shared.cmd_opts.use_cpu) or (
21
+ 'interrogate' in shared.cmd_opts.use_cpu)
22
+
23
+ if use_cpu:
24
+ tf_device_name = '/cpu:0'
25
+ else:
26
+ tf_device_name = '/gpu:0'
27
+
28
+ if shared.cmd_opts.device_id is not None:
29
+ try:
30
+ tf_device_name = f'/gpu:{int(shared.cmd_opts.device_id)}'
31
+ except ValueError:
32
+ print('--device-id is not a integer')
33
+
34
+
35
+ class Interrogator:
36
+ @staticmethod
37
+ def postprocess_tags(
38
+ tags: Dict[str, float],
39
+
40
+ threshold=0.35,
41
+ additional_tags: List[str] = [],
42
+ exclude_tags: List[str] = [],
43
+ sort_by_alphabetical_order=False,
44
+ add_confident_as_weight=False,
45
+ replace_underscore=False,
46
+ replace_underscore_excludes: List[str] = [],
47
+ escape_tag=False
48
+ ) -> Dict[str, float]:
49
+
50
+ tags = {
51
+ **{t: 1.0 for t in additional_tags},
52
+ **tags
53
+ }
54
+
55
+ # those lines are totally not "pythonic" but looks better to me
56
+ tags = {
57
+ t: c
58
+
59
+ # sort by tag name or confident
60
+ for t, c in sorted(
61
+ tags.items(),
62
+ key=lambda i: i[0 if sort_by_alphabetical_order else 1],
63
+ reverse=not sort_by_alphabetical_order
64
+ )
65
+
66
+ # filter tags
67
+ if (
68
+ c >= threshold
69
+ and t not in exclude_tags
70
+ )
71
+ }
72
+
73
+ new_tags = []
74
+ for tag in list(tags):
75
+ new_tag = tag
76
+
77
+ if replace_underscore and tag not in replace_underscore_excludes:
78
+ new_tag = new_tag.replace('_', ' ')
79
+
80
+ if escape_tag:
81
+ new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
82
+
83
+ if add_confident_as_weight:
84
+ new_tag = f'({new_tag}:{tags[tag]})'
85
+
86
+ new_tags.append((new_tag, tags[tag]))
87
+ tags = dict(new_tags)
88
+
89
+ return tags
90
+
91
+ def __init__(self, name: str) -> None:
92
+ self.name = name
93
+
94
+ def load(self):
95
+ raise NotImplementedError()
96
+
97
+ def unload(self) -> bool:
98
+ unloaded = False
99
+
100
+ if hasattr(self, 'model') and self.model is not None:
101
+ del self.model
102
+ unloaded = True
103
+ print(f'Unloaded {self.name}')
104
+
105
+ if hasattr(self, 'tags'):
106
+ del self.tags
107
+
108
+ return unloaded
109
+
110
+ def interrogate(
111
+ self,
112
+ image: Image
113
+ ) -> Tuple[
114
+ Dict[str, float], # rating confidents
115
+ Dict[str, float] # tag confidents
116
+ ]:
117
+ raise NotImplementedError()
118
+
119
+
120
+ class DeepDanbooruInterrogator(Interrogator):
121
+ def __init__(self, name: str, project_path: os.PathLike) -> None:
122
+ super().__init__(name)
123
+ self.project_path = project_path
124
+
125
+ def load(self) -> None:
126
+ print(f'Loading {self.name} from {str(self.project_path)}')
127
+
128
+ # deepdanbooru package is not include in web-sd anymore
129
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/c81d440d876dfd2ab3560410f37442ef56fc663
130
+ from launch import is_installed, run_pip
131
+ if not is_installed('deepdanbooru'):
132
+ package = os.environ.get(
133
+ 'DEEPDANBOORU_PACKAGE',
134
+ 'git+https://github.com/KichangKim/DeepDanbooru.git@d91a2963bf87c6a770d74894667e9ffa9f6de7ff'
135
+ )
136
+
137
+ run_pip(
138
+ f'install {package} tensorflow tensorflow-io', 'deepdanbooru')
139
+
140
+ import tensorflow as tf
141
+
142
+ # tensorflow maps nearly all vram by default, so we limit this
143
+ # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
144
+ # TODO: only run on the first run
145
+ for device in tf.config.experimental.list_physical_devices('GPU'):
146
+ tf.config.experimental.set_memory_growth(device, True)
147
+
148
+ with tf.device(tf_device_name):
149
+ import deepdanbooru.project as ddp
150
+
151
+ self.model = ddp.load_model_from_project(
152
+ project_path=self.project_path,
153
+ compile_model=False
154
+ )
155
+
156
+ print(f'Loaded {self.name} model from {str(self.project_path)}')
157
+
158
+ self.tags = ddp.load_tags_from_project(
159
+ project_path=self.project_path
160
+ )
161
+
162
+ def unload(self) -> bool:
163
+ # unloaded = super().unload()
164
+
165
+ # if unloaded:
166
+ # # tensorflow suck
167
+ # # https://github.com/keras-team/keras/issues/2102
168
+ # import tensorflow as tf
169
+ # tf.keras.backend.clear_session()
170
+ # gc.collect()
171
+
172
+ # return unloaded
173
+
174
+ # There is a bug in Keras where it is not possible to release a model that has been loaded into memory.
175
+ # Downgrading to keras==2.1.6 may solve the issue, but it may cause compatibility issues with other packages.
176
+ # Using subprocess to create a new process may also solve the problem, but it can be too complex (like Automatic1111 did).
177
+ # It seems that for now, the best option is to keep the model in memory, as most users use the Waifu Diffusion model with onnx.
178
+
179
+ return False
180
+
181
+ def interrogate(
182
+ self,
183
+ image: Image
184
+ ) -> Tuple[
185
+ Dict[str, float], # rating confidents
186
+ Dict[str, float] # tag confidents
187
+ ]:
188
+ # init model
189
+ if not hasattr(self, 'model') or self.model is None:
190
+ self.load()
191
+
192
+ import deepdanbooru.data as ddd
193
+
194
+ # convert an image to fit the model
195
+ image_bufs = BytesIO()
196
+ image.save(image_bufs, format='PNG')
197
+ image = ddd.load_image_for_evaluate(
198
+ image_bufs,
199
+ self.model.input_shape[2],
200
+ self.model.input_shape[1]
201
+ )
202
+
203
+ image = image.reshape((1, *image.shape[0:3]))
204
+
205
+ # evaluate model
206
+ result = self.model.predict(image)
207
+
208
+ confidents = result[0].tolist()
209
+ ratings = {}
210
+ tags = {}
211
+
212
+ for i, tag in enumerate(self.tags):
213
+ tags[tag] = confidents[i]
214
+
215
+ return ratings, tags
216
+
217
+
218
+ class WaifuDiffusionInterrogator(Interrogator):
219
+ def __init__(
220
+ self,
221
+ name: str,
222
+ model_path='model.onnx',
223
+ tags_path='selected_tags.csv',
224
+ **kwargs
225
+ ) -> None:
226
+ super().__init__(name)
227
+ self.model_path = model_path
228
+ self.tags_path = tags_path
229
+ self.kwargs = kwargs
230
+
231
+ def download(self) -> Tuple[os.PathLike, os.PathLike]:
232
+ #if model_path exists, skip download
233
+ print(self.model_path, self.tags_path)
234
+ if os.path.exists(self.model_path) and os.path.exists(self.tags_path):
235
+ return self.model_path, self.tags_path
236
+ print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
237
+
238
+ model_path = Path(hf_hub_download(
239
+ **self.kwargs, filename=self.model_path))
240
+ tags_path = Path(hf_hub_download(
241
+ **self.kwargs, filename=self.tags_path))
242
+ return model_path, tags_path
243
+
244
+ def load(self) -> None:
245
+ model_path, tags_path = self.download()
246
+
247
+ # only one of these packages should be installed at a time in any one environment
248
+ # https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
249
+ # TODO: remove old package when the environment changes?
250
+ from launch import is_installed, run_pip
251
+ if not is_installed('onnxruntime'):
252
+ package = os.environ.get(
253
+ 'ONNXRUNTIME_PACKAGE',
254
+ 'onnxruntime-gpu'
255
+ )
256
+
257
+ run_pip(f'install {package}', 'onnxruntime')
258
+
259
+ from onnxruntime import InferenceSession
260
+
261
+ # https://onnxruntime.ai/docs/execution-providers/
262
+ # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
263
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
264
+ if use_cpu:
265
+ providers.pop(0)
266
+
267
+ self.model = InferenceSession(str(model_path), providers=providers)
268
+
269
+ print(f'Loaded {self.name} model from {model_path}')
270
+
271
+ self.tags = pd.read_csv(tags_path)
272
+
273
+ def interrogate(
274
+ self,
275
+ image: Image
276
+ ) -> Tuple[
277
+ Dict[str, float], # rating confidents
278
+ Dict[str, float] # tag confidents
279
+ ]:
280
+ # init model
281
+ if not hasattr(self, 'model') or self.model is None:
282
+ self.load()
283
+
284
+ # code for converting the image and running the model is taken from the link below
285
+ # thanks, SmilingWolf!
286
+ # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
287
+
288
+ # convert an image to fit the model
289
+ _, height, _, _ = self.model.get_inputs()[0].shape
290
+
291
+ # alpha to white
292
+ image = image.convert('RGBA')
293
+ new_image = Image.new('RGBA', image.size, 'WHITE')
294
+ new_image.paste(image, mask=image)
295
+ image = new_image.convert('RGB')
296
+ image = np.asarray(image)
297
+
298
+ # PIL RGB to OpenCV BGR
299
+ image = image[:, :, ::-1]
300
+
301
+ image = dbimutils.make_square(image, height)
302
+ image = dbimutils.smart_resize(image, height)
303
+ image = image.astype(np.float32)
304
+ image = np.expand_dims(image, 0)
305
+
306
+ # evaluate model
307
+ input_name = self.model.get_inputs()[0].name
308
+ label_name = self.model.get_outputs()[0].name
309
+ confidents = self.model.run([label_name], {input_name: image})[0]
310
+
311
+ tags = self.tags[:][['name']]
312
+ tags['confidents'] = confidents[0]
313
+
314
+ # first 4 items are for rating (general, sensitive, questionable, explicit)
315
+ ratings = dict(tags[:4].values)
316
+
317
+ # rest are regular tags
318
+ tags = dict(tags[4:].values)
319
+
320
+ return ratings, tags
tagger/preset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from typing import Tuple, List, Dict
5
+ from pathlib import Path
6
+ from modules.images import sanitize_filename_part
7
+
8
+ PresetDict = Dict[str, Dict[str, any]]
9
+
10
+
11
+ class Preset:
12
+ base_dir: Path
13
+ default_filename: str
14
+ default_values: PresetDict
15
+ components: List[object]
16
+
17
+ def __init__(
18
+ self,
19
+ base_dir: os.PathLike,
20
+ default_filename='default.json'
21
+ ) -> None:
22
+ self.base_dir = Path(base_dir)
23
+ self.default_filename = default_filename
24
+ self.default_values = self.load(default_filename)[1]
25
+ self.components = []
26
+
27
+ def component(self, component_class: object, **kwargs) -> object:
28
+ # find all the top components from the Gradio context and create a path
29
+ from gradio.context import Context
30
+ parent = Context.block
31
+ paths = [kwargs['label']]
32
+
33
+ while parent is not None:
34
+ if hasattr(parent, 'label'):
35
+ paths.insert(0, parent.label)
36
+
37
+ parent = parent.parent
38
+
39
+ path = '/'.join(paths)
40
+
41
+ component = component_class(**{
42
+ **kwargs,
43
+ **self.default_values.get(path, {})
44
+ })
45
+
46
+ setattr(component, 'path', path)
47
+
48
+ self.components.append(component)
49
+ return component
50
+
51
+ def load(self, filename: str) -> Tuple[str, PresetDict]:
52
+ if not filename.endswith('.json'):
53
+ filename += '.json'
54
+
55
+ path = self.base_dir.joinpath(sanitize_filename_part(filename))
56
+ configs = {}
57
+
58
+ if path.is_file():
59
+ configs = json.loads(path.read_text())
60
+
61
+ return path, configs
62
+
63
+ def save(self, filename: str, *values) -> Tuple:
64
+ path, configs = self.load(filename)
65
+
66
+ for index, component in enumerate(self.components):
67
+ config = configs.get(component.path, {})
68
+ config['value'] = values[index]
69
+
70
+ for attr in ['visible', 'min', 'max', 'step']:
71
+ if hasattr(component, attr):
72
+ config[attr] = config.get(attr, getattr(component, attr))
73
+
74
+ configs[component.path] = config
75
+
76
+ self.base_dir.mkdir(0o777, True, True)
77
+ path.write_text(
78
+ json.dumps(configs, indent=4)
79
+ )
80
+
81
+ return 'successfully saved the preset'
82
+
83
+ def apply(self, filename: str) -> Tuple:
84
+ values = self.load(filename)[1]
85
+ outputs = []
86
+
87
+ for component in self.components:
88
+ config = values.get(component.path, {})
89
+
90
+ if 'value' in config and hasattr(component, 'choices'):
91
+ if config['value'] not in component.choices:
92
+ config['value'] = None
93
+
94
+ outputs.append(component.update(**config))
95
+
96
+ return (*outputs, 'successfully loaded the preset')
97
+
98
+ def list(self) -> List[str]:
99
+ presets = [
100
+ p.name
101
+ for p in self.base_dir.glob('*.json')
102
+ if p.is_file()
103
+ ]
104
+
105
+ if len(presets) < 1:
106
+ presets.append(self.default_filename)
107
+
108
+ return presets
tagger/ui.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+
5
+ from pathlib import Path
6
+ from glob import glob
7
+ from PIL import Image, UnidentifiedImageError
8
+
9
+ from webui import wrap_gradio_gpu_call
10
+ from modules import ui
11
+ from modules import generation_parameters_copypaste as parameters_copypaste
12
+
13
+ from tagger import format, utils
14
+ from tagger.utils import split_str
15
+ from tagger.interrogator import Interrogator
16
+
17
+
18
+ def unload_interrogators():
19
+ unloaded_models = 0
20
+
21
+ for i in utils.interrogators.values():
22
+ if i.unload():
23
+ unloaded_models = unloaded_models + 1
24
+
25
+ return [f'Successfully unload {unloaded_models} model(s)']
26
+
27
+
28
+ def on_interrogate(
29
+ image: Image,
30
+ batch_input_glob: str,
31
+ batch_input_recursive: bool,
32
+ batch_output_dir: str,
33
+ batch_output_filename_format: str,
34
+ batch_output_action_on_conflict: str,
35
+ batch_output_save_json: bool,
36
+
37
+ interrogator: str,
38
+ threshold: float,
39
+ additional_tags: str,
40
+ exclude_tags: str,
41
+ sort_by_alphabetical_order: bool,
42
+ add_confident_as_weight: bool,
43
+ replace_underscore: bool,
44
+ replace_underscore_excludes: str,
45
+ escape_tag: bool,
46
+
47
+ unload_model_after_running: bool
48
+ ):
49
+ if interrogator not in utils.interrogators:
50
+ return ['', None, None, f"'{interrogator}' is not a valid interrogator"]
51
+
52
+ interrogator: Interrogator = utils.interrogators[interrogator]
53
+
54
+ postprocess_opts = (
55
+ threshold,
56
+ split_str(additional_tags),
57
+ split_str(exclude_tags),
58
+ sort_by_alphabetical_order,
59
+ add_confident_as_weight,
60
+ replace_underscore,
61
+ split_str(replace_underscore_excludes),
62
+ escape_tag
63
+ )
64
+
65
+ # single process
66
+ if image is not None:
67
+ ratings, tags = interrogator.interrogate(image)
68
+ processed_tags = Interrogator.postprocess_tags(
69
+ tags,
70
+ *postprocess_opts
71
+ )
72
+
73
+ if unload_model_after_running:
74
+ interrogator.unload()
75
+
76
+ return [
77
+ ', '.join(processed_tags),
78
+ ratings,
79
+ tags,
80
+ ''
81
+ ]
82
+
83
+ # batch process
84
+ batch_input_glob = batch_input_glob.strip()
85
+ batch_output_dir = batch_output_dir.strip()
86
+ batch_output_filename_format = batch_output_filename_format.strip()
87
+
88
+ if batch_input_glob != '':
89
+ # if there is no glob pattern, insert it automatically
90
+ if not batch_input_glob.endswith('*'):
91
+ if not batch_input_glob.endswith('/'):
92
+ batch_input_glob += '/'
93
+ batch_input_glob += '*'
94
+
95
+ # get root directory of input glob pattern
96
+ base_dir = batch_input_glob.replace('?', '*')
97
+ base_dir = base_dir.split('/*').pop(0)
98
+
99
+ # check the input directory path
100
+ if not os.path.isdir(base_dir):
101
+ return ['', None, None, 'input path is not a directory']
102
+
103
+ # this line is moved here because some reason
104
+ # PIL.Image.registered_extensions() returns only PNG if you call too early
105
+ supported_extensions = [
106
+ e
107
+ for e, f in Image.registered_extensions().items()
108
+ if f in Image.OPEN
109
+ ]
110
+
111
+ paths = [
112
+ Path(p)
113
+ for p in glob(batch_input_glob, recursive=batch_input_recursive)
114
+ if '.' + p.split('.').pop().lower() in supported_extensions
115
+ ]
116
+
117
+ print(f'found {len(paths)} image(s)')
118
+
119
+ for path in paths:
120
+ try:
121
+ image = Image.open(path)
122
+ except UnidentifiedImageError:
123
+ # just in case, user has mysterious file...
124
+ print(f'${path} is not supported image type')
125
+ continue
126
+
127
+ # guess the output path
128
+ base_dir_last = Path(base_dir).parts[-1]
129
+ base_dir_last_idx = path.parts.index(base_dir_last)
130
+ output_dir = Path(
131
+ batch_output_dir) if batch_output_dir else Path(base_dir)
132
+ output_dir = output_dir.joinpath(
133
+ *path.parts[base_dir_last_idx + 1:]).parent
134
+
135
+ output_dir.mkdir(0o777, True, True)
136
+
137
+ # format output filename
138
+ format_info = format.Info(path, 'txt')
139
+
140
+ try:
141
+ formatted_output_filename = format.pattern.sub(
142
+ lambda m: format.format(m, format_info),
143
+ batch_output_filename_format
144
+ )
145
+ except (TypeError, ValueError) as error:
146
+ return ['', None, None, str(error)]
147
+
148
+ output_path = output_dir.joinpath(
149
+ formatted_output_filename
150
+ )
151
+
152
+ output = []
153
+
154
+ if output_path.is_file():
155
+ output.append(output_path.read_text())
156
+
157
+ if batch_output_action_on_conflict == 'ignore':
158
+ print(f'skipping {path}')
159
+ continue
160
+
161
+ ratings, tags = interrogator.interrogate(image)
162
+ processed_tags = Interrogator.postprocess_tags(
163
+ tags,
164
+ *postprocess_opts
165
+ )
166
+
167
+ # TODO: switch for less print
168
+ print(
169
+ f'found {len(processed_tags)} tags out of {len(tags)} from {path}'
170
+ )
171
+
172
+ plain_tags = ', '.join(processed_tags)
173
+
174
+ if batch_output_action_on_conflict == 'copy':
175
+ output = [plain_tags]
176
+ elif batch_output_action_on_conflict == 'prepend':
177
+ output.insert(0, plain_tags)
178
+ else:
179
+ output.append(plain_tags)
180
+
181
+ output_path.write_text(' '.join(output))
182
+
183
+ if batch_output_save_json:
184
+ output_path.with_suffix('.json').write_text(
185
+ json.dumps([ratings, tags])
186
+ )
187
+
188
+ print('all done :)')
189
+
190
+ if unload_model_after_running:
191
+ interrogator.unload()
192
+
193
+ return ['', None, None, '']
194
+
195
+
196
+ def on_ui_tabs():
197
+ with gr.Blocks(analytics_enabled=False) as tagger_interface:
198
+ with gr.Row().style(equal_height=False):
199
+ with gr.Column(variant='panel'):
200
+
201
+ # input components
202
+ with gr.Tabs():
203
+ with gr.TabItem(label='Single process'):
204
+ image = gr.Image(
205
+ label='Source',
206
+ source='upload',
207
+ interactive=True,
208
+ type="pil"
209
+ )
210
+
211
+ with gr.TabItem(label='Batch from directory'):
212
+ batch_input_glob = utils.preset.component(
213
+ gr.Textbox,
214
+ label='Input directory',
215
+ placeholder='/path/to/images or /path/to/images/**/*'
216
+ )
217
+ batch_input_recursive = utils.preset.component(
218
+ gr.Checkbox,
219
+ label='Use recursive with glob pattern'
220
+ )
221
+
222
+ batch_output_dir = utils.preset.component(
223
+ gr.Textbox,
224
+ label='Output directory',
225
+ placeholder='Leave blank to save images to the same path.'
226
+ )
227
+
228
+ batch_output_filename_format = utils.preset.component(
229
+ gr.Textbox,
230
+ label='Output filename format',
231
+ placeholder='Leave blank to use same filename as original.',
232
+ value='[name].[output_extension]'
233
+ )
234
+
235
+ import hashlib
236
+ with gr.Accordion(
237
+ label='Output filename formats',
238
+ open=False
239
+ ):
240
+ gr.Markdown(
241
+ value=f'''
242
+ ### Related to original file
243
+ - `[name]`: Original filename without extension
244
+ - `[extension]`: Original extension
245
+ - `[hash:<algorithms>]`: Original extension
246
+ Available algorithms: `{', '.join(hashlib.algorithms_available)}`
247
+
248
+ ### Related to output file
249
+ - `[output_extension]`: Output extension (has no dot)
250
+
251
+ ## Examples
252
+ ### Original filename without extension
253
+ `[name].[output_extension]`
254
+
255
+ ### Original file's hash (good for deleting duplication)
256
+ `[hash:sha1].[output_extension]`
257
+ '''
258
+ )
259
+
260
+ batch_output_action_on_conflict = utils.preset.component(
261
+ gr.Dropdown,
262
+ label='Action on exiting caption',
263
+ value='ignore',
264
+ choices=[
265
+ 'ignore',
266
+ 'copy',
267
+ 'append',
268
+ 'prepend'
269
+ ]
270
+ )
271
+
272
+ batch_output_save_json = utils.preset.component(
273
+ gr.Checkbox,
274
+ label='Save with JSON'
275
+ )
276
+
277
+ submit = gr.Button(
278
+ value='Interrogate',
279
+ variant='primary'
280
+ )
281
+
282
+ info = gr.HTML()
283
+
284
+ # preset selector
285
+ with gr.Row(variant='compact'):
286
+ available_presets = utils.preset.list()
287
+ selected_preset = gr.Dropdown(
288
+ label='Preset',
289
+ choices=available_presets,
290
+ value=available_presets[0]
291
+ )
292
+
293
+ save_preset_button = gr.Button(
294
+ value=ui.save_style_symbol
295
+ )
296
+
297
+ ui.create_refresh_button(
298
+ selected_preset,
299
+ lambda: None,
300
+ lambda: {'choices': utils.preset.list()},
301
+ 'refresh_preset'
302
+ )
303
+
304
+ # option components
305
+
306
+ # interrogator selector
307
+ with gr.Column():
308
+ with gr.Row(variant='compact'):
309
+ interrogator_names = utils.refresh_interrogators()
310
+ interrogator = utils.preset.component(
311
+ gr.Dropdown,
312
+ label='Interrogator',
313
+ choices=interrogator_names,
314
+ value=(
315
+ None
316
+ if len(interrogator_names) < 1 else
317
+ interrogator_names[-1]
318
+ )
319
+ )
320
+
321
+ ui.create_refresh_button(
322
+ interrogator,
323
+ lambda: None,
324
+ lambda: {'choices': utils.refresh_interrogators()},
325
+ 'refresh_interrogator'
326
+ )
327
+
328
+ unload_all_models = gr.Button(
329
+ value='Unload all interrogate models'
330
+ )
331
+
332
+ threshold = utils.preset.component(
333
+ gr.Slider,
334
+ label='Threshold',
335
+ minimum=0,
336
+ maximum=1,
337
+ value=0.35
338
+ )
339
+
340
+ additional_tags = utils.preset.component(
341
+ gr.Textbox,
342
+ label='Additional tags (split by comma)',
343
+ elem_id='additioanl-tags'
344
+ )
345
+
346
+ exclude_tags = utils.preset.component(
347
+ gr.Textbox,
348
+ label='Exclude tags (split by comma)',
349
+ elem_id='exclude-tags'
350
+ )
351
+
352
+ sort_by_alphabetical_order = utils.preset.component(
353
+ gr.Checkbox,
354
+ label='Sort by alphabetical order',
355
+ )
356
+ add_confident_as_weight = utils.preset.component(
357
+ gr.Checkbox,
358
+ label='Include confident of tags matches in results'
359
+ )
360
+ replace_underscore = utils.preset.component(
361
+ gr.Checkbox,
362
+ label='Use spaces instead of underscore',
363
+ value=True
364
+ )
365
+ replace_underscore_excludes = utils.preset.component(
366
+ gr.Textbox,
367
+ label='Excudes (split by comma)',
368
+ # kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400!
369
+ value='0_0, (o)_(o), +_+, +_-, ._., <o>_<o>, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||'
370
+ )
371
+ escape_tag = utils.preset.component(
372
+ gr.Checkbox,
373
+ label='Escape brackets',
374
+ )
375
+
376
+ unload_model_after_running = utils.preset.component(
377
+ gr.Checkbox,
378
+ label='Unload model after running',
379
+ )
380
+
381
+ # output components
382
+ with gr.Column(variant='panel'):
383
+ tags = gr.Textbox(
384
+ label='Tags',
385
+ placeholder='Found tags',
386
+ interactive=False
387
+ )
388
+
389
+ with gr.Row():
390
+ parameters_copypaste.bind_buttons(
391
+ parameters_copypaste.create_buttons(
392
+ ["txt2img", "img2img"],
393
+ ),
394
+ None,
395
+ tags
396
+ )
397
+
398
+ rating_confidents = gr.Label(
399
+ label='Rating confidents',
400
+ elem_id='rating-confidents'
401
+ )
402
+ tag_confidents = gr.Label(
403
+ label='Tag confidents',
404
+ elem_id='tag-confidents'
405
+ )
406
+
407
+ # register events
408
+ selected_preset.change(
409
+ fn=utils.preset.apply,
410
+ inputs=[selected_preset],
411
+ outputs=[*utils.preset.components, info]
412
+ )
413
+
414
+ save_preset_button.click(
415
+ fn=utils.preset.save,
416
+ inputs=[selected_preset, *utils.preset.components], # values only
417
+ outputs=[info]
418
+ )
419
+
420
+ unload_all_models.click(
421
+ fn=unload_interrogators,
422
+ outputs=[info]
423
+ )
424
+
425
+ for func in [image.change, submit.click]:
426
+ func(
427
+ fn=wrap_gradio_gpu_call(on_interrogate),
428
+ inputs=[
429
+ # single process
430
+ image,
431
+
432
+ # batch process
433
+ batch_input_glob,
434
+ batch_input_recursive,
435
+ batch_output_dir,
436
+ batch_output_filename_format,
437
+ batch_output_action_on_conflict,
438
+ batch_output_save_json,
439
+
440
+ # options
441
+ interrogator,
442
+ threshold,
443
+ additional_tags,
444
+ exclude_tags,
445
+ sort_by_alphabetical_order,
446
+ add_confident_as_weight,
447
+ replace_underscore,
448
+ replace_underscore_excludes,
449
+ escape_tag,
450
+
451
+ unload_model_after_running
452
+ ],
453
+ outputs=[
454
+ tags,
455
+ rating_confidents,
456
+ tag_confidents,
457
+
458
+ # contains execution time, memory usage and other stuffs...
459
+ # generated from modules.ui.wrap_gradio_call
460
+ info
461
+ ]
462
+ )
463
+
464
+ return [(tagger_interface, "Tagger", "tagger")]
tagger/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import List, Dict
4
+ from pathlib import Path
5
+
6
+ from modules import shared, scripts
7
+ from preload import default_ddp_path, default_onnx_path
8
+ from tagger.preset import Preset
9
+ from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, WaifuDiffusionInterrogator
10
+
11
+ preset = Preset(Path(scripts.basedir(), 'presets'))
12
+
13
+ interrogators: Dict[str, Interrogator] = {}
14
+
15
+
16
+ def refresh_interrogators() -> List[str]:
17
+ global interrogators
18
+ interrogators = {
19
+ 'wd14-vit-v2': WaifuDiffusionInterrogator(
20
+ 'wd14-vit-v2',
21
+ repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2',
22
+ revision='v2.0'
23
+ ),
24
+ 'wd14-convnext-v2': WaifuDiffusionInterrogator(
25
+ 'wd14-convnext-v2',
26
+ repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2',
27
+ revision='v2.0'
28
+ ),
29
+ 'wd14-swinv2-v2': WaifuDiffusionInterrogator(
30
+ 'wd14-swinv2-v2',
31
+ repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2',
32
+ revision='v2.0'
33
+ ),
34
+ 'wd14-vit-v2-git': WaifuDiffusionInterrogator(
35
+ 'wd14-vit-v2-git',
36
+ repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2'
37
+ ),
38
+ 'wd14-convnext-v2-git': WaifuDiffusionInterrogator(
39
+ 'wd14-convnext-v2-git',
40
+ repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2'
41
+ ),
42
+ 'wd14-swinv2-v2-git': WaifuDiffusionInterrogator(
43
+ 'wd14-swinv2-v2-git',
44
+ repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2'
45
+ ),
46
+ 'wd14-vit': WaifuDiffusionInterrogator(
47
+ 'wd14-vit',
48
+ repo_id='SmilingWolf/wd-v1-4-vit-tagger'),
49
+ 'wd14-convnext': WaifuDiffusionInterrogator(
50
+ 'wd14-convnext',
51
+ repo_id='SmilingWolf/wd-v1-4-convnext-tagger'
52
+ ),
53
+ #'Z3D-E621-convnext': WaifuDiffusionInterrogator(
54
+ # 'Z3D-E621-convnext',
55
+ # model_path=r'SmilingWolf/wd-v1-4-convnext-tagger',
56
+ # tags_path=r''
57
+ #),
58
+ }
59
+
60
+ # load deepdanbooru project
61
+ os.makedirs(
62
+ getattr(shared.cmd_opts, 'deepdanbooru_projects_path', default_ddp_path),
63
+ exist_ok=True
64
+ )
65
+ os.makedirs(
66
+ getattr(shared.cmd_opts, 'onnxtagger_path', default_onnx_path),
67
+ exist_ok=True
68
+ )
69
+
70
+ for path in os.scandir(shared.cmd_opts.deepdanbooru_projects_path):
71
+ if not path.is_dir():
72
+ continue
73
+
74
+ if not Path(path, 'project.json').is_file():
75
+ continue
76
+
77
+ interrogators[path.name] = DeepDanbooruInterrogator(path.name, path)
78
+ #scan for onnx models as well
79
+ for path in os.scandir(shared.cmd_opts.onnxtagger_path):
80
+ if not path.is_dir():
81
+ continue
82
+
83
+ #if no file with the extension .onnx is found, skip. If there is more than one file with that name, warn. Else put it in model_path
84
+ onnx_files = [x for x in os.scandir(path) if x.name.endswith('.onnx')]
85
+ if len(onnx_files) == 0:
86
+ print(f"Warning: {path} has no model, skipping")
87
+ continue
88
+ elif len(onnx_files) > 1:
89
+ print(f"Warning: {path} has multiple models, skipping")
90
+ continue
91
+ model_path = Path(path, onnx_files[0].name)
92
+
93
+ if not Path(path, 'tags-selected.csv').is_file():
94
+ print(f"Warning: {path} has a model but no tags-selected.csv file, skipping")
95
+ continue
96
+
97
+ interrogators[path.name] = WaifuDiffusionInterrogator(path.name,model_path=model_path, tags_path=Path(path, 'tags-selected.csv'))
98
+
99
+ return sorted(interrogators.keys())
100
+
101
+
102
+ def split_str(s: str, separator=',') -> List[str]:
103
+ return [x.strip() for x in s.split(separator) if x]