NeTI / scripts /inference.py
neural-ti's picture
Upload 17 files
3eb1ce9
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, List, Tuple, Union
import numpy as np
import pyrallis
import torch
from PIL import Image
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
from transformers import CLIPTokenizer
sys.path.append("")
sys.path.append("../src")
import constants
from models.neti_clip_text_encoder import NeTICLIPTextModel
from models.neti_mapper import NeTIMapper
from prompt_manager import PromptManager
from sd_pipeline_call import sd_pipeline_call
from models.xti_attention_processor import XTIAttenProc
from checkpoint_handler import CheckpointHandler
from utils import vis_utils
@dataclass
class InferenceConfig:
# Specifies which checkpoint iteration we want to load
iteration: Optional[int] = None
# The input directory containing the saved models and embeddings
input_dir: Optional[Path] = None
# Where the save the inference results to
inference_dir: Optional[Path] = None
# Specific path to the mapper you want to load, overrides `input_dir`
mapper_checkpoint_path: Optional[Path] = None
# Specific path to the embeddings you want to load, overrides `input_dir`
learned_embeds_path: Optional[Path] = None
# List of prompts to run inference on
prompts: Optional[List[str]] = None
# Text file containing a prompts to run inference on (one prompt per line), overrides `prompts`
prompts_file_path: Optional[Path] = None
# List of random seeds to run on
seeds: List[int] = field(default_factory=lambda: [42])
# If you want to run with dropout at inference time, this specifies the truncation indices for applying dropout.
# None indicates that no dropout will be performed. If a list of indices is provided, will run all indices.
truncation_idxs: Optional[Union[int, List[int]]] = None
# Whether to run with torch.float16 or torch.float32
torch_dtype: str = "fp16"
def __post_init__(self):
assert bool(self.prompts) != bool(self.prompts_file_path), \
"You must provide either prompts or prompts_file_path, but not both!"
self._set_prompts()
self._set_input_paths()
self.inference_dir.mkdir(exist_ok=True, parents=True)
if type(self.truncation_idxs) == int:
self.truncation_idxs = [self.truncation_idxs]
self.torch_dtype = torch.float16 if self.torch_dtype == "fp16" else torch.float32
def _set_input_paths(self):
if self.inference_dir is None:
assert self.input_dir is not None, "You must pass an input_dir if you do not specify inference_dir"
self.inference_dir = self.input_dir / f"inference_{self.iteration}"
if self.mapper_checkpoint_path is None:
assert self.input_dir is not None, "You must pass an input_dir if you do not specify mapper_checkpoint_path"
self.mapper_checkpoint_path = self.input_dir / f"mapper-steps-{self.iteration}.pt"
if self.learned_embeds_path is None:
assert self.input_dir is not None, "You must pass an input_dir if you do not specify learned_embeds_path"
self.learned_embeds_path = self.input_dir / f"learned_embeds-steps-{self.iteration}.bin"
def _set_prompts(self):
if self.prompts_file_path is not None:
assert self.prompts_file_path.exists(), f"Prompts file {self.prompts_file_path} does not exist!"
self.prompts = self.prompts_file_path.read_text().splitlines()
@pyrallis.wrap()
def main(infer_cfg: InferenceConfig):
train_cfg, mapper = CheckpointHandler.load_mapper(infer_cfg.mapper_checkpoint_path)
pipeline, placeholder_token, placeholder_token_id = load_stable_diffusion_model(
pretrained_model_name_or_path=train_cfg.model.pretrained_model_name_or_path,
mapper=mapper,
learned_embeds_path=infer_cfg.learned_embeds_path,
torch_dtype=infer_cfg.torch_dtype
)
prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
text_encoder=pipeline.text_encoder,
timesteps=pipeline.scheduler.timesteps,
unet_layers=constants.UNET_LAYERS,
placeholder_token=placeholder_token,
placeholder_token_id=placeholder_token_id,
torch_dtype=infer_cfg.torch_dtype)
for prompt in infer_cfg.prompts:
output_path = infer_cfg.inference_dir / prompt.format(placeholder_token)
output_path.mkdir(exist_ok=True, parents=True)
for truncation_idx in infer_cfg.truncation_idxs:
print(f"Running with truncation index: {truncation_idx}")
prompt_image = run_inference(prompt=prompt,
pipeline=pipeline,
prompt_manager=prompt_manager,
seeds=infer_cfg.seeds,
output_path=output_path,
num_images_per_prompt=1,
truncation_idx=truncation_idx)
if truncation_idx is not None:
save_name = f"{prompt.format(placeholder_token)}_truncation_{truncation_idx}.png"
else:
save_name = f"{prompt.format(placeholder_token)}.png"
prompt_image.save(infer_cfg.inference_dir / save_name)
def run_inference(prompt: str,
pipeline: StableDiffusionPipeline,
prompt_manager: PromptManager,
seeds: List[int],
output_path: Optional[Path] = None,
num_images_per_prompt: int = 1,
truncation_idx: Optional[int] = None) -> Image.Image:
with torch.autocast("cuda"):
with torch.no_grad():
prompt_embeds = prompt_manager.embed_prompt(prompt,
num_images_per_prompt=num_images_per_prompt,
truncation_idx=truncation_idx)
joined_images = []
for seed in seeds:
generator = torch.Generator(device='cuda').manual_seed(seed)
images = sd_pipeline_call(pipeline,
prompt_embeds=prompt_embeds,
generator=generator,
num_images_per_prompt=num_images_per_prompt).images
seed_image = Image.fromarray(np.concatenate(images, axis=1)).convert("RGB")
if output_path is not None:
save_name = f'{seed}_truncation_{truncation_idx}.png' if truncation_idx is not None else f'{seed}.png'
seed_image.save(output_path / save_name)
joined_images.append(seed_image)
joined_image = vis_utils.get_image_grid(joined_images)
return joined_image
def load_stable_diffusion_model(pretrained_model_name_or_path: str,
learned_embeds_path: Path,
mapper: Optional[NeTIMapper] = None,
num_denoising_steps: int = 50,
torch_dtype: torch.dtype = torch.float16) -> Tuple[StableDiffusionPipeline, str, int]:
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = NeTICLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
)
if mapper is not None:
text_encoder.text_model.embeddings.set_mapper(mapper)
placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
learned_embeds_path=learned_embeds_path,
text_encoder=text_encoder,
tokenizer=tokenizer
)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
text_encoder=text_encoder,
tokenizer=tokenizer
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
pipeline.unet.set_attn_processor(XTIAttenProc())
return pipeline, placeholder_token, placeholder_token_id
if __name__ == '__main__':
main()