Spaces:
Running
Running
Add support for using the Llama3 model to reorganize WD Tagger into a readable article.
Browse files- app.py +188 -4
- requirements.txt +11 -1
app.py
CHANGED
@@ -10,6 +10,7 @@ from PIL import Image
|
|
10 |
import traceback
|
11 |
import tempfile
|
12 |
import zipfile
|
|
|
13 |
from datetime import datetime
|
14 |
|
15 |
TITLE = "WaifuDiffusion Tagger"
|
@@ -41,6 +42,10 @@ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
|
|
41 |
MODEL_FILENAME = "model.onnx"
|
42 |
LABEL_FILENAME = "selected_tags.csv"
|
43 |
|
|
|
|
|
|
|
|
|
44 |
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
45 |
kaomojis = [
|
46 |
"0_0",
|
@@ -102,6 +107,159 @@ def mcut_threshold(probs):
|
|
102 |
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
103 |
return thresh
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
class Predictor:
|
107 |
def __init__(self):
|
@@ -189,6 +347,7 @@ class Predictor:
|
|
189 |
character_thresh,
|
190 |
character_mcut_enabled,
|
191 |
characters_merge_enabled,
|
|
|
192 |
additional_tags_prepend,
|
193 |
additional_tags_append,
|
194 |
):
|
@@ -206,6 +365,9 @@ class Predictor:
|
|
206 |
|
207 |
tag_results.clear()
|
208 |
|
|
|
|
|
|
|
209 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
210 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
211 |
if prepend_list and append_list:
|
@@ -265,6 +427,13 @@ class Predictor:
|
|
265 |
sorted_general_list = [item for item in sorted_general_list if item not in append_list]
|
266 |
|
267 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
270 |
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
@@ -284,6 +453,10 @@ class Predictor:
|
|
284 |
# Get file name from lookup
|
285 |
taggers_zip.write(info["path"], arcname=info["name"])
|
286 |
download.append(downloadZipPath)
|
|
|
|
|
|
|
|
|
287 |
|
288 |
return download, sorted_general_strings, rating, character_res, general_res
|
289 |
|
@@ -343,6 +516,11 @@ def main():
|
|
343 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
344 |
]
|
345 |
|
|
|
|
|
|
|
|
|
|
|
346 |
with gr.Blocks(title=TITLE) as demo:
|
347 |
gr.Markdown(
|
348 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
@@ -361,7 +539,7 @@ def main():
|
|
361 |
|
362 |
model_repo = gr.Dropdown(
|
363 |
dropdown_list,
|
364 |
-
value=
|
365 |
label="Model",
|
366 |
)
|
367 |
with gr.Row():
|
@@ -398,6 +576,13 @@ def main():
|
|
398 |
label="Merge characters into the string output",
|
399 |
scale=1,
|
400 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
with gr.Row():
|
402 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
403 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
@@ -411,6 +596,7 @@ def main():
|
|
411 |
character_thresh,
|
412 |
character_mcut_enabled,
|
413 |
characters_merge_enabled,
|
|
|
414 |
additional_tags_prepend,
|
415 |
additional_tags_append,
|
416 |
],
|
@@ -454,6 +640,7 @@ def main():
|
|
454 |
character_thresh,
|
455 |
character_mcut_enabled,
|
456 |
characters_merge_enabled,
|
|
|
457 |
additional_tags_prepend,
|
458 |
additional_tags_append,
|
459 |
],
|
@@ -469,9 +656,6 @@ def main():
|
|
469 |
general_mcut_enabled,
|
470 |
character_thresh,
|
471 |
character_mcut_enabled,
|
472 |
-
characters_merge_enabled,
|
473 |
-
additional_tags_prepend,
|
474 |
-
additional_tags_append,
|
475 |
],
|
476 |
)
|
477 |
|
|
|
10 |
import traceback
|
11 |
import tempfile
|
12 |
import zipfile
|
13 |
+
import re
|
14 |
from datetime import datetime
|
15 |
|
16 |
TITLE = "WaifuDiffusion Tagger"
|
|
|
42 |
MODEL_FILENAME = "model.onnx"
|
43 |
LABEL_FILENAME = "selected_tags.csv"
|
44 |
|
45 |
+
# LLAMA model
|
46 |
+
META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
|
47 |
+
META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
|
48 |
+
|
49 |
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
|
50 |
kaomojis = [
|
51 |
"0_0",
|
|
|
107 |
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
108 |
return thresh
|
109 |
|
110 |
+
class Llama3Reorganize:
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
repoId: str,
|
114 |
+
device: str = None,
|
115 |
+
loadModel: bool = False,
|
116 |
+
):
|
117 |
+
"""Initializes the Llama model.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
repoId: LLAMA model repo.
|
121 |
+
device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
|
122 |
+
ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
|
123 |
+
localFilesOnly: If True, avoid downloading the file and return the path to the
|
124 |
+
local cached file if it exists.
|
125 |
+
"""
|
126 |
+
self.modelPath = self.download_model(repoId)
|
127 |
+
|
128 |
+
if device is None:
|
129 |
+
import torch
|
130 |
+
self.totalVram = 0
|
131 |
+
if torch.cuda.is_available():
|
132 |
+
try:
|
133 |
+
deviceId = torch.cuda.current_device()
|
134 |
+
self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
|
135 |
+
except Exception as e:
|
136 |
+
print(traceback.format_exc())
|
137 |
+
print("Error detect vram: " + str(e))
|
138 |
+
device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
|
139 |
+
else:
|
140 |
+
device = "cpu"
|
141 |
+
|
142 |
+
self.device = device
|
143 |
+
self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
|
144 |
+
|
145 |
+
if loadModel:
|
146 |
+
self.load_model()
|
147 |
+
|
148 |
+
def download_model(self, repoId):
|
149 |
+
import warnings
|
150 |
+
import requests
|
151 |
+
allowPatterns = [
|
152 |
+
"config.json",
|
153 |
+
"generation_config.json",
|
154 |
+
"model.bin",
|
155 |
+
"pytorch_model.bin",
|
156 |
+
"pytorch_model.bin.index.json",
|
157 |
+
"pytorch_model-*.bin",
|
158 |
+
"sentencepiece.bpe.model",
|
159 |
+
"tokenizer.json",
|
160 |
+
"tokenizer_config.json",
|
161 |
+
"shared_vocabulary.txt",
|
162 |
+
"shared_vocabulary.json",
|
163 |
+
"special_tokens_map.json",
|
164 |
+
"spiece.model",
|
165 |
+
"vocab.json",
|
166 |
+
"model.safetensors",
|
167 |
+
"model-*.safetensors",
|
168 |
+
"model.safetensors.index.json",
|
169 |
+
"quantize_config.json",
|
170 |
+
"tokenizer.model",
|
171 |
+
"vocabulary.json",
|
172 |
+
"preprocessor_config.json",
|
173 |
+
"added_tokens.json"
|
174 |
+
]
|
175 |
+
|
176 |
+
kwargs = {"allow_patterns": allowPatterns,}
|
177 |
+
|
178 |
+
try:
|
179 |
+
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
180 |
+
except (
|
181 |
+
huggingface_hub.utils.HfHubHTTPError,
|
182 |
+
requests.exceptions.ConnectionError,
|
183 |
+
) as exception:
|
184 |
+
warnings.warn(
|
185 |
+
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
186 |
+
repoId,
|
187 |
+
exception,
|
188 |
+
)
|
189 |
+
warnings.warn(
|
190 |
+
"Trying to load the model directly from the local cache, if it exists."
|
191 |
+
)
|
192 |
+
|
193 |
+
kwargs["local_files_only"] = True
|
194 |
+
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
195 |
+
|
196 |
+
|
197 |
+
def load_model(self):
|
198 |
+
import ctranslate2
|
199 |
+
import transformers
|
200 |
+
try:
|
201 |
+
print('\n\nLoading model: %s\n\n' % self.modelPath)
|
202 |
+
kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
|
203 |
+
kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
|
204 |
+
self.roleSystem = {"role": "system", "content": self.system_prompt}
|
205 |
+
self.Model = ctranslate2.Generator(**kwargsModel)
|
206 |
+
|
207 |
+
self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
|
208 |
+
self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
|
209 |
+
|
210 |
+
except Exception as e:
|
211 |
+
self.release_vram()
|
212 |
+
raise e
|
213 |
+
|
214 |
+
|
215 |
+
def release_vram(self):
|
216 |
+
try:
|
217 |
+
import torch
|
218 |
+
if torch.cuda.is_available():
|
219 |
+
if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
|
220 |
+
self.Model.unload_model()
|
221 |
+
|
222 |
+
if getattr(self, "Tokenizer", None) is not None:
|
223 |
+
del self.Tokenizer
|
224 |
+
if getattr(self, "Model", None) is not None:
|
225 |
+
del self.Model
|
226 |
+
import gc
|
227 |
+
gc.collect()
|
228 |
+
try:
|
229 |
+
torch.cuda.empty_cache()
|
230 |
+
except Exception as e:
|
231 |
+
print(traceback.format_exc())
|
232 |
+
print("\tcuda empty cache, error: " + str(e))
|
233 |
+
print("release vram end.")
|
234 |
+
except Exception as e:
|
235 |
+
print(traceback.format_exc())
|
236 |
+
print("Error release vram: " + str(e))
|
237 |
+
|
238 |
+
def reorganize(self, text: str, max_length: int = 400):
|
239 |
+
output = None
|
240 |
+
result = None
|
241 |
+
try:
|
242 |
+
input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
|
243 |
+
source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
|
244 |
+
output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
|
245 |
+
target = output[0]
|
246 |
+
result = self.Tokenizer.decode(target.sequences_ids[0])
|
247 |
+
|
248 |
+
if len(result) > 2:
|
249 |
+
if result[0] == "\"" and result[len(result) - 1] == "\"":
|
250 |
+
result = result[1:-1]
|
251 |
+
elif result[0] == "'" and result[len(result) - 1] == "'":
|
252 |
+
result = result[1:-1]
|
253 |
+
elif result[0] == "「" and result[len(result) - 1] == "」":
|
254 |
+
result = result[1:-1]
|
255 |
+
elif result[0] == "『" and result[len(result) - 1] == "』":
|
256 |
+
result = result[1:-1]
|
257 |
+
except Exception as e:
|
258 |
+
print(traceback.format_exc())
|
259 |
+
print("Error reorganize text: " + str(e))
|
260 |
+
|
261 |
+
return result
|
262 |
+
|
263 |
|
264 |
class Predictor:
|
265 |
def __init__(self):
|
|
|
347 |
character_thresh,
|
348 |
character_mcut_enabled,
|
349 |
characters_merge_enabled,
|
350 |
+
llama3_reorganize_model_repo,
|
351 |
additional_tags_prepend,
|
352 |
additional_tags_append,
|
353 |
):
|
|
|
365 |
|
366 |
tag_results.clear()
|
367 |
|
368 |
+
if llama3_reorganize_model_repo:
|
369 |
+
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
370 |
+
|
371 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
372 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
373 |
if prepend_list and append_list:
|
|
|
427 |
sorted_general_list = [item for item in sorted_general_list if item not in append_list]
|
428 |
|
429 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
|
430 |
+
|
431 |
+
if llama3_reorganize_model_repo:
|
432 |
+
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
433 |
+
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
434 |
+
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
435 |
+
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
436 |
+
sorted_general_strings += "," + reorganize_strings
|
437 |
|
438 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
439 |
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
|
|
453 |
# Get file name from lookup
|
454 |
taggers_zip.write(info["path"], arcname=info["name"])
|
455 |
download.append(downloadZipPath)
|
456 |
+
|
457 |
+
if llama3_reorganize_model_repo:
|
458 |
+
llama3_reorganize.release_vram()
|
459 |
+
del llama3_reorganize
|
460 |
|
461 |
return download, sorted_general_strings, rating, character_res, general_res
|
462 |
|
|
|
516 |
EVA02_LARGE_MODEL_IS_DSV1_REPO,
|
517 |
]
|
518 |
|
519 |
+
llama_list = [
|
520 |
+
META_LLAMA_3_3B_REPO,
|
521 |
+
META_LLAMA_3_8B_REPO,
|
522 |
+
]
|
523 |
+
|
524 |
with gr.Blocks(title=TITLE) as demo:
|
525 |
gr.Markdown(
|
526 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
|
|
539 |
|
540 |
model_repo = gr.Dropdown(
|
541 |
dropdown_list,
|
542 |
+
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
543 |
label="Model",
|
544 |
)
|
545 |
with gr.Row():
|
|
|
576 |
label="Merge characters into the string output",
|
577 |
scale=1,
|
578 |
)
|
579 |
+
with gr.Row():
|
580 |
+
llama3_reorganize_model_repo = gr.Dropdown(
|
581 |
+
[None] + llama_list,
|
582 |
+
value=None,
|
583 |
+
label="Llama3 Model",
|
584 |
+
info="Use the Llama3 model to reorganize the article (Note: very slow)",
|
585 |
+
)
|
586 |
with gr.Row():
|
587 |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
|
588 |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
|
|
|
596 |
character_thresh,
|
597 |
character_mcut_enabled,
|
598 |
characters_merge_enabled,
|
599 |
+
llama3_reorganize_model_repo,
|
600 |
additional_tags_prepend,
|
601 |
additional_tags_append,
|
602 |
],
|
|
|
640 |
character_thresh,
|
641 |
character_mcut_enabled,
|
642 |
characters_merge_enabled,
|
643 |
+
llama3_reorganize_model_repo,
|
644 |
additional_tags_prepend,
|
645 |
additional_tags_append,
|
646 |
],
|
|
|
656 |
general_mcut_enabled,
|
657 |
character_thresh,
|
658 |
character_mcut_enabled,
|
|
|
|
|
|
|
659 |
],
|
660 |
)
|
661 |
|
requirements.txt
CHANGED
@@ -1,6 +1,16 @@
|
|
|
|
|
|
1 |
pillow>=9.0.0
|
2 |
onnxruntime>=1.12.0
|
3 |
huggingface-hub
|
4 |
|
5 |
gradio==5.12.0
|
6 |
-
pandas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu124
|
2 |
+
|
3 |
pillow>=9.0.0
|
4 |
onnxruntime>=1.12.0
|
5 |
huggingface-hub
|
6 |
|
7 |
gradio==5.12.0
|
8 |
+
pandas
|
9 |
+
|
10 |
+
# for reorganize WD Tagger into a readable article by Llama3 model.
|
11 |
+
transformers>=4.45.2
|
12 |
+
ctranslate2>=4.4.0
|
13 |
+
torch==2.5.0+cu124; sys_platform != 'darwin'
|
14 |
+
torchvision==0.20.0+cu124; sys_platform != 'darwin'
|
15 |
+
torch==2.5.0; sys_platform == 'darwin'
|
16 |
+
torchvision==0.20.0; sys_platform == 'darwin'
|