import argparse from .constants import * from .modules.models import HUNYUAN_DIT_CONFIG def get_args(default_args=None): parser = argparse.ArgumentParser() # Basic parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.") parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.") parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024], help='Image size (h, w). If a single value is provided, the image will be treated to ' '(value, value).') parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch", help="Inference mode") # HunYuan-DiT parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2') parser.add_argument("--norm", type=str, default="layer", help="Normalization layer type") parser.add_argument("--load-key", type=str, choices=["ema", "module"], default="ema", help="Load model key for HunYuanDiT checkpoint.") parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024], help="Size condition used in sampling. 2 values are required for height and width. " "If a single value is provided, the image will be treated to (value, value).") parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.") # Prompt enhancement parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.") parser.add_argument("--no-enhance", dest="enhance", action="store_false") parser.set_defaults(enhance=True) # Diffusion parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.") parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false") parser.set_defaults(learn_sigma=True) parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction", help="Diffusion predict type") parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear", help="Noise schedule") parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value") parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value") # Text condition parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.") parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.") parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.") parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.") parser.add_argument("--negative", type=str, default=None, help="Negative prompt.") # Acceleration parser.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.") parser.add_argument("--no-fp16", dest="use_fp16", action="store_false") parser.set_defaults(use_fp16=True) # Sampling parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size") parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler") parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps") parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.") # App parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language") args = parser.parse_args(default_args) return args