import glob import json import logging import math import os.path from typing import Optional, Union from gchar.games.base import Character from hbutils.string import plural_word from hbutils.system import TemporaryDirectory from hcpdiff.train_ac import Trainer from hcpdiff.train_ac_single import TrainerSingleCard from hcpdiff.utils import load_config_with_cli from .embedding import create_embedding, _DEFAULT_TRAIN_MODEL from ..dataset import load_dataset_for_character, save_recommended_tags from ..utils import data_to_cli_args, get_ch_name _DEFAULT_TRAIN_CFG = 'cfgs/train/examples/lora_anime_character.yaml' def _min_training_steps(dataset_size: int, unit: int = 20): steps = 4000.9 + (720.9319 - 4000.9) / (1 + (dataset_size / 297.2281) ** 0.6543184) return int(round(steps / unit)) * unit def train_plora( source: Union[str, Character], name: Optional[str] = None, epochs: int = 13, min_steps: Optional[int] = None, save_for_times: int = 15, no_min_steps: bool = False, batch_size: int = 4, pretrained_model: str = _DEFAULT_TRAIN_MODEL, workdir: str = None, emb_n_words: int = 4, emb_init_text: str = '*[0.017, 1]', unet_rank: float = 8, text_encoder_rank: float = 4, cfg_file: str = _DEFAULT_TRAIN_CFG, single_card: bool = True, dataset_type: str = 'stage3-1200', use_ratio: bool = True, ): with load_dataset_for_character(source, dataset_type) as (ch, ds_dir): if ch is None: if name is None: raise ValueError(f'Name should be specified when using custom source - {source!r}.') else: name = name or get_ch_name(ch) dataset_size = len(glob.glob(os.path.join(ds_dir, '*.png'))) logging.info(f'{plural_word(dataset_size, "image")} found in dataset.') actual_steps = epochs * dataset_size if not no_min_steps: actual_steps = max(actual_steps, _min_training_steps(dataset_size, 20)) if min_steps is not None: actual_steps = max(actual_steps, min_steps) save_per_steps = max(int(math.ceil(actual_steps / save_for_times / 20) * 20), 20) steps = int(math.ceil(actual_steps / save_per_steps) * save_per_steps) epochs = int(math.ceil(steps / dataset_size)) logging.info(f'Training for {plural_word(steps, "step")}, {plural_word(epochs, "epoch")}, ' f'save per {plural_word(save_per_steps, "step")} ...') workdir = workdir or os.path.join('runs', name) os.makedirs(workdir, exist_ok=True) # os.makedirs(workdir) save_recommended_tags(ds_dir, name, workdir) with open(os.path.join(workdir, 'meta.json'), 'w', encoding='utf-8') as f: json.dump({ 'dataset': { 'size': dataset_size, 'type': dataset_type, }, }, f, indent=4, sort_keys=True, ensure_ascii=False) with TemporaryDirectory() as embs_dir: logging.info(f'Creating embeddings {name!r} at {embs_dir!r}, ' f'n_words: {emb_n_words!r}, init_text: {emb_init_text!r}, ' f'pretrained_model: {pretrained_model!r}.') create_embedding( name, emb_n_words, emb_init_text, replace=True, pretrained_model=pretrained_model, embs_dir=embs_dir, ) cli_args = data_to_cli_args({ 'train': { 'train_steps': steps, 'save_step': save_per_steps, 'scheduler': { 'num_training_steps': steps, } }, 'model': { 'pretrained_model_name_or_path': pretrained_model, }, 'character_name': name, 'dataset_dir': ds_dir, 'exp_dir': workdir, 'unet_rank': unet_rank, 'text_encoder_rank': text_encoder_rank, 'tokenizer_pt': { 'emb_dir': embs_dir, }, 'data': { 'dataset1': { 'batch_size': batch_size, 'bucket': { '_target_': 'hcpdiff.data.bucket.RatioBucket.from_files', 'target_area': '${times:512,512}', 'num_bucket': 5, } if use_ratio else { '_target_': 'hcpdiff.data.bucket.SizeBucket.from_files', 'target_area': '---', 'num_bucket': 1, } }, }, }) conf = load_config_with_cli(cfg_file, args_list=cli_args) # skip --cfg logging.info(f'Training with {cfg_file!r}, args: {cli_args!r} ...') if single_card: logging.info('Training with single card ...') trainer = TrainerSingleCard(conf) else: logging.info('Training with non-single cards ...') trainer = Trainer(conf) trainer.train()