Files changed (6) hide show
  1. .gitignore +1 -0
  2. .pre-commit-config.yaml +35 -0
  3. .style.yapf +5 -0
  4. README.md +1 -0
  5. app.py +122 -0
  6. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ images
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.10.1
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.812
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,6 +4,7 @@ emoji: 🏃
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  ---
 
4
  colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.6
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import functools
7
+ import os
8
+ import pathlib
9
+ import tarfile
10
+
11
+ import deepdanbooru as dd
12
+ import gradio as gr
13
+ import huggingface_hub
14
+ import numpy as np
15
+ import PIL.Image
16
+ import tensorflow as tf
17
+
18
+ TITLE = 'KichangKim/DeepDanbooru'
19
+ DESCRIPTION = 'This is an unofficial demo for https://github.com/KichangKim/DeepDanbooru.'
20
+ ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.deepdanbooru" alt="visitor badge"/></center>'
21
+
22
+ HF_TOKEN = os.environ['HF_TOKEN']
23
+ MODEL_REPO = 'hysts/DeepDanbooru'
24
+ MODEL_FILENAME = 'model-resnet_custom_v3.h5'
25
+ LABEL_FILENAME = 'tags.txt'
26
+
27
+
28
+ def parse_args() -> argparse.Namespace:
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--score-slider-step', type=float, default=0.05)
31
+ parser.add_argument('--score-threshold', type=float, default=0.5)
32
+ parser.add_argument('--share', action='store_true')
33
+ return parser.parse_args()
34
+
35
+
36
+ def load_sample_image_paths() -> list[pathlib.Path]:
37
+ image_dir = pathlib.Path('images')
38
+ if not image_dir.exists():
39
+ dataset_repo = 'hysts/sample-images-TADNE'
40
+ path = huggingface_hub.hf_hub_download(dataset_repo,
41
+ 'images.tar.gz',
42
+ repo_type='dataset',
43
+ use_auth_token=HF_TOKEN)
44
+ with tarfile.open(path) as f:
45
+ f.extractall()
46
+ return sorted(image_dir.glob('*'))
47
+
48
+
49
+ def load_model() -> tf.keras.Model:
50
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
51
+ MODEL_FILENAME,
52
+ use_auth_token=HF_TOKEN)
53
+ model = tf.keras.models.load_model(path)
54
+ return model
55
+
56
+
57
+ def load_labels() -> list[str]:
58
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
59
+ LABEL_FILENAME,
60
+ use_auth_token=HF_TOKEN)
61
+ with open(path) as f:
62
+ labels = [line.strip() for line in f.readlines()]
63
+ return labels
64
+
65
+
66
+ def predict(image: PIL.Image.Image, score_threshold: float,
67
+ model: tf.keras.Model, labels: list[str]) -> dict[str, float]:
68
+ _, height, width, _ = model.input_shape
69
+ image = np.asarray(image)
70
+ image = tf.image.resize(image,
71
+ size=(height, width),
72
+ method=tf.image.ResizeMethod.AREA,
73
+ preserve_aspect_ratio=True)
74
+ image = image.numpy()
75
+ image = dd.image.transform_and_pad_image(image, width, height)
76
+ image = image / 255.
77
+ probs = model.predict(image[None, ...])[0]
78
+ probs = probs.astype(float)
79
+ res = dict()
80
+ for prob, label in zip(probs.tolist(), labels):
81
+ if prob < score_threshold:
82
+ continue
83
+ res[label] = prob
84
+ return res
85
+
86
+
87
+ def main():
88
+ args = parse_args()
89
+
90
+ image_paths = load_sample_image_paths()
91
+ examples = [[path.as_posix(), args.score_threshold]
92
+ for path in image_paths]
93
+
94
+ model = load_model()
95
+ labels = load_labels()
96
+
97
+ func = functools.partial(predict, model=model, labels=labels)
98
+
99
+ gr.Interface(
100
+ func,
101
+ [
102
+ gr.Image(type='pil', label='Input'),
103
+ gr.Slider(0,
104
+ 1,
105
+ step=args.score_slider_step,
106
+ value=args.score_threshold,
107
+ label='Score Threshold'),
108
+ ],
109
+ gr.Label(label='Output'),
110
+ examples=examples,
111
+ title=TITLE,
112
+ description=DESCRIPTION,
113
+ article=ARTICLE,
114
+ allow_flagging='never',
115
+ ).launch(
116
+ enable_queue=True,
117
+ share=args.share,
118
+ )
119
+
120
+
121
+ if __name__ == '__main__':
122
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pillow>=9.0.0
2
+ tensorflow>=2.7.0
3
+ git+https://github.com/KichangKim/DeepDanbooru@v3-20200915-sgd-e30#egg=deepdanbooru