avans06 commited on
Commit
f59521a
·
1 Parent(s): eda899c

Add support for using the Llama3 model to reorganize WD Tagger into a readable article.

Browse files
Files changed (2) hide show
  1. app.py +188 -4
  2. requirements.txt +11 -1
app.py CHANGED
@@ -10,6 +10,7 @@ from PIL import Image
10
  import traceback
11
  import tempfile
12
  import zipfile
 
13
  from datetime import datetime
14
 
15
  TITLE = "WaifuDiffusion Tagger"
@@ -41,6 +42,10 @@ SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
41
  MODEL_FILENAME = "model.onnx"
42
  LABEL_FILENAME = "selected_tags.csv"
43
 
 
 
 
 
44
  # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
45
  kaomojis = [
46
  "0_0",
@@ -102,6 +107,159 @@ def mcut_threshold(probs):
102
  thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
103
  return thresh
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  class Predictor:
107
  def __init__(self):
@@ -189,6 +347,7 @@ class Predictor:
189
  character_thresh,
190
  character_mcut_enabled,
191
  characters_merge_enabled,
 
192
  additional_tags_prepend,
193
  additional_tags_append,
194
  ):
@@ -206,6 +365,9 @@ class Predictor:
206
 
207
  tag_results.clear()
208
 
 
 
 
209
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
210
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
211
  if prepend_list and append_list:
@@ -265,6 +427,13 @@ class Predictor:
265
  sorted_general_list = [item for item in sorted_general_list if item not in append_list]
266
 
267
  sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
 
 
 
 
 
 
 
268
 
269
  txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
270
  txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
@@ -284,6 +453,10 @@ class Predictor:
284
  # Get file name from lookup
285
  taggers_zip.write(info["path"], arcname=info["name"])
286
  download.append(downloadZipPath)
 
 
 
 
287
 
288
  return download, sorted_general_strings, rating, character_res, general_res
289
 
@@ -343,6 +516,11 @@ def main():
343
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
344
  ]
345
 
 
 
 
 
 
346
  with gr.Blocks(title=TITLE) as demo:
347
  gr.Markdown(
348
  value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
@@ -361,7 +539,7 @@ def main():
361
 
362
  model_repo = gr.Dropdown(
363
  dropdown_list,
364
- value=SWINV2_MODEL_DSV3_REPO,
365
  label="Model",
366
  )
367
  with gr.Row():
@@ -398,6 +576,13 @@ def main():
398
  label="Merge characters into the string output",
399
  scale=1,
400
  )
 
 
 
 
 
 
 
401
  with gr.Row():
402
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
403
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
@@ -411,6 +596,7 @@ def main():
411
  character_thresh,
412
  character_mcut_enabled,
413
  characters_merge_enabled,
 
414
  additional_tags_prepend,
415
  additional_tags_append,
416
  ],
@@ -454,6 +640,7 @@ def main():
454
  character_thresh,
455
  character_mcut_enabled,
456
  characters_merge_enabled,
 
457
  additional_tags_prepend,
458
  additional_tags_append,
459
  ],
@@ -469,9 +656,6 @@ def main():
469
  general_mcut_enabled,
470
  character_thresh,
471
  character_mcut_enabled,
472
- characters_merge_enabled,
473
- additional_tags_prepend,
474
- additional_tags_append,
475
  ],
476
  )
477
 
 
10
  import traceback
11
  import tempfile
12
  import zipfile
13
+ import re
14
  from datetime import datetime
15
 
16
  TITLE = "WaifuDiffusion Tagger"
 
42
  MODEL_FILENAME = "model.onnx"
43
  LABEL_FILENAME = "selected_tags.csv"
44
 
45
+ # LLAMA model
46
+ META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
47
+ META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
48
+
49
  # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
50
  kaomojis = [
51
  "0_0",
 
107
  thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
108
  return thresh
109
 
110
+ class Llama3Reorganize:
111
+ def __init__(
112
+ self,
113
+ repoId: str,
114
+ device: str = None,
115
+ loadModel: bool = False,
116
+ ):
117
+ """Initializes the Llama model.
118
+
119
+ Args:
120
+ repoId: LLAMA model repo.
121
+ device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
122
+ ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
123
+ localFilesOnly: If True, avoid downloading the file and return the path to the
124
+ local cached file if it exists.
125
+ """
126
+ self.modelPath = self.download_model(repoId)
127
+
128
+ if device is None:
129
+ import torch
130
+ self.totalVram = 0
131
+ if torch.cuda.is_available():
132
+ try:
133
+ deviceId = torch.cuda.current_device()
134
+ self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
135
+ except Exception as e:
136
+ print(traceback.format_exc())
137
+ print("Error detect vram: " + str(e))
138
+ device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
139
+ else:
140
+ device = "cpu"
141
+
142
+ self.device = device
143
+ self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
144
+
145
+ if loadModel:
146
+ self.load_model()
147
+
148
+ def download_model(self, repoId):
149
+ import warnings
150
+ import requests
151
+ allowPatterns = [
152
+ "config.json",
153
+ "generation_config.json",
154
+ "model.bin",
155
+ "pytorch_model.bin",
156
+ "pytorch_model.bin.index.json",
157
+ "pytorch_model-*.bin",
158
+ "sentencepiece.bpe.model",
159
+ "tokenizer.json",
160
+ "tokenizer_config.json",
161
+ "shared_vocabulary.txt",
162
+ "shared_vocabulary.json",
163
+ "special_tokens_map.json",
164
+ "spiece.model",
165
+ "vocab.json",
166
+ "model.safetensors",
167
+ "model-*.safetensors",
168
+ "model.safetensors.index.json",
169
+ "quantize_config.json",
170
+ "tokenizer.model",
171
+ "vocabulary.json",
172
+ "preprocessor_config.json",
173
+ "added_tokens.json"
174
+ ]
175
+
176
+ kwargs = {"allow_patterns": allowPatterns,}
177
+
178
+ try:
179
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
180
+ except (
181
+ huggingface_hub.utils.HfHubHTTPError,
182
+ requests.exceptions.ConnectionError,
183
+ ) as exception:
184
+ warnings.warn(
185
+ "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
186
+ repoId,
187
+ exception,
188
+ )
189
+ warnings.warn(
190
+ "Trying to load the model directly from the local cache, if it exists."
191
+ )
192
+
193
+ kwargs["local_files_only"] = True
194
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
195
+
196
+
197
+ def load_model(self):
198
+ import ctranslate2
199
+ import transformers
200
+ try:
201
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
202
+ kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
203
+ kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
204
+ self.roleSystem = {"role": "system", "content": self.system_prompt}
205
+ self.Model = ctranslate2.Generator(**kwargsModel)
206
+
207
+ self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
208
+ self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
209
+
210
+ except Exception as e:
211
+ self.release_vram()
212
+ raise e
213
+
214
+
215
+ def release_vram(self):
216
+ try:
217
+ import torch
218
+ if torch.cuda.is_available():
219
+ if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
220
+ self.Model.unload_model()
221
+
222
+ if getattr(self, "Tokenizer", None) is not None:
223
+ del self.Tokenizer
224
+ if getattr(self, "Model", None) is not None:
225
+ del self.Model
226
+ import gc
227
+ gc.collect()
228
+ try:
229
+ torch.cuda.empty_cache()
230
+ except Exception as e:
231
+ print(traceback.format_exc())
232
+ print("\tcuda empty cache, error: " + str(e))
233
+ print("release vram end.")
234
+ except Exception as e:
235
+ print(traceback.format_exc())
236
+ print("Error release vram: " + str(e))
237
+
238
+ def reorganize(self, text: str, max_length: int = 400):
239
+ output = None
240
+ result = None
241
+ try:
242
+ input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
243
+ source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
244
+ output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
245
+ target = output[0]
246
+ result = self.Tokenizer.decode(target.sequences_ids[0])
247
+
248
+ if len(result) > 2:
249
+ if result[0] == "\"" and result[len(result) - 1] == "\"":
250
+ result = result[1:-1]
251
+ elif result[0] == "'" and result[len(result) - 1] == "'":
252
+ result = result[1:-1]
253
+ elif result[0] == "「" and result[len(result) - 1] == "」":
254
+ result = result[1:-1]
255
+ elif result[0] == "『" and result[len(result) - 1] == "』":
256
+ result = result[1:-1]
257
+ except Exception as e:
258
+ print(traceback.format_exc())
259
+ print("Error reorganize text: " + str(e))
260
+
261
+ return result
262
+
263
 
264
  class Predictor:
265
  def __init__(self):
 
347
  character_thresh,
348
  character_mcut_enabled,
349
  characters_merge_enabled,
350
+ llama3_reorganize_model_repo,
351
  additional_tags_prepend,
352
  additional_tags_append,
353
  ):
 
365
 
366
  tag_results.clear()
367
 
368
+ if llama3_reorganize_model_repo:
369
+ llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
370
+
371
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
372
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
373
  if prepend_list and append_list:
 
427
  sorted_general_list = [item for item in sorted_general_list if item not in append_list]
428
 
429
  sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
430
+
431
+ if llama3_reorganize_model_repo:
432
+ reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
433
+ reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
434
+ reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
435
+ reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
436
+ sorted_general_strings += "," + reorganize_strings
437
 
438
  txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
439
  txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
 
453
  # Get file name from lookup
454
  taggers_zip.write(info["path"], arcname=info["name"])
455
  download.append(downloadZipPath)
456
+
457
+ if llama3_reorganize_model_repo:
458
+ llama3_reorganize.release_vram()
459
+ del llama3_reorganize
460
 
461
  return download, sorted_general_strings, rating, character_res, general_res
462
 
 
516
  EVA02_LARGE_MODEL_IS_DSV1_REPO,
517
  ]
518
 
519
+ llama_list = [
520
+ META_LLAMA_3_3B_REPO,
521
+ META_LLAMA_3_8B_REPO,
522
+ ]
523
+
524
  with gr.Blocks(title=TITLE) as demo:
525
  gr.Markdown(
526
  value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
 
539
 
540
  model_repo = gr.Dropdown(
541
  dropdown_list,
542
+ value=EVA02_LARGE_MODEL_DSV3_REPO,
543
  label="Model",
544
  )
545
  with gr.Row():
 
576
  label="Merge characters into the string output",
577
  scale=1,
578
  )
579
+ with gr.Row():
580
+ llama3_reorganize_model_repo = gr.Dropdown(
581
+ [None] + llama_list,
582
+ value=None,
583
+ label="Llama3 Model",
584
+ info="Use the Llama3 model to reorganize the article (Note: very slow)",
585
+ )
586
  with gr.Row():
587
  additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
588
  additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
 
596
  character_thresh,
597
  character_mcut_enabled,
598
  characters_merge_enabled,
599
+ llama3_reorganize_model_repo,
600
  additional_tags_prepend,
601
  additional_tags_append,
602
  ],
 
640
  character_thresh,
641
  character_mcut_enabled,
642
  characters_merge_enabled,
643
+ llama3_reorganize_model_repo,
644
  additional_tags_prepend,
645
  additional_tags_append,
646
  ],
 
656
  general_mcut_enabled,
657
  character_thresh,
658
  character_mcut_enabled,
 
 
 
659
  ],
660
  )
661
 
requirements.txt CHANGED
@@ -1,6 +1,16 @@
 
 
1
  pillow>=9.0.0
2
  onnxruntime>=1.12.0
3
  huggingface-hub
4
 
5
  gradio==5.12.0
6
- pandas
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+
3
  pillow>=9.0.0
4
  onnxruntime>=1.12.0
5
  huggingface-hub
6
 
7
  gradio==5.12.0
8
+ pandas
9
+
10
+ # for reorganize WD Tagger into a readable article by Llama3 model.
11
+ transformers>=4.45.2
12
+ ctranslate2>=4.4.0
13
+ torch==2.5.0+cu124; sys_platform != 'darwin'
14
+ torchvision==0.20.0+cu124; sys_platform != 'darwin'
15
+ torch==2.5.0; sys_platform == 'darwin'
16
+ torchvision==0.20.0; sys_platform == 'darwin'