LittleApple_fp16
upload
69a6cef
raw
history blame contribute delete
No virus
5.27 kB
import glob
import json
import logging
import math
import os.path
from typing import Optional, Union
from gchar.games.base import Character
from hbutils.string import plural_word
from hbutils.system import TemporaryDirectory
from hcpdiff.train_ac import Trainer
from hcpdiff.train_ac_single import TrainerSingleCard
from hcpdiff.utils import load_config_with_cli
from .embedding import create_embedding, _DEFAULT_TRAIN_MODEL
from ..dataset import load_dataset_for_character, save_recommended_tags
from ..utils import data_to_cli_args, get_ch_name
_DEFAULT_TRAIN_CFG = 'cfgs/train/examples/lora_anime_character.yaml'
def _min_training_steps(dataset_size: int, unit: int = 20):
steps = 4000.9 + (720.9319 - 4000.9) / (1 + (dataset_size / 297.2281) ** 0.6543184)
return int(round(steps / unit)) * unit
def train_plora(
source: Union[str, Character], name: Optional[str] = None,
epochs: int = 13, min_steps: Optional[int] = None,
save_for_times: int = 15, no_min_steps: bool = False,
batch_size: int = 4, pretrained_model: str = _DEFAULT_TRAIN_MODEL,
workdir: str = None, emb_n_words: int = 4, emb_init_text: str = '*[0.017, 1]',
unet_rank: float = 8, text_encoder_rank: float = 4,
cfg_file: str = _DEFAULT_TRAIN_CFG, single_card: bool = True,
dataset_type: str = 'stage3-1200', use_ratio: bool = True,
):
with load_dataset_for_character(source, dataset_type) as (ch, ds_dir):
if ch is None:
if name is None:
raise ValueError(f'Name should be specified when using custom source - {source!r}.')
else:
name = name or get_ch_name(ch)
dataset_size = len(glob.glob(os.path.join(ds_dir, '*.png')))
logging.info(f'{plural_word(dataset_size, "image")} found in dataset.')
actual_steps = epochs * dataset_size
if not no_min_steps:
actual_steps = max(actual_steps, _min_training_steps(dataset_size, 20))
if min_steps is not None:
actual_steps = max(actual_steps, min_steps)
save_per_steps = max(int(math.ceil(actual_steps / save_for_times / 20) * 20), 20)
steps = int(math.ceil(actual_steps / save_per_steps) * save_per_steps)
epochs = int(math.ceil(steps / dataset_size))
logging.info(f'Training for {plural_word(steps, "step")}, {plural_word(epochs, "epoch")}, '
f'save per {plural_word(save_per_steps, "step")} ...')
workdir = workdir or os.path.join('runs', name)
os.makedirs(workdir, exist_ok=True)
# os.makedirs(workdir)
save_recommended_tags(ds_dir, name, workdir)
with open(os.path.join(workdir, 'meta.json'), 'w', encoding='utf-8') as f:
json.dump({
'dataset': {
'size': dataset_size,
'type': dataset_type,
},
}, f, indent=4, sort_keys=True, ensure_ascii=False)
with TemporaryDirectory() as embs_dir:
logging.info(f'Creating embeddings {name!r} at {embs_dir!r}, '
f'n_words: {emb_n_words!r}, init_text: {emb_init_text!r}, '
f'pretrained_model: {pretrained_model!r}.')
create_embedding(
name, emb_n_words, emb_init_text,
replace=True,
pretrained_model=pretrained_model,
embs_dir=embs_dir,
)
cli_args = data_to_cli_args({
'train': {
'train_steps': steps,
'save_step': save_per_steps,
'scheduler': {
'num_training_steps': steps,
}
},
'model': {
'pretrained_model_name_or_path': pretrained_model,
},
'character_name': name,
'dataset_dir': ds_dir,
'exp_dir': workdir,
'unet_rank': unet_rank,
'text_encoder_rank': text_encoder_rank,
'tokenizer_pt': {
'emb_dir': embs_dir,
},
'data': {
'dataset1': {
'batch_size': batch_size,
'bucket': {
'_target_': 'hcpdiff.data.bucket.RatioBucket.from_files',
'target_area': '${times:512,512}',
'num_bucket': 5,
} if use_ratio else {
'_target_': 'hcpdiff.data.bucket.SizeBucket.from_files',
'target_area': '---',
'num_bucket': 1,
}
},
},
})
conf = load_config_with_cli(cfg_file, args_list=cli_args) # skip --cfg
logging.info(f'Training with {cfg_file!r}, args: {cli_args!r} ...')
if single_card:
logging.info('Training with single card ...')
trainer = TrainerSingleCard(conf)
else:
logging.info('Training with non-single cards ...')
trainer = Trainer(conf)
trainer.train()