tori29umai commited on
Commit
c9dfb9e
1 Parent(s): 6477f29
Files changed (2) hide show
  1. app.py +166 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import csv
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import onnxruntime as ort
10
+ from huggingface_hub import hf_hub_download
11
+ import spaces
12
+
13
+ # 画像のサイズ設定
14
+ IMAGE_SIZE = 448
15
+
16
+ def preprocess_image(image):
17
+ image = np.array(image)
18
+ image = image[:, :, ::-1] # BGRからRGBへ変換
19
+
20
+ # 画像を正方形にするためのパディングを追加
21
+ size = max(image.shape[0:2])
22
+ pad_x = size - image.shape[1]
23
+ pad_y = size - image.shape[0]
24
+ pad_l = pad_x // 2
25
+ pad_t = pad_y // 2
26
+ image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
27
+
28
+ # サイズに合わせた補間方法を選択
29
+ interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
30
+ image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
31
+ image = image.astype(np.float32)
32
+ return image
33
+
34
+ @spaces.GPU
35
+ def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags):
36
+ thresh = 0.35
37
+ try:
38
+ image = Image.open(image_path)
39
+ image = image.convert("RGB") if image.mode != "RGB" else image
40
+ image = preprocess_image(image)
41
+ except Exception as e:
42
+ print(f"画像を読み込めません: {image_path}, エラー: {e}")
43
+ return
44
+
45
+ img = np.array([image])
46
+ prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力
47
+
48
+ # NSFW/SFW判定
49
+ tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)}
50
+ max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0))
51
+ max_sfw_score = tag_confidences.get("general", 0)
52
+ NSFW_flag = None
53
+
54
+ if max_nsfw_score > max_sfw_score:
55
+ NSFW_flag = "NSFWの可能性が高いです"
56
+ else:
57
+ NSFW_flag = "SFWの可能性が高いです"
58
+
59
+ # 版権キャラクターの可能性を評価
60
+ character_tags_with_probs = []
61
+ for i, p in enumerate(prob[4:]):
62
+ if p >= thresh and i >= len(general_tags):
63
+ tag_index = i - len(general_tags)
64
+ if tag_index < len(character_tags):
65
+ tag_name = character_tags[tag_index]
66
+ prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換
67
+ character_tags_with_probs.append((tag_name, f"{prob_percent}%"))
68
+
69
+ IP_flag = None
70
+ if character_tags_with_probs:
71
+ IP_flag = f"版権キャラクター: {character_tags_with_probs}の可能性があります"
72
+ else:
73
+ IP_flag = "版権キャラクターの可能性が低いと思われます"
74
+
75
+ # タグを生成
76
+ tag_freq = {}
77
+ undesired_tags = []
78
+ combined_tags = []
79
+ general_tag_text = ""
80
+ character_tag_text = ""
81
+ remove_underscore = True
82
+ caption_separator = ", "
83
+ general_threshold = 0.35
84
+ character_threshold = 0.35
85
+
86
+ for i, p in enumerate(prob[4:]):
87
+ if i < len(general_tags) and p >= general_threshold:
88
+ tag_name = general_tags[i]
89
+ if remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
90
+ tag_name = tag_name.replace("_", " ")
91
+
92
+ if tag_name not in undesired_tags:
93
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
94
+ general_tag_text += caption_separator + tag_name
95
+ combined_tags.append(tag_name)
96
+ elif i >= len(general_tags) and p >= character_threshold:
97
+ tag_name = character_tags[i - len(general_tags)]
98
+ if remove_underscore and len(tag_name) > 3:
99
+ tag_name = tag_name.replace("_", " ")
100
+
101
+ if tag_name not in undesired_tags:
102
+ tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
103
+ character_tag_text += caption_separator + tag_name
104
+ combined_tags.append(tag_name)
105
+
106
+ # 先頭のカンマを取る
107
+ if len(general_tag_text) > 0:
108
+ general_tag_text = general_tag_text[len(caption_separator) :]
109
+ if len(character_tag_text) > 0:
110
+ character_tag_text = character_tag_text[len(caption_separator) :]
111
+ tag_text = caption_separator.join(combined_tags)
112
+
113
+ return NSFW_flag, IP_flag, tag_text
114
+
115
+
116
+ class webui:
117
+ def __init__(self):
118
+ self.demo = gr.Blocks()
119
+
120
+ @spaces.GPU
121
+ def main(self, image_path, model_id):
122
+ print("Hugging Faceからモデルをダウンロード中")
123
+ onnx_path = hf_hub_download(model_id, "model.onnx")
124
+ csv_path = hf_hub_download(model_id, "selected_tags.csv")
125
+
126
+ print("ONNXモデルを実行中")
127
+ print(f"ONNXモデルのパス: {onnx_path}")
128
+
129
+ ort_sess = ort.InferenceSession(onnx_path)
130
+
131
+ with open(csv_path, "r", encoding="utf-8") as f:
132
+ reader = csv.reader(f)
133
+ header = next(reader)
134
+ rows = list(reader)
135
+ assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマット���期待と異なります: {header}"
136
+
137
+ rating_tags = [row[1] for row in rows if row[2] == "9"]
138
+ character_tags = [row[1] for row in rows if row[2] == "4"]
139
+ general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
140
+
141
+ NSFW_flag, IP_flag, tag_text = process_image(image_path, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, character_tags, general_tags)
142
+ return NSFW_flag, IP_flag, tag_text
143
+
144
+
145
+ def launch(self):
146
+ with self.demo:
147
+ with gr.Row():
148
+ with gr.Column():
149
+ input_image = gr.Image(type='filepath', label="Analysis Image")
150
+ model_id = gr.Textbox(label="NSFW Flag", value="SmilingWolf/wd-vit-tagger-v3")
151
+ output_0 = gr.Textbox(label="NSFW Flag")
152
+ output_1 = gr.Textbox(label="IP Flag")
153
+ output_2 = gr.Textbox(label="Tags")
154
+ submit = gr.Button(value="Start Analysis")
155
+
156
+ submit.click(
157
+ self.main,
158
+ inputs=[input_image, model_id],
159
+ outputs=[output_0, output_1, output_2]
160
+ )
161
+
162
+ self.demo.launch()
163
+
164
+ if __name__ == "__main__":
165
+ ui = webui()
166
+ ui.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ numpy
3
+ Pillow
4
+ onnxruntime
5
+ onnxruntime-gpu