Upload 25 files
Browse files- .gitignore +6 -0
- README.ko.md +61 -0
- README.md +118 -0
- __pycache__/preload.cpython-310.pyc +0 -0
- docs/screenshot.png +0 -0
- javascript/tagger.js +108 -0
- preload.py +27 -0
- scripts/tagger.py +18 -0
- style.css +7 -0
- tagger/__pycache__/api.cpython-310.pyc +0 -0
- tagger/__pycache__/api_models.cpython-310.pyc +0 -0
- tagger/__pycache__/dbimutils.cpython-310.pyc +0 -0
- tagger/__pycache__/format.cpython-310.pyc +0 -0
- tagger/__pycache__/interrogator.cpython-310.pyc +0 -0
- tagger/__pycache__/preset.cpython-310.pyc +0 -0
- tagger/__pycache__/ui.cpython-310.pyc +0 -0
- tagger/__pycache__/utils.cpython-310.pyc +0 -0
- tagger/api.py +90 -0
- tagger/api_models.py +33 -0
- tagger/dbimutils.py +54 -0
- tagger/format.py +46 -0
- tagger/interrogator.py +320 -0
- tagger/preset.py +108 -0
- tagger/ui.py +464 -0
- tagger/utils.py +103 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.vscode/
|
3 |
+
.venv/
|
4 |
+
.env
|
5 |
+
|
6 |
+
presets/
|
README.ko.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[Automatic1111 웹UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)를 위한 태깅(라벨링) 확장 기능
|
2 |
+
---
|
3 |
+
DeepDanbooru 와 같은 모델을 통해 단일 또는 여러 이미지로부터 부루에서 사용하는 태그를 알아냅니다.
|
4 |
+
|
5 |
+
[You don't know how to read Korean? Read it in English here!](README.md)
|
6 |
+
|
7 |
+
## 들어가기 앞서
|
8 |
+
모델과 대부분의 코드는 제가 만들지 않았고 [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) 와 MrSmillingWolf 의 태거에서 가져왔습니다.
|
9 |
+
|
10 |
+
## 설치하기
|
11 |
+
1. *확장기능* -> *URL로부터 확장기능 설치* -> 이 레포지토리 주소 입력 -> *설치*
|
12 |
+
- 또는 이 레포지토리를 `extensions/` 디렉터리 내에 클론합니다.
|
13 |
+
```sh
|
14 |
+
$ git clone https://github.com/toriato/stable-diffusion-webui-wd14-tagger.git extensions/tagger
|
15 |
+
```
|
16 |
+
|
17 |
+
1. 모델 추가하기
|
18 |
+
- #### *MrSmilingWolf's model (a.k.a. Waifu Diffusion 1.4 tagger)*
|
19 |
+
처음 실행할 때 [HuggingFace 레포지토리](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger)로부터 자동으로 받아옵니다.
|
20 |
+
|
21 |
+
모델과 관련된 또는 추가 학습에 대한 질문은 원작자인 MrSmilingWolf#5991 으로 물어봐주세요.
|
22 |
+
|
23 |
+
- #### *DeepDanbooru*
|
24 |
+
1. 다양한 모델 파일은 아래 주소에서 찾을 수 있습니다.
|
25 |
+
- [DeepDanbooru model](https://github.com/KichangKim/DeepDanbooru/releases)
|
26 |
+
- [e621 model by 🐾Zack🐾#1984](https://discord.gg/BDFpq9Yb7K)
|
27 |
+
*(NSFW 주의!)*
|
28 |
+
|
29 |
+
1. 모델과 설정 파일이 포함된 프로젝트 폴더를 `models/deepdanbooru` 경로로 옮깁니다.
|
30 |
+
|
31 |
+
1. 파일 구조는 다음과 같습니다:
|
32 |
+
```
|
33 |
+
models/
|
34 |
+
└╴deepdanbooru/
|
35 |
+
├╴deepdanbooru-v3-20211112-sgd-e28/
|
36 |
+
│ ├╴project.json
|
37 |
+
│ └╴...
|
38 |
+
│
|
39 |
+
├╴deepdanbooru-v4-20200814-sgd-e30/
|
40 |
+
│ ├╴project.json
|
41 |
+
│ └╴...
|
42 |
+
│
|
43 |
+
├╴e621-v3-20221117-sgd-e32/
|
44 |
+
│ ├╴project.json
|
45 |
+
│ └╴...
|
46 |
+
│
|
47 |
+
...
|
48 |
+
```
|
49 |
+
|
50 |
+
1. 웹UI 를 시작하거나 재시작합니다.
|
51 |
+
- 또는 *Interrogator* 드롭다운 상자 우측에 있는 새로고침 버튼을 누릅니다.
|
52 |
+
|
53 |
+
|
54 |
+
## 스크린샷
|
55 |
+

|
56 |
+
|
57 |
+
Artwork made by [hecattaart](https://vk.com/hecattaart?w=wall-89063929_3767)
|
58 |
+
|
59 |
+
## 저작권
|
60 |
+
|
61 |
+
빌려온 코드(예: `dbimutils.py`)를 제외하고 모두 Public domain
|
README.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tagger for [Automatic1111's WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
|
2 |
+
---
|
3 |
+
Interrogate booru style tags for single or multiple image files using various models, such as DeepDanbooru.
|
4 |
+
|
5 |
+
[한국어를 사용하시나요? 여기에 한국어 설명서가 있습니다!](README.ko.md)
|
6 |
+
|
7 |
+
## Disclaimer
|
8 |
+
I didn't make any models, and most of the code was heavily borrowed from the [DeepDanbooru](https://github.com/KichangKim/DeepDanbooru) and MrSmillingWolf's tagger.
|
9 |
+
|
10 |
+
## Installation
|
11 |
+
1. *Extensions* -> *Install from URL* -> Enter URL of this repository -> Press *Install* button
|
12 |
+
- or clone this repository under `extensions/`
|
13 |
+
```sh
|
14 |
+
$ git clone https://github.com/toriato/stable-diffusion-webui-wd14-tagger.git extensions/tagger
|
15 |
+
```
|
16 |
+
|
17 |
+
1. Add interrogate model
|
18 |
+
- #### *MrSmilingWolf's model (a.k.a. Waifu Diffusion 1.4 tagger)*
|
19 |
+
Downloads automatically from the [HuggingFace repository](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger) the first time you run it.
|
20 |
+
|
21 |
+
Please ask the original author MrSmilingWolf#5991 for questions related to model or additional training.
|
22 |
+
|
23 |
+
##### ViT vs Convnext
|
24 |
+
> To make it clear: the ViT model is the one used to tag images for WD 1.4. That's why the repo was originally called like that. This one has been trained on the same data and tags, but has got no other relation to WD 1.4, aside from stemming from the same coordination effort. They were trained in parallel, and the best one at the time was selected for WD 1.4
|
25 |
+
|
26 |
+
> This particular model was trained later and might actually be slightly better than the ViT one. Difference is in the noise range tho
|
27 |
+
|
28 |
+
— [SmilingWolf](https://github.com/SmilingWolf) from [this thread](https://discord.com/channels/930499730843250783/1052283314997837955) in the [東方Project AI server](https://discord.com/invite/touhouai)
|
29 |
+
|
30 |
+
- #### *DeepDanbooru*
|
31 |
+
1. Various model files can be found below.
|
32 |
+
- [DeepDanbooru models](https://github.com/KichangKim/DeepDanbooru/releases)
|
33 |
+
- [e621 model by 🐾Zack🐾#1984](https://discord.gg/BDFpq9Yb7K)
|
34 |
+
*(link contains NSFW contents!)*
|
35 |
+
|
36 |
+
1. Move the project folder containing the model and config to `models/deepdanbooru`
|
37 |
+
|
38 |
+
1. The file structure should look like:
|
39 |
+
```
|
40 |
+
models/
|
41 |
+
└╴deepdanbooru/
|
42 |
+
├╴deepdanbooru-v3-20211112-sgd-e28/
|
43 |
+
│ ├╴project.json
|
44 |
+
│ └╴...
|
45 |
+
│
|
46 |
+
├╴deepdanbooru-v4-20200814-sgd-e30/
|
47 |
+
│ ├╴project.json
|
48 |
+
│ └╴...
|
49 |
+
│
|
50 |
+
├╴e621-v3-20221117-sgd-e32/
|
51 |
+
│ ├╴project.json
|
52 |
+
│ └╴...
|
53 |
+
│
|
54 |
+
...
|
55 |
+
```
|
56 |
+
|
57 |
+
1. Start or restart the WebUI.
|
58 |
+
- or you can press refresh button after *Interrogator* dropdown box.
|
59 |
+
|
60 |
+
|
61 |
+
## Model comparison
|
62 |
+
|
63 |
+
* Used image: [hecattaart's artwork](https://vk.com/hecattaart?w=wall-89063929_3767)
|
64 |
+
* Threshold: `0.5`
|
65 |
+
|
66 |
+
### DeepDanbooru
|
67 |
+
Used the same image as the one used in the Screenshot item
|
68 |
+
|
69 |
+
#### [`deepdanbooru-v3-20211112-sgd-e28`](https://github.com/KichangKim/DeepDanbooru/releases/tag/v3-20211112-sgd-e28)
|
70 |
+
```
|
71 |
+
1girl, animal ears, cat ears, cat tail, clothes writing, full body, rating:safe, shiba inu, shirt, shoes, simple background, sneakers, socks, solo, standing, t-shirt, tail, white background, white shirt
|
72 |
+
```
|
73 |
+
|
74 |
+
#### [`deepdanbooru-v4-20200814-sgd-e30`](https://github.com/KichangKim/DeepDanbooru/releases/tag/v4-20200814-sgd-e30)
|
75 |
+
```
|
76 |
+
1girl, animal, animal ears, bottomless, clothes writing, full body, rating:safe, shirt, shoes, short sleeves, sneakers, solo, standing, t-shirt, tail, white background, white shirt
|
77 |
+
```
|
78 |
+
|
79 |
+
#### `e621-v3-20221117-sgd-e32`
|
80 |
+
```
|
81 |
+
anthro, bottomwear, clothing, footwear, fur, hi res, mammal, shirt, shoes, shorts, simple background, sneakers, socks, solo, standing, text on clothing, text on topwear, topwear, white background
|
82 |
+
```
|
83 |
+
|
84 |
+
### Waifu Diffusion Tagger
|
85 |
+
|
86 |
+
#### [`wd14-vit`](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger)
|
87 |
+
```
|
88 |
+
1boy, animal ears, dog, furry, leg hair, male focus, shirt, shoes, simple background, socks, solo, tail, white background
|
89 |
+
```
|
90 |
+
|
91 |
+
#### [`wd14-convnext`](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger)
|
92 |
+
```
|
93 |
+
full body, furry, shirt, shoes, simple background, socks, solo, tail, white background
|
94 |
+
```
|
95 |
+
|
96 |
+
#### [`wd14-vit-v2`](https://huggingface.co/SmilingWolf/wd-v1-4-vit-tagger-v2)
|
97 |
+
```
|
98 |
+
1boy, animal ears, cat, furry, male focus, shirt, shoes, simple background, socks, solo, tail, white background
|
99 |
+
```
|
100 |
+
|
101 |
+
#### [`wd14-convnext-v2`](https://huggingface.co/SmilingWolf/wd-v1-4-convnext-tagger-v2)
|
102 |
+
```
|
103 |
+
animal focus, clothes writing, earrings, full body, meme, shirt, shoes, simple background, socks, solo, sweat, tail, white background, white shirt
|
104 |
+
```
|
105 |
+
|
106 |
+
#### [`wd14-swinv2-v2`](https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2)
|
107 |
+
```
|
108 |
+
1boy, arm hair, black footwear, cat, dirty, full body, furry, leg hair, male focus, shirt, shoes, simple background, socks, solo, standing, tail, white background, white shirt
|
109 |
+
```
|
110 |
+
|
111 |
+
## Screenshot
|
112 |
+

|
113 |
+
|
114 |
+
Artwork made by [hecattaart](https://vk.com/hecattaart?w=wall-89063929_3767)
|
115 |
+
|
116 |
+
## Copyright
|
117 |
+
|
118 |
+
Public domain, except borrowed parts (e.g. `dbimutils.py`)
|
__pycache__/preload.cpython-310.pyc
ADDED
Binary file (733 Bytes). View file
|
|
docs/screenshot.png
ADDED
![]() |
javascript/tagger.js
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* wait until element is loaded and returns
|
3 |
+
* @param {string} selector
|
4 |
+
* @param {number} timeout
|
5 |
+
* @param {Element} $rootElement
|
6 |
+
* @returns {Promise<HTMLElement>}
|
7 |
+
*/
|
8 |
+
function waitQuerySelector(selector, timeout = 5000, $rootElement = gradioApp()) {
|
9 |
+
return new Promise((resolve, reject) => {
|
10 |
+
const element = $rootElement.querySelector(selector)
|
11 |
+
if (document.querySelector(element)) {
|
12 |
+
return resolve(element)
|
13 |
+
}
|
14 |
+
|
15 |
+
let timeoutId
|
16 |
+
|
17 |
+
const observer = new MutationObserver(() => {
|
18 |
+
const element = $rootElement.querySelector(selector)
|
19 |
+
if (!element) {
|
20 |
+
return
|
21 |
+
}
|
22 |
+
|
23 |
+
if (timeoutId) {
|
24 |
+
clearInterval(timeoutId)
|
25 |
+
}
|
26 |
+
|
27 |
+
observer.disconnect()
|
28 |
+
resolve(element)
|
29 |
+
})
|
30 |
+
|
31 |
+
timeoutId = setTimeout(() => {
|
32 |
+
observer.disconnect()
|
33 |
+
reject(new Error(`timeout, cannot find element by '${selector}'`))
|
34 |
+
}, timeout)
|
35 |
+
|
36 |
+
observer.observe($rootElement, {
|
37 |
+
childList: true,
|
38 |
+
subtree: true
|
39 |
+
})
|
40 |
+
})
|
41 |
+
}
|
42 |
+
|
43 |
+
document.addEventListener('DOMContentLoaded', () => {
|
44 |
+
Promise.all([
|
45 |
+
// option texts
|
46 |
+
waitQuerySelector('#additioanl-tags'),
|
47 |
+
waitQuerySelector('#exclude-tags'),
|
48 |
+
|
49 |
+
// tag-confident labels
|
50 |
+
waitQuerySelector('#rating-confidents'),
|
51 |
+
waitQuerySelector('#tag-confidents')
|
52 |
+
]).then(elements => {
|
53 |
+
|
54 |
+
const $additionalTags = elements[0].querySelector('textarea')
|
55 |
+
const $excludeTags = elements[1].querySelector('textarea')
|
56 |
+
const $ratingConfidents = elements[2]
|
57 |
+
const $tagConfidents = elements[3]
|
58 |
+
|
59 |
+
let $selectedTextarea = $additionalTags
|
60 |
+
|
61 |
+
/**
|
62 |
+
* @this {HTMLElement}
|
63 |
+
* @param {MouseEvent} e
|
64 |
+
* @listens document#click
|
65 |
+
*/
|
66 |
+
function onClickTextarea(e) {
|
67 |
+
$selectedTextarea = this
|
68 |
+
}
|
69 |
+
|
70 |
+
$additionalTags.addEventListener('click', onClickTextarea)
|
71 |
+
$excludeTags.addEventListener('click', onClickTextarea)
|
72 |
+
|
73 |
+
/**
|
74 |
+
* @this {HTMLElement}
|
75 |
+
* @param {MouseEvent} e
|
76 |
+
* @listens document#click
|
77 |
+
*/
|
78 |
+
function onClickLabels(e) {
|
79 |
+
// find clicked label item's wrapper element
|
80 |
+
const $tag = e.target.closest('.output-label > div:not(:first-child)')
|
81 |
+
if (!$tag) {
|
82 |
+
return
|
83 |
+
}
|
84 |
+
|
85 |
+
/** @type {string} */
|
86 |
+
const tag = $tag.querySelector('.leading-snug').textContent
|
87 |
+
|
88 |
+
// ignore if tag is already exist in textbox
|
89 |
+
const escapedTag = tag.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
|
90 |
+
const pattern = new RegExp(`(^|,)\\s{0,}${escapedTag}\\s{0,}($|,)`)
|
91 |
+
if (pattern.test($selectedTextarea.value)) {
|
92 |
+
return
|
93 |
+
}
|
94 |
+
|
95 |
+
if ($selectedTextarea.value !== '') {
|
96 |
+
$selectedTextarea.value += ', '
|
97 |
+
}
|
98 |
+
|
99 |
+
$selectedTextarea.value += tag
|
100 |
+
}
|
101 |
+
|
102 |
+
$ratingConfidents.addEventListener('click', onClickLabels)
|
103 |
+
$tagConfidents.addEventListener('click', onClickLabels)
|
104 |
+
|
105 |
+
}).catch(err => {
|
106 |
+
console.error(err)
|
107 |
+
})
|
108 |
+
})
|
preload.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
from modules.shared import models_path
|
5 |
+
|
6 |
+
default_ddp_path = Path(models_path, 'deepdanbooru')
|
7 |
+
default_onnx_path = Path(models_path, 'TaggerOnnx')
|
8 |
+
|
9 |
+
|
10 |
+
def preload(parser: ArgumentParser):
|
11 |
+
# default deepdanbooru use different paths:
|
12 |
+
# models/deepbooru and models/torch_deepdanbooru
|
13 |
+
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/c81d440d876dfd2ab3560410f37442ef56fc6632
|
14 |
+
|
15 |
+
parser.add_argument(
|
16 |
+
'--deepdanbooru-projects-path',
|
17 |
+
type=str,
|
18 |
+
help='Path to directory with DeepDanbooru project(s).',
|
19 |
+
default=default_ddp_path
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
'--onnxtagger-path',
|
23 |
+
type=str,
|
24 |
+
help='Path to directory with DeepDanbooru project(s).',
|
25 |
+
default=default_onnx_path
|
26 |
+
)
|
27 |
+
|
scripts/tagger.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageFile
|
2 |
+
|
3 |
+
from modules import script_callbacks
|
4 |
+
from tagger.api import on_app_started
|
5 |
+
from tagger.ui import on_ui_tabs
|
6 |
+
|
7 |
+
|
8 |
+
# if you do not initialize the Image object
|
9 |
+
# Image.registered_extensions() returns only PNG
|
10 |
+
Image.init()
|
11 |
+
|
12 |
+
# PIL spits errors when loading a truncated image by default
|
13 |
+
# https://pillow.readthedocs.io/en/stable/reference/ImageFile.html#PIL.ImageFile.LOAD_TRUNCATED_IMAGES
|
14 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
15 |
+
|
16 |
+
|
17 |
+
script_callbacks.on_app_started(on_app_started)
|
18 |
+
script_callbacks.on_ui_tabs(on_ui_tabs)
|
style.css
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#rating-confidents .output-label>div:not(:first-child) {
|
2 |
+
cursor: pointer;
|
3 |
+
}
|
4 |
+
|
5 |
+
#tag-confidents .output-label>div:not(:first-child) {
|
6 |
+
cursor: pointer;
|
7 |
+
}
|
tagger/__pycache__/api.cpython-310.pyc
ADDED
Binary file (2.99 kB). View file
|
|
tagger/__pycache__/api_models.cpython-310.pyc
ADDED
Binary file (1.32 kB). View file
|
|
tagger/__pycache__/dbimutils.cpython-310.pyc
ADDED
Binary file (1.69 kB). View file
|
|
tagger/__pycache__/format.cpython-310.pyc
ADDED
Binary file (1.68 kB). View file
|
|
tagger/__pycache__/interrogator.cpython-310.pyc
ADDED
Binary file (7.69 kB). View file
|
|
tagger/__pycache__/preset.cpython-310.pyc
ADDED
Binary file (3.07 kB). View file
|
|
tagger/__pycache__/ui.cpython-310.pyc
ADDED
Binary file (9.12 kB). View file
|
|
tagger/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.64 kB). View file
|
|
tagger/api.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from threading import Lock
|
3 |
+
from secrets import compare_digest
|
4 |
+
|
5 |
+
from modules import shared
|
6 |
+
from modules.api.api import decode_base64_to_image
|
7 |
+
from modules.call_queue import queue_lock
|
8 |
+
from fastapi import FastAPI, Depends, HTTPException
|
9 |
+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
10 |
+
|
11 |
+
from tagger import utils
|
12 |
+
from tagger import api_models as models
|
13 |
+
|
14 |
+
|
15 |
+
class Api:
|
16 |
+
def __init__(self, app: FastAPI, queue_lock: Lock, prefix: str = None) -> None:
|
17 |
+
if shared.cmd_opts.api_auth:
|
18 |
+
self.credentials = dict()
|
19 |
+
for auth in shared.cmd_opts.api_auth.split(","):
|
20 |
+
user, password = auth.split(":")
|
21 |
+
self.credentials[user] = password
|
22 |
+
|
23 |
+
self.app = app
|
24 |
+
self.queue_lock = queue_lock
|
25 |
+
self.prefix = prefix
|
26 |
+
|
27 |
+
self.add_api_route(
|
28 |
+
'interrogate',
|
29 |
+
self.endpoint_interrogate,
|
30 |
+
methods=['POST'],
|
31 |
+
response_model=models.TaggerInterrogateResponse
|
32 |
+
)
|
33 |
+
|
34 |
+
self.add_api_route(
|
35 |
+
'interrogators',
|
36 |
+
self.endpoint_interrogators,
|
37 |
+
methods=['GET'],
|
38 |
+
response_model=models.InterrogatorsResponse
|
39 |
+
)
|
40 |
+
|
41 |
+
def auth(self, creds: HTTPBasicCredentials = Depends(HTTPBasic())):
|
42 |
+
if creds.username in self.credentials:
|
43 |
+
if compare_digest(creds.password, self.credentials[creds.username]):
|
44 |
+
return True
|
45 |
+
|
46 |
+
raise HTTPException(
|
47 |
+
status_code=401,
|
48 |
+
detail="Incorrect username or password",
|
49 |
+
headers={
|
50 |
+
"WWW-Authenticate": "Basic"
|
51 |
+
})
|
52 |
+
|
53 |
+
def add_api_route(self, path: str, endpoint: Callable, **kwargs):
|
54 |
+
if self.prefix:
|
55 |
+
path = f'{self.prefix}/{path}'
|
56 |
+
|
57 |
+
if shared.cmd_opts.api_auth:
|
58 |
+
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
|
59 |
+
return self.app.add_api_route(path, endpoint, **kwargs)
|
60 |
+
|
61 |
+
def endpoint_interrogate(self, req: models.TaggerInterrogateRequest):
|
62 |
+
if req.image is None:
|
63 |
+
raise HTTPException(404, 'Image not found')
|
64 |
+
|
65 |
+
if req.model not in utils.interrogators.keys():
|
66 |
+
raise HTTPException(404, 'Model not found')
|
67 |
+
|
68 |
+
image = decode_base64_to_image(req.image)
|
69 |
+
interrogator = utils.interrogators[req.model]
|
70 |
+
|
71 |
+
with self.queue_lock:
|
72 |
+
ratings, tags = interrogator.interrogate(image)
|
73 |
+
|
74 |
+
return models.TaggerInterrogateResponse(
|
75 |
+
caption={
|
76 |
+
**ratings,
|
77 |
+
**interrogator.postprocess_tags(
|
78 |
+
tags,
|
79 |
+
req.threshold
|
80 |
+
)
|
81 |
+
})
|
82 |
+
|
83 |
+
def endpoint_interrogators(self):
|
84 |
+
return models.InterrogatorsResponse(
|
85 |
+
models=list(utils.interrogators.keys())
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
def on_app_started(_, app: FastAPI):
|
90 |
+
Api(app, queue_lock, '/tagger/v1')
|
tagger/api_models.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
|
3 |
+
from modules.api import models as sd_models
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
|
7 |
+
class TaggerInterrogateRequest(sd_models.InterrogateRequest):
|
8 |
+
model: str = Field(
|
9 |
+
title='Model',
|
10 |
+
description='The interrogate model used.'
|
11 |
+
)
|
12 |
+
|
13 |
+
threshold: float = Field(
|
14 |
+
default=0.35,
|
15 |
+
title='Threshold',
|
16 |
+
description='',
|
17 |
+
ge=0,
|
18 |
+
le=1
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class TaggerInterrogateResponse(BaseModel):
|
23 |
+
caption: Dict[str, float] = Field(
|
24 |
+
title='Caption',
|
25 |
+
description='The generated caption for the image.'
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class InterrogatorsResponse(BaseModel):
|
30 |
+
models: List[str] = Field(
|
31 |
+
title='Models',
|
32 |
+
description=''
|
33 |
+
)
|
tagger/dbimutils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DanBooru IMage Utility functions
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def smart_imread(img, flag=cv2.IMREAD_UNCHANGED):
|
9 |
+
if img.endswith(".gif"):
|
10 |
+
img = Image.open(img)
|
11 |
+
img = img.convert("RGB")
|
12 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
13 |
+
else:
|
14 |
+
img = cv2.imread(img, flag)
|
15 |
+
return img
|
16 |
+
|
17 |
+
|
18 |
+
def smart_24bit(img):
|
19 |
+
if img.dtype is np.dtype(np.uint16):
|
20 |
+
img = (img / 257).astype(np.uint8)
|
21 |
+
|
22 |
+
if len(img.shape) == 2:
|
23 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
24 |
+
elif img.shape[2] == 4:
|
25 |
+
trans_mask = img[:, :, 3] == 0
|
26 |
+
img[trans_mask] = [255, 255, 255, 255]
|
27 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
28 |
+
return img
|
29 |
+
|
30 |
+
|
31 |
+
def make_square(img, target_size):
|
32 |
+
old_size = img.shape[:2]
|
33 |
+
desired_size = max(old_size)
|
34 |
+
desired_size = max(desired_size, target_size)
|
35 |
+
|
36 |
+
delta_w = desired_size - old_size[1]
|
37 |
+
delta_h = desired_size - old_size[0]
|
38 |
+
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
39 |
+
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
40 |
+
|
41 |
+
color = [255, 255, 255]
|
42 |
+
new_im = cv2.copyMakeBorder(
|
43 |
+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
|
44 |
+
)
|
45 |
+
return new_im
|
46 |
+
|
47 |
+
|
48 |
+
def smart_resize(img, size):
|
49 |
+
# Assumes the image has already gone through make_square
|
50 |
+
if img.shape[0] > size:
|
51 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
52 |
+
elif img.shape[0] < size:
|
53 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
54 |
+
return img
|
tagger/format.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import hashlib
|
3 |
+
|
4 |
+
from typing import Dict, Callable, NamedTuple
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
|
8 |
+
class Info(NamedTuple):
|
9 |
+
path: Path
|
10 |
+
output_ext: str
|
11 |
+
|
12 |
+
|
13 |
+
def hash(i: Info, algo='sha1') -> str:
|
14 |
+
try:
|
15 |
+
hash = hashlib.new(algo)
|
16 |
+
except ImportError:
|
17 |
+
raise ValueError(f"'{algo}' is invalid hash algorithm")
|
18 |
+
|
19 |
+
# TODO: is okay to hash large image?
|
20 |
+
with open(i.path, 'rb') as file:
|
21 |
+
hash.update(file.read())
|
22 |
+
|
23 |
+
return hash.hexdigest()
|
24 |
+
|
25 |
+
|
26 |
+
pattern = re.compile(r'\[([\w:]+)\]')
|
27 |
+
|
28 |
+
# all function must returns string or raise TypeError or ValueError
|
29 |
+
# other errors will cause the extension error
|
30 |
+
available_formats: Dict[str, Callable] = {
|
31 |
+
'name': lambda i: i.path.stem,
|
32 |
+
'extension': lambda i: i.path.suffix[1:],
|
33 |
+
'hash': hash,
|
34 |
+
|
35 |
+
'output_extension': lambda i: i.output_ext
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def format(match: re.Match, info: Info) -> str:
|
40 |
+
matches = match[1].split(':')
|
41 |
+
name, args = matches[0], matches[1:]
|
42 |
+
|
43 |
+
if name not in available_formats:
|
44 |
+
return match[0]
|
45 |
+
|
46 |
+
return available_formats[name](info, *args)
|
tagger/interrogator.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from typing import Tuple, List, Dict
|
7 |
+
from io import BytesIO
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
|
13 |
+
from modules import shared
|
14 |
+
from modules.deepbooru import re_special as tag_escape_pattern
|
15 |
+
|
16 |
+
# i'm not sure if it's okay to add this file to the repository
|
17 |
+
from . import dbimutils
|
18 |
+
|
19 |
+
# select a device to process
|
20 |
+
use_cpu = ('all' in shared.cmd_opts.use_cpu) or (
|
21 |
+
'interrogate' in shared.cmd_opts.use_cpu)
|
22 |
+
|
23 |
+
if use_cpu:
|
24 |
+
tf_device_name = '/cpu:0'
|
25 |
+
else:
|
26 |
+
tf_device_name = '/gpu:0'
|
27 |
+
|
28 |
+
if shared.cmd_opts.device_id is not None:
|
29 |
+
try:
|
30 |
+
tf_device_name = f'/gpu:{int(shared.cmd_opts.device_id)}'
|
31 |
+
except ValueError:
|
32 |
+
print('--device-id is not a integer')
|
33 |
+
|
34 |
+
|
35 |
+
class Interrogator:
|
36 |
+
@staticmethod
|
37 |
+
def postprocess_tags(
|
38 |
+
tags: Dict[str, float],
|
39 |
+
|
40 |
+
threshold=0.35,
|
41 |
+
additional_tags: List[str] = [],
|
42 |
+
exclude_tags: List[str] = [],
|
43 |
+
sort_by_alphabetical_order=False,
|
44 |
+
add_confident_as_weight=False,
|
45 |
+
replace_underscore=False,
|
46 |
+
replace_underscore_excludes: List[str] = [],
|
47 |
+
escape_tag=False
|
48 |
+
) -> Dict[str, float]:
|
49 |
+
|
50 |
+
tags = {
|
51 |
+
**{t: 1.0 for t in additional_tags},
|
52 |
+
**tags
|
53 |
+
}
|
54 |
+
|
55 |
+
# those lines are totally not "pythonic" but looks better to me
|
56 |
+
tags = {
|
57 |
+
t: c
|
58 |
+
|
59 |
+
# sort by tag name or confident
|
60 |
+
for t, c in sorted(
|
61 |
+
tags.items(),
|
62 |
+
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
|
63 |
+
reverse=not sort_by_alphabetical_order
|
64 |
+
)
|
65 |
+
|
66 |
+
# filter tags
|
67 |
+
if (
|
68 |
+
c >= threshold
|
69 |
+
and t not in exclude_tags
|
70 |
+
)
|
71 |
+
}
|
72 |
+
|
73 |
+
new_tags = []
|
74 |
+
for tag in list(tags):
|
75 |
+
new_tag = tag
|
76 |
+
|
77 |
+
if replace_underscore and tag not in replace_underscore_excludes:
|
78 |
+
new_tag = new_tag.replace('_', ' ')
|
79 |
+
|
80 |
+
if escape_tag:
|
81 |
+
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)
|
82 |
+
|
83 |
+
if add_confident_as_weight:
|
84 |
+
new_tag = f'({new_tag}:{tags[tag]})'
|
85 |
+
|
86 |
+
new_tags.append((new_tag, tags[tag]))
|
87 |
+
tags = dict(new_tags)
|
88 |
+
|
89 |
+
return tags
|
90 |
+
|
91 |
+
def __init__(self, name: str) -> None:
|
92 |
+
self.name = name
|
93 |
+
|
94 |
+
def load(self):
|
95 |
+
raise NotImplementedError()
|
96 |
+
|
97 |
+
def unload(self) -> bool:
|
98 |
+
unloaded = False
|
99 |
+
|
100 |
+
if hasattr(self, 'model') and self.model is not None:
|
101 |
+
del self.model
|
102 |
+
unloaded = True
|
103 |
+
print(f'Unloaded {self.name}')
|
104 |
+
|
105 |
+
if hasattr(self, 'tags'):
|
106 |
+
del self.tags
|
107 |
+
|
108 |
+
return unloaded
|
109 |
+
|
110 |
+
def interrogate(
|
111 |
+
self,
|
112 |
+
image: Image
|
113 |
+
) -> Tuple[
|
114 |
+
Dict[str, float], # rating confidents
|
115 |
+
Dict[str, float] # tag confidents
|
116 |
+
]:
|
117 |
+
raise NotImplementedError()
|
118 |
+
|
119 |
+
|
120 |
+
class DeepDanbooruInterrogator(Interrogator):
|
121 |
+
def __init__(self, name: str, project_path: os.PathLike) -> None:
|
122 |
+
super().__init__(name)
|
123 |
+
self.project_path = project_path
|
124 |
+
|
125 |
+
def load(self) -> None:
|
126 |
+
print(f'Loading {self.name} from {str(self.project_path)}')
|
127 |
+
|
128 |
+
# deepdanbooru package is not include in web-sd anymore
|
129 |
+
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/c81d440d876dfd2ab3560410f37442ef56fc663
|
130 |
+
from launch import is_installed, run_pip
|
131 |
+
if not is_installed('deepdanbooru'):
|
132 |
+
package = os.environ.get(
|
133 |
+
'DEEPDANBOORU_PACKAGE',
|
134 |
+
'git+https://github.com/KichangKim/DeepDanbooru.git@d91a2963bf87c6a770d74894667e9ffa9f6de7ff'
|
135 |
+
)
|
136 |
+
|
137 |
+
run_pip(
|
138 |
+
f'install {package} tensorflow tensorflow-io', 'deepdanbooru')
|
139 |
+
|
140 |
+
import tensorflow as tf
|
141 |
+
|
142 |
+
# tensorflow maps nearly all vram by default, so we limit this
|
143 |
+
# https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth
|
144 |
+
# TODO: only run on the first run
|
145 |
+
for device in tf.config.experimental.list_physical_devices('GPU'):
|
146 |
+
tf.config.experimental.set_memory_growth(device, True)
|
147 |
+
|
148 |
+
with tf.device(tf_device_name):
|
149 |
+
import deepdanbooru.project as ddp
|
150 |
+
|
151 |
+
self.model = ddp.load_model_from_project(
|
152 |
+
project_path=self.project_path,
|
153 |
+
compile_model=False
|
154 |
+
)
|
155 |
+
|
156 |
+
print(f'Loaded {self.name} model from {str(self.project_path)}')
|
157 |
+
|
158 |
+
self.tags = ddp.load_tags_from_project(
|
159 |
+
project_path=self.project_path
|
160 |
+
)
|
161 |
+
|
162 |
+
def unload(self) -> bool:
|
163 |
+
# unloaded = super().unload()
|
164 |
+
|
165 |
+
# if unloaded:
|
166 |
+
# # tensorflow suck
|
167 |
+
# # https://github.com/keras-team/keras/issues/2102
|
168 |
+
# import tensorflow as tf
|
169 |
+
# tf.keras.backend.clear_session()
|
170 |
+
# gc.collect()
|
171 |
+
|
172 |
+
# return unloaded
|
173 |
+
|
174 |
+
# There is a bug in Keras where it is not possible to release a model that has been loaded into memory.
|
175 |
+
# Downgrading to keras==2.1.6 may solve the issue, but it may cause compatibility issues with other packages.
|
176 |
+
# Using subprocess to create a new process may also solve the problem, but it can be too complex (like Automatic1111 did).
|
177 |
+
# It seems that for now, the best option is to keep the model in memory, as most users use the Waifu Diffusion model with onnx.
|
178 |
+
|
179 |
+
return False
|
180 |
+
|
181 |
+
def interrogate(
|
182 |
+
self,
|
183 |
+
image: Image
|
184 |
+
) -> Tuple[
|
185 |
+
Dict[str, float], # rating confidents
|
186 |
+
Dict[str, float] # tag confidents
|
187 |
+
]:
|
188 |
+
# init model
|
189 |
+
if not hasattr(self, 'model') or self.model is None:
|
190 |
+
self.load()
|
191 |
+
|
192 |
+
import deepdanbooru.data as ddd
|
193 |
+
|
194 |
+
# convert an image to fit the model
|
195 |
+
image_bufs = BytesIO()
|
196 |
+
image.save(image_bufs, format='PNG')
|
197 |
+
image = ddd.load_image_for_evaluate(
|
198 |
+
image_bufs,
|
199 |
+
self.model.input_shape[2],
|
200 |
+
self.model.input_shape[1]
|
201 |
+
)
|
202 |
+
|
203 |
+
image = image.reshape((1, *image.shape[0:3]))
|
204 |
+
|
205 |
+
# evaluate model
|
206 |
+
result = self.model.predict(image)
|
207 |
+
|
208 |
+
confidents = result[0].tolist()
|
209 |
+
ratings = {}
|
210 |
+
tags = {}
|
211 |
+
|
212 |
+
for i, tag in enumerate(self.tags):
|
213 |
+
tags[tag] = confidents[i]
|
214 |
+
|
215 |
+
return ratings, tags
|
216 |
+
|
217 |
+
|
218 |
+
class WaifuDiffusionInterrogator(Interrogator):
|
219 |
+
def __init__(
|
220 |
+
self,
|
221 |
+
name: str,
|
222 |
+
model_path='model.onnx',
|
223 |
+
tags_path='selected_tags.csv',
|
224 |
+
**kwargs
|
225 |
+
) -> None:
|
226 |
+
super().__init__(name)
|
227 |
+
self.model_path = model_path
|
228 |
+
self.tags_path = tags_path
|
229 |
+
self.kwargs = kwargs
|
230 |
+
|
231 |
+
def download(self) -> Tuple[os.PathLike, os.PathLike]:
|
232 |
+
#if model_path exists, skip download
|
233 |
+
print(self.model_path, self.tags_path)
|
234 |
+
if os.path.exists(self.model_path) and os.path.exists(self.tags_path):
|
235 |
+
return self.model_path, self.tags_path
|
236 |
+
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")
|
237 |
+
|
238 |
+
model_path = Path(hf_hub_download(
|
239 |
+
**self.kwargs, filename=self.model_path))
|
240 |
+
tags_path = Path(hf_hub_download(
|
241 |
+
**self.kwargs, filename=self.tags_path))
|
242 |
+
return model_path, tags_path
|
243 |
+
|
244 |
+
def load(self) -> None:
|
245 |
+
model_path, tags_path = self.download()
|
246 |
+
|
247 |
+
# only one of these packages should be installed at a time in any one environment
|
248 |
+
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
|
249 |
+
# TODO: remove old package when the environment changes?
|
250 |
+
from launch import is_installed, run_pip
|
251 |
+
if not is_installed('onnxruntime'):
|
252 |
+
package = os.environ.get(
|
253 |
+
'ONNXRUNTIME_PACKAGE',
|
254 |
+
'onnxruntime-gpu'
|
255 |
+
)
|
256 |
+
|
257 |
+
run_pip(f'install {package}', 'onnxruntime')
|
258 |
+
|
259 |
+
from onnxruntime import InferenceSession
|
260 |
+
|
261 |
+
# https://onnxruntime.ai/docs/execution-providers/
|
262 |
+
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
|
263 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
264 |
+
if use_cpu:
|
265 |
+
providers.pop(0)
|
266 |
+
|
267 |
+
self.model = InferenceSession(str(model_path), providers=providers)
|
268 |
+
|
269 |
+
print(f'Loaded {self.name} model from {model_path}')
|
270 |
+
|
271 |
+
self.tags = pd.read_csv(tags_path)
|
272 |
+
|
273 |
+
def interrogate(
|
274 |
+
self,
|
275 |
+
image: Image
|
276 |
+
) -> Tuple[
|
277 |
+
Dict[str, float], # rating confidents
|
278 |
+
Dict[str, float] # tag confidents
|
279 |
+
]:
|
280 |
+
# init model
|
281 |
+
if not hasattr(self, 'model') or self.model is None:
|
282 |
+
self.load()
|
283 |
+
|
284 |
+
# code for converting the image and running the model is taken from the link below
|
285 |
+
# thanks, SmilingWolf!
|
286 |
+
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py
|
287 |
+
|
288 |
+
# convert an image to fit the model
|
289 |
+
_, height, _, _ = self.model.get_inputs()[0].shape
|
290 |
+
|
291 |
+
# alpha to white
|
292 |
+
image = image.convert('RGBA')
|
293 |
+
new_image = Image.new('RGBA', image.size, 'WHITE')
|
294 |
+
new_image.paste(image, mask=image)
|
295 |
+
image = new_image.convert('RGB')
|
296 |
+
image = np.asarray(image)
|
297 |
+
|
298 |
+
# PIL RGB to OpenCV BGR
|
299 |
+
image = image[:, :, ::-1]
|
300 |
+
|
301 |
+
image = dbimutils.make_square(image, height)
|
302 |
+
image = dbimutils.smart_resize(image, height)
|
303 |
+
image = image.astype(np.float32)
|
304 |
+
image = np.expand_dims(image, 0)
|
305 |
+
|
306 |
+
# evaluate model
|
307 |
+
input_name = self.model.get_inputs()[0].name
|
308 |
+
label_name = self.model.get_outputs()[0].name
|
309 |
+
confidents = self.model.run([label_name], {input_name: image})[0]
|
310 |
+
|
311 |
+
tags = self.tags[:][['name']]
|
312 |
+
tags['confidents'] = confidents[0]
|
313 |
+
|
314 |
+
# first 4 items are for rating (general, sensitive, questionable, explicit)
|
315 |
+
ratings = dict(tags[:4].values)
|
316 |
+
|
317 |
+
# rest are regular tags
|
318 |
+
tags = dict(tags[4:].values)
|
319 |
+
|
320 |
+
return ratings, tags
|
tagger/preset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from typing import Tuple, List, Dict
|
5 |
+
from pathlib import Path
|
6 |
+
from modules.images import sanitize_filename_part
|
7 |
+
|
8 |
+
PresetDict = Dict[str, Dict[str, any]]
|
9 |
+
|
10 |
+
|
11 |
+
class Preset:
|
12 |
+
base_dir: Path
|
13 |
+
default_filename: str
|
14 |
+
default_values: PresetDict
|
15 |
+
components: List[object]
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
base_dir: os.PathLike,
|
20 |
+
default_filename='default.json'
|
21 |
+
) -> None:
|
22 |
+
self.base_dir = Path(base_dir)
|
23 |
+
self.default_filename = default_filename
|
24 |
+
self.default_values = self.load(default_filename)[1]
|
25 |
+
self.components = []
|
26 |
+
|
27 |
+
def component(self, component_class: object, **kwargs) -> object:
|
28 |
+
# find all the top components from the Gradio context and create a path
|
29 |
+
from gradio.context import Context
|
30 |
+
parent = Context.block
|
31 |
+
paths = [kwargs['label']]
|
32 |
+
|
33 |
+
while parent is not None:
|
34 |
+
if hasattr(parent, 'label'):
|
35 |
+
paths.insert(0, parent.label)
|
36 |
+
|
37 |
+
parent = parent.parent
|
38 |
+
|
39 |
+
path = '/'.join(paths)
|
40 |
+
|
41 |
+
component = component_class(**{
|
42 |
+
**kwargs,
|
43 |
+
**self.default_values.get(path, {})
|
44 |
+
})
|
45 |
+
|
46 |
+
setattr(component, 'path', path)
|
47 |
+
|
48 |
+
self.components.append(component)
|
49 |
+
return component
|
50 |
+
|
51 |
+
def load(self, filename: str) -> Tuple[str, PresetDict]:
|
52 |
+
if not filename.endswith('.json'):
|
53 |
+
filename += '.json'
|
54 |
+
|
55 |
+
path = self.base_dir.joinpath(sanitize_filename_part(filename))
|
56 |
+
configs = {}
|
57 |
+
|
58 |
+
if path.is_file():
|
59 |
+
configs = json.loads(path.read_text())
|
60 |
+
|
61 |
+
return path, configs
|
62 |
+
|
63 |
+
def save(self, filename: str, *values) -> Tuple:
|
64 |
+
path, configs = self.load(filename)
|
65 |
+
|
66 |
+
for index, component in enumerate(self.components):
|
67 |
+
config = configs.get(component.path, {})
|
68 |
+
config['value'] = values[index]
|
69 |
+
|
70 |
+
for attr in ['visible', 'min', 'max', 'step']:
|
71 |
+
if hasattr(component, attr):
|
72 |
+
config[attr] = config.get(attr, getattr(component, attr))
|
73 |
+
|
74 |
+
configs[component.path] = config
|
75 |
+
|
76 |
+
self.base_dir.mkdir(0o777, True, True)
|
77 |
+
path.write_text(
|
78 |
+
json.dumps(configs, indent=4)
|
79 |
+
)
|
80 |
+
|
81 |
+
return 'successfully saved the preset'
|
82 |
+
|
83 |
+
def apply(self, filename: str) -> Tuple:
|
84 |
+
values = self.load(filename)[1]
|
85 |
+
outputs = []
|
86 |
+
|
87 |
+
for component in self.components:
|
88 |
+
config = values.get(component.path, {})
|
89 |
+
|
90 |
+
if 'value' in config and hasattr(component, 'choices'):
|
91 |
+
if config['value'] not in component.choices:
|
92 |
+
config['value'] = None
|
93 |
+
|
94 |
+
outputs.append(component.update(**config))
|
95 |
+
|
96 |
+
return (*outputs, 'successfully loaded the preset')
|
97 |
+
|
98 |
+
def list(self) -> List[str]:
|
99 |
+
presets = [
|
100 |
+
p.name
|
101 |
+
for p in self.base_dir.glob('*.json')
|
102 |
+
if p.is_file()
|
103 |
+
]
|
104 |
+
|
105 |
+
if len(presets) < 1:
|
106 |
+
presets.append(self.default_filename)
|
107 |
+
|
108 |
+
return presets
|
tagger/ui.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
from glob import glob
|
7 |
+
from PIL import Image, UnidentifiedImageError
|
8 |
+
|
9 |
+
from webui import wrap_gradio_gpu_call
|
10 |
+
from modules import ui
|
11 |
+
from modules import generation_parameters_copypaste as parameters_copypaste
|
12 |
+
|
13 |
+
from tagger import format, utils
|
14 |
+
from tagger.utils import split_str
|
15 |
+
from tagger.interrogator import Interrogator
|
16 |
+
|
17 |
+
|
18 |
+
def unload_interrogators():
|
19 |
+
unloaded_models = 0
|
20 |
+
|
21 |
+
for i in utils.interrogators.values():
|
22 |
+
if i.unload():
|
23 |
+
unloaded_models = unloaded_models + 1
|
24 |
+
|
25 |
+
return [f'Successfully unload {unloaded_models} model(s)']
|
26 |
+
|
27 |
+
|
28 |
+
def on_interrogate(
|
29 |
+
image: Image,
|
30 |
+
batch_input_glob: str,
|
31 |
+
batch_input_recursive: bool,
|
32 |
+
batch_output_dir: str,
|
33 |
+
batch_output_filename_format: str,
|
34 |
+
batch_output_action_on_conflict: str,
|
35 |
+
batch_output_save_json: bool,
|
36 |
+
|
37 |
+
interrogator: str,
|
38 |
+
threshold: float,
|
39 |
+
additional_tags: str,
|
40 |
+
exclude_tags: str,
|
41 |
+
sort_by_alphabetical_order: bool,
|
42 |
+
add_confident_as_weight: bool,
|
43 |
+
replace_underscore: bool,
|
44 |
+
replace_underscore_excludes: str,
|
45 |
+
escape_tag: bool,
|
46 |
+
|
47 |
+
unload_model_after_running: bool
|
48 |
+
):
|
49 |
+
if interrogator not in utils.interrogators:
|
50 |
+
return ['', None, None, f"'{interrogator}' is not a valid interrogator"]
|
51 |
+
|
52 |
+
interrogator: Interrogator = utils.interrogators[interrogator]
|
53 |
+
|
54 |
+
postprocess_opts = (
|
55 |
+
threshold,
|
56 |
+
split_str(additional_tags),
|
57 |
+
split_str(exclude_tags),
|
58 |
+
sort_by_alphabetical_order,
|
59 |
+
add_confident_as_weight,
|
60 |
+
replace_underscore,
|
61 |
+
split_str(replace_underscore_excludes),
|
62 |
+
escape_tag
|
63 |
+
)
|
64 |
+
|
65 |
+
# single process
|
66 |
+
if image is not None:
|
67 |
+
ratings, tags = interrogator.interrogate(image)
|
68 |
+
processed_tags = Interrogator.postprocess_tags(
|
69 |
+
tags,
|
70 |
+
*postprocess_opts
|
71 |
+
)
|
72 |
+
|
73 |
+
if unload_model_after_running:
|
74 |
+
interrogator.unload()
|
75 |
+
|
76 |
+
return [
|
77 |
+
', '.join(processed_tags),
|
78 |
+
ratings,
|
79 |
+
tags,
|
80 |
+
''
|
81 |
+
]
|
82 |
+
|
83 |
+
# batch process
|
84 |
+
batch_input_glob = batch_input_glob.strip()
|
85 |
+
batch_output_dir = batch_output_dir.strip()
|
86 |
+
batch_output_filename_format = batch_output_filename_format.strip()
|
87 |
+
|
88 |
+
if batch_input_glob != '':
|
89 |
+
# if there is no glob pattern, insert it automatically
|
90 |
+
if not batch_input_glob.endswith('*'):
|
91 |
+
if not batch_input_glob.endswith('/'):
|
92 |
+
batch_input_glob += '/'
|
93 |
+
batch_input_glob += '*'
|
94 |
+
|
95 |
+
# get root directory of input glob pattern
|
96 |
+
base_dir = batch_input_glob.replace('?', '*')
|
97 |
+
base_dir = base_dir.split('/*').pop(0)
|
98 |
+
|
99 |
+
# check the input directory path
|
100 |
+
if not os.path.isdir(base_dir):
|
101 |
+
return ['', None, None, 'input path is not a directory']
|
102 |
+
|
103 |
+
# this line is moved here because some reason
|
104 |
+
# PIL.Image.registered_extensions() returns only PNG if you call too early
|
105 |
+
supported_extensions = [
|
106 |
+
e
|
107 |
+
for e, f in Image.registered_extensions().items()
|
108 |
+
if f in Image.OPEN
|
109 |
+
]
|
110 |
+
|
111 |
+
paths = [
|
112 |
+
Path(p)
|
113 |
+
for p in glob(batch_input_glob, recursive=batch_input_recursive)
|
114 |
+
if '.' + p.split('.').pop().lower() in supported_extensions
|
115 |
+
]
|
116 |
+
|
117 |
+
print(f'found {len(paths)} image(s)')
|
118 |
+
|
119 |
+
for path in paths:
|
120 |
+
try:
|
121 |
+
image = Image.open(path)
|
122 |
+
except UnidentifiedImageError:
|
123 |
+
# just in case, user has mysterious file...
|
124 |
+
print(f'${path} is not supported image type')
|
125 |
+
continue
|
126 |
+
|
127 |
+
# guess the output path
|
128 |
+
base_dir_last = Path(base_dir).parts[-1]
|
129 |
+
base_dir_last_idx = path.parts.index(base_dir_last)
|
130 |
+
output_dir = Path(
|
131 |
+
batch_output_dir) if batch_output_dir else Path(base_dir)
|
132 |
+
output_dir = output_dir.joinpath(
|
133 |
+
*path.parts[base_dir_last_idx + 1:]).parent
|
134 |
+
|
135 |
+
output_dir.mkdir(0o777, True, True)
|
136 |
+
|
137 |
+
# format output filename
|
138 |
+
format_info = format.Info(path, 'txt')
|
139 |
+
|
140 |
+
try:
|
141 |
+
formatted_output_filename = format.pattern.sub(
|
142 |
+
lambda m: format.format(m, format_info),
|
143 |
+
batch_output_filename_format
|
144 |
+
)
|
145 |
+
except (TypeError, ValueError) as error:
|
146 |
+
return ['', None, None, str(error)]
|
147 |
+
|
148 |
+
output_path = output_dir.joinpath(
|
149 |
+
formatted_output_filename
|
150 |
+
)
|
151 |
+
|
152 |
+
output = []
|
153 |
+
|
154 |
+
if output_path.is_file():
|
155 |
+
output.append(output_path.read_text())
|
156 |
+
|
157 |
+
if batch_output_action_on_conflict == 'ignore':
|
158 |
+
print(f'skipping {path}')
|
159 |
+
continue
|
160 |
+
|
161 |
+
ratings, tags = interrogator.interrogate(image)
|
162 |
+
processed_tags = Interrogator.postprocess_tags(
|
163 |
+
tags,
|
164 |
+
*postprocess_opts
|
165 |
+
)
|
166 |
+
|
167 |
+
# TODO: switch for less print
|
168 |
+
print(
|
169 |
+
f'found {len(processed_tags)} tags out of {len(tags)} from {path}'
|
170 |
+
)
|
171 |
+
|
172 |
+
plain_tags = ', '.join(processed_tags)
|
173 |
+
|
174 |
+
if batch_output_action_on_conflict == 'copy':
|
175 |
+
output = [plain_tags]
|
176 |
+
elif batch_output_action_on_conflict == 'prepend':
|
177 |
+
output.insert(0, plain_tags)
|
178 |
+
else:
|
179 |
+
output.append(plain_tags)
|
180 |
+
|
181 |
+
output_path.write_text(' '.join(output))
|
182 |
+
|
183 |
+
if batch_output_save_json:
|
184 |
+
output_path.with_suffix('.json').write_text(
|
185 |
+
json.dumps([ratings, tags])
|
186 |
+
)
|
187 |
+
|
188 |
+
print('all done :)')
|
189 |
+
|
190 |
+
if unload_model_after_running:
|
191 |
+
interrogator.unload()
|
192 |
+
|
193 |
+
return ['', None, None, '']
|
194 |
+
|
195 |
+
|
196 |
+
def on_ui_tabs():
|
197 |
+
with gr.Blocks(analytics_enabled=False) as tagger_interface:
|
198 |
+
with gr.Row().style(equal_height=False):
|
199 |
+
with gr.Column(variant='panel'):
|
200 |
+
|
201 |
+
# input components
|
202 |
+
with gr.Tabs():
|
203 |
+
with gr.TabItem(label='Single process'):
|
204 |
+
image = gr.Image(
|
205 |
+
label='Source',
|
206 |
+
source='upload',
|
207 |
+
interactive=True,
|
208 |
+
type="pil"
|
209 |
+
)
|
210 |
+
|
211 |
+
with gr.TabItem(label='Batch from directory'):
|
212 |
+
batch_input_glob = utils.preset.component(
|
213 |
+
gr.Textbox,
|
214 |
+
label='Input directory',
|
215 |
+
placeholder='/path/to/images or /path/to/images/**/*'
|
216 |
+
)
|
217 |
+
batch_input_recursive = utils.preset.component(
|
218 |
+
gr.Checkbox,
|
219 |
+
label='Use recursive with glob pattern'
|
220 |
+
)
|
221 |
+
|
222 |
+
batch_output_dir = utils.preset.component(
|
223 |
+
gr.Textbox,
|
224 |
+
label='Output directory',
|
225 |
+
placeholder='Leave blank to save images to the same path.'
|
226 |
+
)
|
227 |
+
|
228 |
+
batch_output_filename_format = utils.preset.component(
|
229 |
+
gr.Textbox,
|
230 |
+
label='Output filename format',
|
231 |
+
placeholder='Leave blank to use same filename as original.',
|
232 |
+
value='[name].[output_extension]'
|
233 |
+
)
|
234 |
+
|
235 |
+
import hashlib
|
236 |
+
with gr.Accordion(
|
237 |
+
label='Output filename formats',
|
238 |
+
open=False
|
239 |
+
):
|
240 |
+
gr.Markdown(
|
241 |
+
value=f'''
|
242 |
+
### Related to original file
|
243 |
+
- `[name]`: Original filename without extension
|
244 |
+
- `[extension]`: Original extension
|
245 |
+
- `[hash:<algorithms>]`: Original extension
|
246 |
+
Available algorithms: `{', '.join(hashlib.algorithms_available)}`
|
247 |
+
|
248 |
+
### Related to output file
|
249 |
+
- `[output_extension]`: Output extension (has no dot)
|
250 |
+
|
251 |
+
## Examples
|
252 |
+
### Original filename without extension
|
253 |
+
`[name].[output_extension]`
|
254 |
+
|
255 |
+
### Original file's hash (good for deleting duplication)
|
256 |
+
`[hash:sha1].[output_extension]`
|
257 |
+
'''
|
258 |
+
)
|
259 |
+
|
260 |
+
batch_output_action_on_conflict = utils.preset.component(
|
261 |
+
gr.Dropdown,
|
262 |
+
label='Action on exiting caption',
|
263 |
+
value='ignore',
|
264 |
+
choices=[
|
265 |
+
'ignore',
|
266 |
+
'copy',
|
267 |
+
'append',
|
268 |
+
'prepend'
|
269 |
+
]
|
270 |
+
)
|
271 |
+
|
272 |
+
batch_output_save_json = utils.preset.component(
|
273 |
+
gr.Checkbox,
|
274 |
+
label='Save with JSON'
|
275 |
+
)
|
276 |
+
|
277 |
+
submit = gr.Button(
|
278 |
+
value='Interrogate',
|
279 |
+
variant='primary'
|
280 |
+
)
|
281 |
+
|
282 |
+
info = gr.HTML()
|
283 |
+
|
284 |
+
# preset selector
|
285 |
+
with gr.Row(variant='compact'):
|
286 |
+
available_presets = utils.preset.list()
|
287 |
+
selected_preset = gr.Dropdown(
|
288 |
+
label='Preset',
|
289 |
+
choices=available_presets,
|
290 |
+
value=available_presets[0]
|
291 |
+
)
|
292 |
+
|
293 |
+
save_preset_button = gr.Button(
|
294 |
+
value=ui.save_style_symbol
|
295 |
+
)
|
296 |
+
|
297 |
+
ui.create_refresh_button(
|
298 |
+
selected_preset,
|
299 |
+
lambda: None,
|
300 |
+
lambda: {'choices': utils.preset.list()},
|
301 |
+
'refresh_preset'
|
302 |
+
)
|
303 |
+
|
304 |
+
# option components
|
305 |
+
|
306 |
+
# interrogator selector
|
307 |
+
with gr.Column():
|
308 |
+
with gr.Row(variant='compact'):
|
309 |
+
interrogator_names = utils.refresh_interrogators()
|
310 |
+
interrogator = utils.preset.component(
|
311 |
+
gr.Dropdown,
|
312 |
+
label='Interrogator',
|
313 |
+
choices=interrogator_names,
|
314 |
+
value=(
|
315 |
+
None
|
316 |
+
if len(interrogator_names) < 1 else
|
317 |
+
interrogator_names[-1]
|
318 |
+
)
|
319 |
+
)
|
320 |
+
|
321 |
+
ui.create_refresh_button(
|
322 |
+
interrogator,
|
323 |
+
lambda: None,
|
324 |
+
lambda: {'choices': utils.refresh_interrogators()},
|
325 |
+
'refresh_interrogator'
|
326 |
+
)
|
327 |
+
|
328 |
+
unload_all_models = gr.Button(
|
329 |
+
value='Unload all interrogate models'
|
330 |
+
)
|
331 |
+
|
332 |
+
threshold = utils.preset.component(
|
333 |
+
gr.Slider,
|
334 |
+
label='Threshold',
|
335 |
+
minimum=0,
|
336 |
+
maximum=1,
|
337 |
+
value=0.35
|
338 |
+
)
|
339 |
+
|
340 |
+
additional_tags = utils.preset.component(
|
341 |
+
gr.Textbox,
|
342 |
+
label='Additional tags (split by comma)',
|
343 |
+
elem_id='additioanl-tags'
|
344 |
+
)
|
345 |
+
|
346 |
+
exclude_tags = utils.preset.component(
|
347 |
+
gr.Textbox,
|
348 |
+
label='Exclude tags (split by comma)',
|
349 |
+
elem_id='exclude-tags'
|
350 |
+
)
|
351 |
+
|
352 |
+
sort_by_alphabetical_order = utils.preset.component(
|
353 |
+
gr.Checkbox,
|
354 |
+
label='Sort by alphabetical order',
|
355 |
+
)
|
356 |
+
add_confident_as_weight = utils.preset.component(
|
357 |
+
gr.Checkbox,
|
358 |
+
label='Include confident of tags matches in results'
|
359 |
+
)
|
360 |
+
replace_underscore = utils.preset.component(
|
361 |
+
gr.Checkbox,
|
362 |
+
label='Use spaces instead of underscore',
|
363 |
+
value=True
|
364 |
+
)
|
365 |
+
replace_underscore_excludes = utils.preset.component(
|
366 |
+
gr.Textbox,
|
367 |
+
label='Excudes (split by comma)',
|
368 |
+
# kaomoji from WD 1.4 tagger csv. thanks, Meow-San#5400!
|
369 |
+
value='0_0, (o)_(o), +_+, +_-, ._., <o>_<o>, <|>_<|>, =_=, >_<, 3_3, 6_9, >_o, @_@, ^_^, o_o, u_u, x_x, |_|, ||_||'
|
370 |
+
)
|
371 |
+
escape_tag = utils.preset.component(
|
372 |
+
gr.Checkbox,
|
373 |
+
label='Escape brackets',
|
374 |
+
)
|
375 |
+
|
376 |
+
unload_model_after_running = utils.preset.component(
|
377 |
+
gr.Checkbox,
|
378 |
+
label='Unload model after running',
|
379 |
+
)
|
380 |
+
|
381 |
+
# output components
|
382 |
+
with gr.Column(variant='panel'):
|
383 |
+
tags = gr.Textbox(
|
384 |
+
label='Tags',
|
385 |
+
placeholder='Found tags',
|
386 |
+
interactive=False
|
387 |
+
)
|
388 |
+
|
389 |
+
with gr.Row():
|
390 |
+
parameters_copypaste.bind_buttons(
|
391 |
+
parameters_copypaste.create_buttons(
|
392 |
+
["txt2img", "img2img"],
|
393 |
+
),
|
394 |
+
None,
|
395 |
+
tags
|
396 |
+
)
|
397 |
+
|
398 |
+
rating_confidents = gr.Label(
|
399 |
+
label='Rating confidents',
|
400 |
+
elem_id='rating-confidents'
|
401 |
+
)
|
402 |
+
tag_confidents = gr.Label(
|
403 |
+
label='Tag confidents',
|
404 |
+
elem_id='tag-confidents'
|
405 |
+
)
|
406 |
+
|
407 |
+
# register events
|
408 |
+
selected_preset.change(
|
409 |
+
fn=utils.preset.apply,
|
410 |
+
inputs=[selected_preset],
|
411 |
+
outputs=[*utils.preset.components, info]
|
412 |
+
)
|
413 |
+
|
414 |
+
save_preset_button.click(
|
415 |
+
fn=utils.preset.save,
|
416 |
+
inputs=[selected_preset, *utils.preset.components], # values only
|
417 |
+
outputs=[info]
|
418 |
+
)
|
419 |
+
|
420 |
+
unload_all_models.click(
|
421 |
+
fn=unload_interrogators,
|
422 |
+
outputs=[info]
|
423 |
+
)
|
424 |
+
|
425 |
+
for func in [image.change, submit.click]:
|
426 |
+
func(
|
427 |
+
fn=wrap_gradio_gpu_call(on_interrogate),
|
428 |
+
inputs=[
|
429 |
+
# single process
|
430 |
+
image,
|
431 |
+
|
432 |
+
# batch process
|
433 |
+
batch_input_glob,
|
434 |
+
batch_input_recursive,
|
435 |
+
batch_output_dir,
|
436 |
+
batch_output_filename_format,
|
437 |
+
batch_output_action_on_conflict,
|
438 |
+
batch_output_save_json,
|
439 |
+
|
440 |
+
# options
|
441 |
+
interrogator,
|
442 |
+
threshold,
|
443 |
+
additional_tags,
|
444 |
+
exclude_tags,
|
445 |
+
sort_by_alphabetical_order,
|
446 |
+
add_confident_as_weight,
|
447 |
+
replace_underscore,
|
448 |
+
replace_underscore_excludes,
|
449 |
+
escape_tag,
|
450 |
+
|
451 |
+
unload_model_after_running
|
452 |
+
],
|
453 |
+
outputs=[
|
454 |
+
tags,
|
455 |
+
rating_confidents,
|
456 |
+
tag_confidents,
|
457 |
+
|
458 |
+
# contains execution time, memory usage and other stuffs...
|
459 |
+
# generated from modules.ui.wrap_gradio_call
|
460 |
+
info
|
461 |
+
]
|
462 |
+
)
|
463 |
+
|
464 |
+
return [(tagger_interface, "Tagger", "tagger")]
|
tagger/utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from typing import List, Dict
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from modules import shared, scripts
|
7 |
+
from preload import default_ddp_path, default_onnx_path
|
8 |
+
from tagger.preset import Preset
|
9 |
+
from tagger.interrogator import Interrogator, DeepDanbooruInterrogator, WaifuDiffusionInterrogator
|
10 |
+
|
11 |
+
preset = Preset(Path(scripts.basedir(), 'presets'))
|
12 |
+
|
13 |
+
interrogators: Dict[str, Interrogator] = {}
|
14 |
+
|
15 |
+
|
16 |
+
def refresh_interrogators() -> List[str]:
|
17 |
+
global interrogators
|
18 |
+
interrogators = {
|
19 |
+
'wd14-vit-v2': WaifuDiffusionInterrogator(
|
20 |
+
'wd14-vit-v2',
|
21 |
+
repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2',
|
22 |
+
revision='v2.0'
|
23 |
+
),
|
24 |
+
'wd14-convnext-v2': WaifuDiffusionInterrogator(
|
25 |
+
'wd14-convnext-v2',
|
26 |
+
repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2',
|
27 |
+
revision='v2.0'
|
28 |
+
),
|
29 |
+
'wd14-swinv2-v2': WaifuDiffusionInterrogator(
|
30 |
+
'wd14-swinv2-v2',
|
31 |
+
repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2',
|
32 |
+
revision='v2.0'
|
33 |
+
),
|
34 |
+
'wd14-vit-v2-git': WaifuDiffusionInterrogator(
|
35 |
+
'wd14-vit-v2-git',
|
36 |
+
repo_id='SmilingWolf/wd-v1-4-vit-tagger-v2'
|
37 |
+
),
|
38 |
+
'wd14-convnext-v2-git': WaifuDiffusionInterrogator(
|
39 |
+
'wd14-convnext-v2-git',
|
40 |
+
repo_id='SmilingWolf/wd-v1-4-convnext-tagger-v2'
|
41 |
+
),
|
42 |
+
'wd14-swinv2-v2-git': WaifuDiffusionInterrogator(
|
43 |
+
'wd14-swinv2-v2-git',
|
44 |
+
repo_id='SmilingWolf/wd-v1-4-swinv2-tagger-v2'
|
45 |
+
),
|
46 |
+
'wd14-vit': WaifuDiffusionInterrogator(
|
47 |
+
'wd14-vit',
|
48 |
+
repo_id='SmilingWolf/wd-v1-4-vit-tagger'),
|
49 |
+
'wd14-convnext': WaifuDiffusionInterrogator(
|
50 |
+
'wd14-convnext',
|
51 |
+
repo_id='SmilingWolf/wd-v1-4-convnext-tagger'
|
52 |
+
),
|
53 |
+
#'Z3D-E621-convnext': WaifuDiffusionInterrogator(
|
54 |
+
# 'Z3D-E621-convnext',
|
55 |
+
# model_path=r'SmilingWolf/wd-v1-4-convnext-tagger',
|
56 |
+
# tags_path=r''
|
57 |
+
#),
|
58 |
+
}
|
59 |
+
|
60 |
+
# load deepdanbooru project
|
61 |
+
os.makedirs(
|
62 |
+
getattr(shared.cmd_opts, 'deepdanbooru_projects_path', default_ddp_path),
|
63 |
+
exist_ok=True
|
64 |
+
)
|
65 |
+
os.makedirs(
|
66 |
+
getattr(shared.cmd_opts, 'onnxtagger_path', default_onnx_path),
|
67 |
+
exist_ok=True
|
68 |
+
)
|
69 |
+
|
70 |
+
for path in os.scandir(shared.cmd_opts.deepdanbooru_projects_path):
|
71 |
+
if not path.is_dir():
|
72 |
+
continue
|
73 |
+
|
74 |
+
if not Path(path, 'project.json').is_file():
|
75 |
+
continue
|
76 |
+
|
77 |
+
interrogators[path.name] = DeepDanbooruInterrogator(path.name, path)
|
78 |
+
#scan for onnx models as well
|
79 |
+
for path in os.scandir(shared.cmd_opts.onnxtagger_path):
|
80 |
+
if not path.is_dir():
|
81 |
+
continue
|
82 |
+
|
83 |
+
#if no file with the extension .onnx is found, skip. If there is more than one file with that name, warn. Else put it in model_path
|
84 |
+
onnx_files = [x for x in os.scandir(path) if x.name.endswith('.onnx')]
|
85 |
+
if len(onnx_files) == 0:
|
86 |
+
print(f"Warning: {path} has no model, skipping")
|
87 |
+
continue
|
88 |
+
elif len(onnx_files) > 1:
|
89 |
+
print(f"Warning: {path} has multiple models, skipping")
|
90 |
+
continue
|
91 |
+
model_path = Path(path, onnx_files[0].name)
|
92 |
+
|
93 |
+
if not Path(path, 'tags-selected.csv').is_file():
|
94 |
+
print(f"Warning: {path} has a model but no tags-selected.csv file, skipping")
|
95 |
+
continue
|
96 |
+
|
97 |
+
interrogators[path.name] = WaifuDiffusionInterrogator(path.name,model_path=model_path, tags_path=Path(path, 'tags-selected.csv'))
|
98 |
+
|
99 |
+
return sorted(interrogators.keys())
|
100 |
+
|
101 |
+
|
102 |
+
def split_str(s: str, separator=',') -> List[str]:
|
103 |
+
return [x.strip() for x in s.split(separator) if x]
|