import time import torch from typing import Callable from pathlib import Path from dartrs.v2 import ( V2Model, MixtralModel, MistralModel, compose_prompt, LengthTag, AspectRatioTag, RatingTag, IdentityTag, ) from dartrs.dartrs import DartTokenizer from dartrs.utils import get_generation_config import gradio as gr from gradio.components import Component try: from output import UpsamplingOutput except: from .output import UpsamplingOutput V2_ALL_MODELS = { "dart-v2-moe-sft": { "repo": "p1atdev/dart-v2-moe-sft", "type": "sft", "class": MixtralModel, }, "dart-v2-sft": { "repo": "p1atdev/dart-v2-sft", "type": "sft", "class": MistralModel, }, } def prepare_models(model_config: dict): model_name = model_config["repo"] tokenizer = DartTokenizer.from_pretrained(model_name) model = model_config["class"].from_pretrained(model_name) return { "tokenizer": tokenizer, "model": model, } def normalize_tags(tokenizer: DartTokenizer, tags: str): """Just remove unk tokens.""" return ", ".join([tag for tag in tokenizer.tokenize(tags) if tag != "<|unk|>"]) @torch.no_grad() def generate_tags( model: V2Model, tokenizer: DartTokenizer, prompt: str, ban_token_ids: list[int], ): output = model.generate( get_generation_config( prompt, tokenizer=tokenizer, temperature=1, top_p=0.9, top_k=100, max_new_tokens=256, ban_token_ids=ban_token_ids, ), ) return output def _people_tag(noun: str, minimum: int = 1, maximum: int = 5): return ( [f"1{noun}"] + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)] + [f"{maximum+1}+{noun}s"] ) PEOPLE_TAGS = ( _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"] ) def gen_prompt_text(output: UpsamplingOutput): # separate people tags (e.g. 1girl) people_tags = [] other_general_tags = [] for tag in output.general_tags.split(","): tag = tag.strip() if tag in PEOPLE_TAGS: people_tags.append(tag) else: other_general_tags.append(tag) return ", ".join( [ part.strip() for part in [ *people_tags, output.character_tags, output.copyright_tags, *other_general_tags, output.upsampled_tags, output.rating_tag, ] if part.strip() != "" ] ) def elapsed_time_format(elapsed_time: float) -> str: return f"Elapsed: {elapsed_time:.2f} seconds" def parse_upsampling_output( upsampler: Callable[..., UpsamplingOutput], ): def _parse_upsampling_output(*args) -> tuple[str, str, dict]: output = upsampler(*args) return ( gen_prompt_text(output), elapsed_time_format(output.elapsed_time), gr.update(interactive=True), gr.update(interactive=True), ) return _parse_upsampling_output class V2UI: model_name: str | None = None model: V2Model tokenizer: DartTokenizer input_components: list[Component] = [] generate_btn: gr.Button def on_generate( self, model_name: str, copyright_tags: str, character_tags: str, general_tags: str, rating_tag: RatingTag, aspect_ratio_tag: AspectRatioTag, length_tag: LengthTag, identity_tag: IdentityTag, ban_tags: str, *args, ) -> UpsamplingOutput: if self.model_name is None or self.model_name != model_name: models = prepare_models(V2_ALL_MODELS[model_name]) self.model = models["model"] self.tokenizer = models["tokenizer"] self.model_name = model_name # normalize tags # copyright_tags = normalize_tags(self.tokenizer, copyright_tags) # character_tags = normalize_tags(self.tokenizer, character_tags) # general_tags = normalize_tags(self.tokenizer, general_tags) ban_token_ids = self.tokenizer.encode(ban_tags.strip()) prompt = compose_prompt( prompt=general_tags, copyright=copyright_tags, character=character_tags, rating=rating_tag, aspect_ratio=aspect_ratio_tag, length=length_tag, identity=identity_tag, ) start = time.time() upsampled_tags = generate_tags( self.model, self.tokenizer, prompt, ban_token_ids, ) elapsed_time = time.time() - start return UpsamplingOutput( upsampled_tags=upsampled_tags, copyright_tags=copyright_tags, character_tags=character_tags, general_tags=general_tags, rating_tag=rating_tag, aspect_ratio_tag=aspect_ratio_tag, length_tag=length_tag, identity_tag=identity_tag, elapsed_time=elapsed_time, ) def parse_upsampling_output_simple(upsampler: UpsamplingOutput): return gen_prompt_text(upsampler) v2 = V2UI() def v2_upsampling_prompt(model: str = "dart-v2-moe-sft", copyright: str = "", character: str = "", general_tags: str = "", rating: str = "nsfw", aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax", ban_tags: str = "censored"): raw_prompt = parse_upsampling_output_simple(v2.on_generate(model, copyright, character, general_tags, rating, aspect_ratio, length, identity, ban_tags)) return raw_prompt def load_dict_from_csv(filename): dict = {} if not Path(filename).exists(): if Path('./tagger/', filename).exists(): filename = str(Path('./tagger/', filename)) else: return dict try: with open(filename, 'r', encoding="utf-8") as f: lines = f.readlines() except Exception: print(f"Failed to open dictionary file: {filename}") return dict for line in lines: parts = line.strip().split(',') dict[parts[0]] = parts[1] return dict anime_series_dict = load_dict_from_csv('character_series_dict.csv') def select_random_character(series: str, character: str): from random import seed, randrange seed() character_list = list(anime_series_dict.keys()) character = character_list[randrange(len(character_list) - 1)] series = anime_series_dict.get(character.split(",")[0].strip(), "") return series, character def v2_random_prompt(general_tags: str = "", copyright: str = "", character: str = "", rating: str = "nsfw", aspect_ratio: str = "square", length: str = "very_long", identity: str = "lax", ban_tags: str = "censored", model: str = "dart-v2-moe-sft"): if copyright == "" and character == "": copyright, character = select_random_character("", "") raw_prompt = v2_upsampling_prompt(model, copyright, character, general_tags, rating, aspect_ratio, length, identity, ban_tags) return raw_prompt, copyright, character