|
import sys |
|
from dataclasses import dataclass, field |
|
from pathlib import Path |
|
from typing import Optional, List, Tuple, Union |
|
|
|
import numpy as np |
|
import pyrallis |
|
import torch |
|
from PIL import Image |
|
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline |
|
from transformers import CLIPTokenizer |
|
|
|
sys.path.append("") |
|
sys.path.append("../src") |
|
|
|
import constants |
|
from models.neti_clip_text_encoder import NeTICLIPTextModel |
|
from models.neti_mapper import NeTIMapper |
|
from prompt_manager import PromptManager |
|
from sd_pipeline_call import sd_pipeline_call |
|
from models.xti_attention_processor import XTIAttenProc |
|
from checkpoint_handler import CheckpointHandler |
|
from utils import vis_utils |
|
|
|
|
|
@dataclass |
|
class InferenceConfig: |
|
|
|
iteration: Optional[int] = None |
|
|
|
input_dir: Optional[Path] = None |
|
|
|
inference_dir: Optional[Path] = None |
|
|
|
mapper_checkpoint_path: Optional[Path] = None |
|
|
|
learned_embeds_path: Optional[Path] = None |
|
|
|
prompts: Optional[List[str]] = None |
|
|
|
prompts_file_path: Optional[Path] = None |
|
|
|
seeds: List[int] = field(default_factory=lambda: [42]) |
|
|
|
|
|
truncation_idxs: Optional[Union[int, List[int]]] = None |
|
|
|
torch_dtype: str = "fp16" |
|
|
|
def __post_init__(self): |
|
assert bool(self.prompts) != bool(self.prompts_file_path), \ |
|
"You must provide either prompts or prompts_file_path, but not both!" |
|
self._set_prompts() |
|
self._set_input_paths() |
|
self.inference_dir.mkdir(exist_ok=True, parents=True) |
|
if type(self.truncation_idxs) == int: |
|
self.truncation_idxs = [self.truncation_idxs] |
|
self.torch_dtype = torch.float16 if self.torch_dtype == "fp16" else torch.float32 |
|
|
|
def _set_input_paths(self): |
|
if self.inference_dir is None: |
|
assert self.input_dir is not None, "You must pass an input_dir if you do not specify inference_dir" |
|
self.inference_dir = self.input_dir / f"inference_{self.iteration}" |
|
if self.mapper_checkpoint_path is None: |
|
assert self.input_dir is not None, "You must pass an input_dir if you do not specify mapper_checkpoint_path" |
|
self.mapper_checkpoint_path = self.input_dir / f"mapper-steps-{self.iteration}.pt" |
|
if self.learned_embeds_path is None: |
|
assert self.input_dir is not None, "You must pass an input_dir if you do not specify learned_embeds_path" |
|
self.learned_embeds_path = self.input_dir / f"learned_embeds-steps-{self.iteration}.bin" |
|
|
|
def _set_prompts(self): |
|
if self.prompts_file_path is not None: |
|
assert self.prompts_file_path.exists(), f"Prompts file {self.prompts_file_path} does not exist!" |
|
self.prompts = self.prompts_file_path.read_text().splitlines() |
|
|
|
|
|
@pyrallis.wrap() |
|
def main(infer_cfg: InferenceConfig): |
|
train_cfg, mapper = CheckpointHandler.load_mapper(infer_cfg.mapper_checkpoint_path) |
|
pipeline, placeholder_token, placeholder_token_id = load_stable_diffusion_model( |
|
pretrained_model_name_or_path=train_cfg.model.pretrained_model_name_or_path, |
|
mapper=mapper, |
|
learned_embeds_path=infer_cfg.learned_embeds_path, |
|
torch_dtype=infer_cfg.torch_dtype |
|
) |
|
prompt_manager = PromptManager(tokenizer=pipeline.tokenizer, |
|
text_encoder=pipeline.text_encoder, |
|
timesteps=pipeline.scheduler.timesteps, |
|
unet_layers=constants.UNET_LAYERS, |
|
placeholder_token=placeholder_token, |
|
placeholder_token_id=placeholder_token_id, |
|
torch_dtype=infer_cfg.torch_dtype) |
|
for prompt in infer_cfg.prompts: |
|
output_path = infer_cfg.inference_dir / prompt.format(placeholder_token) |
|
output_path.mkdir(exist_ok=True, parents=True) |
|
for truncation_idx in infer_cfg.truncation_idxs: |
|
print(f"Running with truncation index: {truncation_idx}") |
|
prompt_image = run_inference(prompt=prompt, |
|
pipeline=pipeline, |
|
prompt_manager=prompt_manager, |
|
seeds=infer_cfg.seeds, |
|
output_path=output_path, |
|
num_images_per_prompt=1, |
|
truncation_idx=truncation_idx) |
|
if truncation_idx is not None: |
|
save_name = f"{prompt.format(placeholder_token)}_truncation_{truncation_idx}.png" |
|
else: |
|
save_name = f"{prompt.format(placeholder_token)}.png" |
|
prompt_image.save(infer_cfg.inference_dir / save_name) |
|
|
|
|
|
def run_inference(prompt: str, |
|
pipeline: StableDiffusionPipeline, |
|
prompt_manager: PromptManager, |
|
seeds: List[int], |
|
output_path: Optional[Path] = None, |
|
num_images_per_prompt: int = 1, |
|
truncation_idx: Optional[int] = None) -> Image.Image: |
|
with torch.autocast("cuda"): |
|
with torch.no_grad(): |
|
prompt_embeds = prompt_manager.embed_prompt(prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
truncation_idx=truncation_idx) |
|
joined_images = [] |
|
for seed in seeds: |
|
generator = torch.Generator(device='cuda').manual_seed(seed) |
|
images = sd_pipeline_call(pipeline, |
|
prompt_embeds=prompt_embeds, |
|
generator=generator, |
|
num_images_per_prompt=num_images_per_prompt).images |
|
seed_image = Image.fromarray(np.concatenate(images, axis=1)).convert("RGB") |
|
if output_path is not None: |
|
save_name = f'{seed}_truncation_{truncation_idx}.png' if truncation_idx is not None else f'{seed}.png' |
|
seed_image.save(output_path / save_name) |
|
joined_images.append(seed_image) |
|
joined_image = vis_utils.get_image_grid(joined_images) |
|
return joined_image |
|
|
|
|
|
def load_stable_diffusion_model(pretrained_model_name_or_path: str, |
|
learned_embeds_path: Path, |
|
mapper: Optional[NeTIMapper] = None, |
|
num_denoising_steps: int = 50, |
|
torch_dtype: torch.dtype = torch.float16) -> Tuple[StableDiffusionPipeline, str, int]: |
|
tokenizer = CLIPTokenizer.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="tokenizer") |
|
text_encoder = NeTICLIPTextModel.from_pretrained( |
|
pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype, |
|
) |
|
if mapper is not None: |
|
text_encoder.text_model.embeddings.set_mapper(mapper) |
|
placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip( |
|
learned_embeds_path=learned_embeds_path, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer |
|
) |
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
pretrained_model_name_or_path, |
|
torch_dtype=torch_dtype, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer |
|
).to("cuda") |
|
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) |
|
pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device) |
|
pipeline.unet.set_attn_processor(XTIAttenProc()) |
|
return pipeline, placeholder_token, placeholder_token_id |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|