|
from pathlib import Path |
|
|
|
from loguru import logger |
|
|
|
from dialoggen.dialoggen_demo import DialogGen |
|
from hydit.config import get_args |
|
from hydit.inference import End2End |
|
|
|
|
|
def inferencer(): |
|
args = get_args() |
|
models_root_path = Path(args.model_root) |
|
if not models_root_path.exists(): |
|
raise ValueError(f"`models_root` not exists: {models_root_path}") |
|
|
|
|
|
gen = End2End(args, models_root_path) |
|
|
|
|
|
if args.enhance: |
|
logger.info("Loading DialogGen model (for prompt enhancement)...") |
|
enhancer = DialogGen(str(models_root_path / "dialoggen")) |
|
logger.info("DialogGen model loaded.") |
|
else: |
|
enhancer = None |
|
|
|
return args, gen, enhancer |
|
|
|
|
|
if __name__ == "__main__": |
|
args, gen, enhancer = inferencer() |
|
|
|
if enhancer: |
|
logger.info("Prompt Enhancement...") |
|
success, enhanced_prompt = enhancer(args.prompt) |
|
if not success: |
|
logger.info("Sorry, the prompt is not compliant, refuse to draw.") |
|
exit() |
|
logger.info(f"Enhanced prompt: {enhanced_prompt}") |
|
else: |
|
enhanced_prompt = None |
|
|
|
|
|
logger.info("Generating images...") |
|
height, width = args.image_size |
|
results = gen.predict(args.prompt, |
|
height=height, |
|
width=width, |
|
seed=args.seed, |
|
enhanced_prompt=enhanced_prompt, |
|
negative_prompt=args.negative, |
|
infer_steps=args.infer_steps, |
|
guidance_scale=args.cfg_scale, |
|
batch_size=args.batch_size, |
|
src_size_cond=args.size_cond, |
|
) |
|
images = results['images'] |
|
|
|
|
|
save_dir = Path('results') |
|
save_dir.mkdir(exist_ok=True) |
|
|
|
all_files = list(save_dir.glob('*.png')) |
|
if all_files: |
|
start = max([int(f.stem) for f in all_files]) + 1 |
|
else: |
|
start = 0 |
|
|
|
for idx, pil_img in enumerate(images): |
|
save_path = save_dir / f"{idx + start}.png" |
|
pil_img.save(save_path) |
|
logger.info(f"Save to {save_path}") |
|
|