neural-ti commited on
Commit
ebb9992
Β·
1 Parent(s): 841c0f5
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
  title: NeTI
3
  emoji: πŸƒ
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.32.0
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
1
  ---
2
  title: NeTI
3
  emoji: πŸƒ
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.32.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
  license: mit
11
  ---
gradio_app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ import gradio as gr
6
+ import torch
7
+ from PIL import Image
8
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
9
+ from huggingface_hub import snapshot_download
10
+ from transformers import CLIPTokenizer
11
+
12
+ from src import constants
13
+ from src.checkpoint_handler import CheckpointHandler
14
+ from src.models.neti_clip_text_encoder import NeTICLIPTextModel
15
+ from src.models.xti_attention_processor import XTIAttenProc
16
+ from src.prompt_manager import PromptManager
17
+ from src.scripts.inference import run_inference
18
+
19
+ sys.path.append(".")
20
+ sys.path.append("..")
21
+
22
+ DESCRIPTION = '''
23
+ # A Neural Space-Time Representation for Text-to-Image Personalization
24
+ <p style="text-align: center;">
25
+ This is a demo for our <a href="https://arxiv.org/abs/2305.15391">paper</a>: ''A Neural Space-Time Representation
26
+ for Text-to-Image Personalization''.
27
+ <br>
28
+ Project page and code is available <a href="https://neuraltextualinversion.github.io/NeTI/">here</a>.
29
+ <br>
30
+ We introduce a new text-conditioning latent space P* that is dependent on both the denoising process timestep and
31
+ the U-Net layers.
32
+ This space-time representation is learned implicitly via a small mapping network.
33
+ <br>
34
+ Here, you can generate images using one of the concepts trained in our paper. Simply select your concept and
35
+ random seed.
36
+ <br>
37
+ You can also choose different truncation values to play with the reconstruction vs. editability of the concept.
38
+ </p>
39
+ '''
40
+
41
+ CONCEPT_TO_PLACEHOLDER = {
42
+ 'barn': '<barn>',
43
+ 'cat': '<cat>',
44
+ 'clock': '<clock>',
45
+ 'colorful_teapot': '<colorful-teapot>',
46
+ 'dangling_child': '<dangling-child>',
47
+ 'dog': '<dog>',
48
+ 'elephant': '<elephant>',
49
+ 'fat_stone_bird': '<stone-bird>',
50
+ 'headless_statue': '<headless-statue>',
51
+ 'lecun': '<lecun>',
52
+ 'maeve': '<maeve-dog>',
53
+ 'metal_bird': '<metal-bird>',
54
+ 'mugs_skulls': '<mug-skulls>',
55
+ 'rainbow_cat': '<rainbow-cat>',
56
+ 'red_bowl': '<red-bowl>',
57
+ 'teddybear': '<teddybear>',
58
+ 'tortoise_plushy': '<tortoise-plushy>',
59
+ 'wooden_pot': '<wooden-pot>'
60
+ }
61
+
62
+ MODELS_PATH = Path('./trained_models')
63
+ MODELS_PATH.mkdir(parents=True, exist_ok=True)
64
+
65
+
66
+ def load_stable_diffusion_model(pretrained_model_name_or_path: str,
67
+ num_denoising_steps: int = 50,
68
+ torch_dtype: torch.dtype = torch.float16) -> StableDiffusionPipeline:
69
+ tokenizer = CLIPTokenizer.from_pretrained(
70
+ pretrained_model_name_or_path, subfolder="tokenizer")
71
+ text_encoder = NeTICLIPTextModel.from_pretrained(
72
+ pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
73
+ )
74
+ pipeline = StableDiffusionPipeline.from_pretrained(
75
+ pretrained_model_name_or_path,
76
+ torch_dtype=torch_dtype,
77
+ text_encoder=text_encoder,
78
+ tokenizer=tokenizer
79
+ ).to("cuda")
80
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
81
+ pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
82
+ pipeline.unet.set_attn_processor(XTIAttenProc())
83
+ return pipeline
84
+
85
+
86
+ def get_possible_concepts() -> List[str]:
87
+ objects = [x for x in MODELS_PATH.iterdir() if x.is_dir()]
88
+ return [x.name for x in objects]
89
+
90
+
91
+ def load_sd_and_all_tokens():
92
+ mappers = {}
93
+ pipeline = load_stable_diffusion_model(pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4")
94
+ print("Downloading all models from HF Hub...")
95
+ snapshot_download(repo_id="neural-ti/NeTI", local_dir='./trained_models')
96
+ print("Done.")
97
+ concepts = get_possible_concepts()
98
+ for concept in concepts:
99
+ print(f"Loading model for concept: {concept}")
100
+ learned_embeds_path = MODELS_PATH / concept / f"{concept}-learned_embeds.bin"
101
+ mapper_path = MODELS_PATH / concept / f"{concept}-mapper.pt"
102
+ train_cfg, mapper = CheckpointHandler.load_mapper(mapper_path=mapper_path)
103
+ placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
104
+ learned_embeds_path=learned_embeds_path,
105
+ text_encoder=pipeline.text_encoder,
106
+ tokenizer=pipeline.tokenizer
107
+ )
108
+ mappers[concept] = {
109
+ "mapper": mapper,
110
+ "placeholder_token": placeholder_token,
111
+ "placeholder_token_id": placeholder_token_id
112
+ }
113
+ return mappers, pipeline
114
+
115
+
116
+ mappers, pipeline = load_sd_and_all_tokens()
117
+
118
+
119
+ def main_pipeline(concept_name: str,
120
+ prompt_input: str,
121
+ seed: int,
122
+ use_truncation: bool = False,
123
+ truncation_idx: Optional[int] = None) -> Image.Image:
124
+ pipeline.text_encoder.text_model.embeddings.set_mapper(mappers[concept_name]["mapper"])
125
+ placeholder_token = mappers[concept_name]["placeholder_token"]
126
+ placeholder_token_id = mappers[concept_name]["placeholder_token_id"]
127
+ prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
128
+ text_encoder=pipeline.text_encoder,
129
+ timesteps=pipeline.scheduler.timesteps,
130
+ unet_layers=constants.UNET_LAYERS,
131
+ placeholder_token=placeholder_token,
132
+ placeholder_token_id=placeholder_token_id,
133
+ torch_dtype=torch.float16)
134
+ image = run_inference(prompt=prompt_input.replace("*", CONCEPT_TO_PLACEHOLDER[concept_name]),
135
+ pipeline=pipeline,
136
+ prompt_manager=prompt_manager,
137
+ seeds=[int(seed)],
138
+ num_images_per_prompt=1,
139
+ truncation_idx=truncation_idx if use_truncation else None)
140
+ return [image]
141
+
142
+
143
+ with gr.Blocks(css='style.css') as demo:
144
+ gr.Markdown(DESCRIPTION)
145
+
146
+ gr.HTML('''<a href="https://huggingface.co/spaces/neural-ti/NeTI?duplicate=true"><img src="https://bit.ly/3gLdBN6"
147
+ alt="Duplicate Space"></a>''')
148
+
149
+ with gr.Row():
150
+ with gr.Column():
151
+ concept = gr.Dropdown(get_possible_concepts(), multiselect=False, label="Concept",
152
+ info="Choose your concept")
153
+ prompt = gr.Textbox(label="Input prompt", info="Input prompt with placeholder for concept. "
154
+ "Please use * to specify the concept.")
155
+ random_seed = gr.Number(value=42, label="Random seed", precision=0)
156
+ use_truncation = gr.Checkbox(label="Use inference-time dropout",
157
+ info="Whether to use our dropout technique when computing the concept "
158
+ "embeddings.")
159
+ truncation_idx = gr.Slider(8, 128, label="Truncation index",
160
+ info="If using truncation, which index to truncate from. Lower numbers tend to "
161
+ "result in more editable images, but at the cost of reconstruction.")
162
+ run_button = gr.Button('Generate')
163
+
164
+ with gr.Column():
165
+ result = gr.Gallery(label='Result')
166
+ inputs = [concept, prompt, random_seed, use_truncation, truncation_idx]
167
+ outputs = [result]
168
+ run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
169
+
170
+ with gr.Row():
171
+ examples = [
172
+ ["maeve", "A photo of * swimming in the ocean", 5196, True, 16],
173
+ ["dangling_child", "A photo of * in Times Square", 3552126062741487430, False, 8],
174
+ ["teddybear", "A photo of * at his graduation ceremony after finishing his PhD", 263, True, 32],
175
+ ["red_bowl", "A * vase filled with flowers", 13491504810502930872, False, 8],
176
+ ["metal_bird", "* in a comic book", 1028, True, 24],
177
+ ["fat_stone_bird", "A movie poster of The Rock, featuring * about on Godzilla", 7393181316156044422, True,
178
+ 64],
179
+ ]
180
+ gr.Examples(examples=examples,
181
+ inputs=[concept, prompt, random_seed, use_truncation, truncation_idx],
182
+ outputs=[result],
183
+ fn=main_pipeline,
184
+ cache_examples=True)
185
+
186
+ demo.queue(max_size=50).launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.7.0.72
2
+ matplotlib
3
+ pyrallis==0.3.1
4
+ loguru==0.7.0
5
+ torch==1.13.1
6
+ torchvision==0.14.1
7
+ diffusers==0.14.0
8
+ transformers==4.27.4
9
+ accelerate==0.18.0
10
+ gradio
src/__init__.py ADDED
File without changes
src/checkpoint_handler.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import pyrallis
5
+ import torch
6
+ from accelerate import Accelerator
7
+ from torch import nn
8
+ from transformers import CLIPTokenizer
9
+
10
+ from src.models.neti_clip_text_encoder import NeTICLIPTextModel
11
+ from src.models.neti_mapper import NeTIMapper
12
+ from src.models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
13
+ from src.config import RunConfig
14
+
15
+
16
+ class CheckpointHandler:
17
+
18
+ def __init__(self, cfg: RunConfig, placeholder_token_string: str, placeholder_token_id: int, save_root: Path):
19
+ self.cfg = cfg
20
+ self.placeholder_token_string = placeholder_token_string
21
+ self.placeholder_token_id = placeholder_token_id
22
+ self.save_root = save_root
23
+
24
+ def save_model(self, text_encoder: NeTICLIPTextModel,
25
+ accelerator: Accelerator,
26
+ embeds_save_name: str,
27
+ mapper_save_name: str):
28
+ self.save_learned_embeds(text_encoder, accelerator, embeds_save_name)
29
+ self.save_mapper(text_encoder, mapper_save_name)
30
+
31
+ def save_learned_embeds(self, text_encoder: NeTICLIPTextModel, accelerator: Accelerator, save_name: str):
32
+ """
33
+ Save learned embeddings. This embedding isn't really learned, but we'll add it to the tokenizer at inference
34
+ to take the place of our placeholder token.
35
+ """
36
+ learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[self.placeholder_token_id]
37
+ learned_embeds = learned_embeds.detach().cpu()
38
+ learned_embeds_dict = {self.placeholder_token_string: learned_embeds}
39
+ torch.save(learned_embeds_dict, self.save_root / save_name)
40
+
41
+ def save_mapper(self, text_encoder: NeTICLIPTextModel, save_name: str):
42
+ """ Save the mapper and config to be used at inference. """
43
+ cfg_ = RunConfig(**self.cfg.__dict__.copy())
44
+ state_dict = {
45
+ "state_dict": text_encoder.text_model.embeddings.mapper.state_dict(),
46
+ "cfg": pyrallis.encode(cfg_),
47
+ "encoder": text_encoder.text_model.embeddings.mapper.encoder
48
+ }
49
+ torch.save(state_dict, self.save_root / save_name)
50
+
51
+ @staticmethod
52
+ def load_mapper(mapper_path: Path) -> Tuple[RunConfig, NeTIMapper]:
53
+ mapper_ckpt = torch.load(mapper_path, map_location="cpu")
54
+ cfg = pyrallis.decode(RunConfig, mapper_ckpt['cfg'])
55
+ neti_mapper = NeTIMapper(output_dim=768,
56
+ use_nested_dropout=cfg.model.use_nested_dropout,
57
+ nested_dropout_prob=cfg.model.nested_dropout_prob,
58
+ norm_scale=cfg.model.target_norm,
59
+ use_positional_encoding=cfg.model.use_positional_encoding,
60
+ num_pe_time_anchors=cfg.model.num_pe_time_anchors,
61
+ pe_sigmas=cfg.model.pe_sigmas,
62
+ output_bypass=cfg.model.output_bypass)
63
+ neti_mapper.load_state_dict(mapper_ckpt['state_dict'], strict=True)
64
+ encoder = mapper_ckpt['encoder']
65
+ if isinstance(encoder, NeTIPositionalEncoding):
66
+ encoder.w = nn.Parameter(mapper_ckpt['encoder'].w.cuda())
67
+ elif isinstance(encoder, BasicEncoder):
68
+ encoder.normalized_timesteps = mapper_ckpt['encoder'].normalized_timesteps.cuda()
69
+ encoder.normalized_unet_layers = mapper_ckpt['encoder'].normalized_unet_layers.cuda()
70
+ neti_mapper.encoder = encoder.cuda()
71
+ neti_mapper.cuda()
72
+ neti_mapper.eval()
73
+ return cfg, neti_mapper
74
+
75
+ @staticmethod
76
+ def load_learned_embed_in_clip(learned_embeds_path: Path,
77
+ text_encoder: NeTICLIPTextModel,
78
+ tokenizer: CLIPTokenizer) -> Tuple[str, int]:
79
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
80
+
81
+ # separate token and the embeds
82
+ trained_tokens = list(loaded_learned_embeds.keys())
83
+ embeds = list(loaded_learned_embeds.values())
84
+
85
+ # cast to dtype of text_encoder
86
+ dtype = text_encoder.get_input_embeddings().weight.dtype
87
+ embeds = [e.to(dtype) for e in embeds]
88
+
89
+ # add the tokens in tokenizer
90
+ num_added_tokens = tokenizer.add_tokens(trained_tokens)
91
+ if num_added_tokens == 0:
92
+ raise ValueError(f"The tokenizer already contains the token {trained_tokens[0]}. "
93
+ f"Please pass a different `token` that is not already in the tokenizer.")
94
+
95
+ # resize the token embeddings
96
+ text_encoder.resize_token_embeddings(len(tokenizer))
97
+
98
+ # get the id for the token and assign the embeds
99
+ placeholder_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in trained_tokens]
100
+
101
+ for idx, (token, token_id, embed) in enumerate(zip(trained_tokens, placeholder_token_ids, embeds)):
102
+ text_encoder.get_input_embeddings().weight.data[token_id] = embed
103
+
104
+ assert len(trained_tokens) == 1, "Only one placeholder token is supported"
105
+ placeholder_token = trained_tokens[0]
106
+ placeholder_token_id = placeholder_token_ids[0]
107
+ return placeholder_token, placeholder_token_id
src/config.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from pathlib import Path
3
+ from typing import List, Optional, Dict
4
+
5
+ from constants import VALIDATION_PROMPTS
6
+ from utils.types import PESigmas
7
+
8
+
9
+ @dataclass
10
+ class LogConfig:
11
+ """ Parameters for logging and saving """
12
+ # Name of experiment. This will be the name of the output folder
13
+ exp_name: str
14
+ # The output directory where the model predictions and checkpoints will be written
15
+ exp_dir: Path = Path("./outputs")
16
+ # Save interval
17
+ save_steps: int = 250
18
+ # [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
19
+ # `output_dir/runs/**CURRENT_DATETIME_HOSTNAME`
20
+ logging_dir: Path = Path("logs")
21
+ # The integration to report the results to. Supported platforms are "tensorboard" '
22
+ # (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
23
+ report_to: str = "tensorboard"
24
+ # Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator`
25
+ checkpoints_total_limit: Optional[int] = None
26
+
27
+
28
+ @dataclass
29
+ class DataConfig:
30
+ """ Parameters for data """
31
+ # A folder containing the training data
32
+ train_data_dir: Path
33
+ # A token to use as a placeholder for the concept
34
+ placeholder_token: str
35
+ # Super category token to use for normalizing the mapper output
36
+ super_category_token: Optional[str] = "object"
37
+ # Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process
38
+ dataloader_num_workers: int = 8
39
+ # Choose between 'object' and 'style' - used for selecting the prompts for training
40
+ learnable_property: str = "object"
41
+ # How many times to repeat the training data
42
+ repeats: int = 100
43
+ # The resolution for input images, all the images in the train/validation dataset will be resized to this resolution
44
+ resolution: int = 512
45
+ # Whether to center crop images before resizing to resolution
46
+ center_crop: bool = False
47
+
48
+
49
+ @dataclass
50
+ class ModelConfig:
51
+ """ Parameters for defining all models """
52
+ # Path to pretrained model or model identifier from huggingface.co/models
53
+ pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4"
54
+ # Whether to use our Nested Dropout technique
55
+ use_nested_dropout: bool = True
56
+ # Probability to apply nested dropout during training
57
+ nested_dropout_prob: float = 0.5
58
+ # Whether to normalize the norm of the mapper's output vector
59
+ normalize_mapper_output: bool = True
60
+ # Target norm for the mapper's output vector
61
+ target_norm: Optional[float] = None
62
+ # Whether to use positional encoding over the input to the mapper
63
+ use_positional_encoding: bool = True
64
+ # Sigmas used for computing positional encoding
65
+ pe_sigmas: Dict[str, float] = field(default_factory=lambda: {'sigma_t': 0.03, 'sigma_l': 2.0})
66
+ # Number of time anchors for computing our positional encodings
67
+ num_pe_time_anchors: int = 10
68
+ # Whether to output the textual bypass vector
69
+ output_bypass: bool = True
70
+ # Revision of pretrained model identifier from huggingface.co/models
71
+ revision: Optional[str] = None
72
+ # Whether training should be resumed from a previous checkpoint.
73
+ mapper_checkpoint_path: Optional[Path] = None
74
+
75
+ def __post_init__(self):
76
+ if self.pe_sigmas is not None:
77
+ assert len(self.pe_sigmas) == 2, "Should provide exactly two sigma values: one for two and one for layers!"
78
+ self.pe_sigmas = PESigmas(sigma_t=self.pe_sigmas['sigma_t'], sigma_l=self.pe_sigmas['sigma_l'])
79
+
80
+
81
+ @dataclass
82
+ class EvalConfig:
83
+ """ Parameters for validation """
84
+ # A list of prompts that will be used during validation to verify that the model is learning
85
+ validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS)
86
+ # Number of images that should be generated during validation with `validation_prompt`
87
+ num_validation_images: int = 4
88
+ # Seeds to use for generating the validation images
89
+ validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456])
90
+ # Run validation every X steps.
91
+ validation_steps: int = 100
92
+ # Number of denoising steps
93
+ num_denoising_steps: int = 50
94
+
95
+ def __post_init__(self):
96
+ if self.validation_seeds is None:
97
+ self.validation_seeds = list(range(self.num_validation_images))
98
+ assert len(self.validation_seeds) == self.num_validation_images, \
99
+ "Length of validation_seeds should equal num_validation_images"
100
+
101
+ @dataclass
102
+ class OptimConfig:
103
+ """ Parameters for the optimization process """
104
+ # Total number of training steps to perform.
105
+ max_train_steps: Optional[int] = 1_000
106
+ # Learning rate
107
+ learning_rate: float = 1e-3
108
+ # Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size
109
+ scale_lr: bool = True
110
+ # Batch size (per device) for the training dataloader
111
+ train_batch_size: int = 2
112
+ # Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass
113
+ gradient_checkpointing: bool = False
114
+ # Number of updates steps to accumulate before performing a backward/update pass
115
+ gradient_accumulation_steps: int = 4
116
+ # A seed for reproducible training
117
+ seed: Optional[int] = None
118
+ # The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",
119
+ # "constant", "constant_with_warmup"]
120
+ lr_scheduler: str = "constant"
121
+ # Number of steps for the warmup in the lr scheduler
122
+ lr_warmup_steps: int = 0
123
+ # The beta1 parameter for the Adam optimizer
124
+ adam_beta1: float = 0.9
125
+ # The beta2 parameter for the Adam optimizer
126
+ adam_beta2: float = 0.999
127
+ # Weight decay to use
128
+ adam_weight_decay: float = 1e-2
129
+ # Epsilon value for the Adam optimizer
130
+ adam_epsilon: float = 1e-08
131
+ # Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10.
132
+ # and an Nvidia Ampere GPU.
133
+ mixed_precision: str = "no"
134
+ # Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see
135
+ # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
136
+ allow_tf32: bool = False
137
+
138
+
139
+ @dataclass
140
+ class RunConfig:
141
+ """ The main configuration for the coach trainer """
142
+ log: LogConfig = field(default_factory=LogConfig)
143
+ data: DataConfig = field(default_factory=DataConfig)
144
+ model: ModelConfig = field(default_factory=ModelConfig)
145
+ eval: EvalConfig = field(default_factory=EvalConfig)
146
+ optim: OptimConfig = field(default_factory=OptimConfig)
src/constants.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ UNET_LAYERS = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID',
2
+ 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
3
+
4
+ SD_INFERENCE_TIMESTEPS = [999, 979, 959, 939, 919, 899, 879, 859, 839, 819, 799, 779, 759, 739, 719, 699, 679, 659,
5
+ 639, 619, 599, 579, 559, 539, 519, 500, 480, 460, 440, 420, 400, 380, 360, 340, 320, 300,
6
+ 280, 260, 240, 220, 200, 180, 160, 140, 120, 100, 80, 60, 40, 20]
7
+
8
+ PROMPTS = [
9
+ "A photo of a {}",
10
+ "A photo of {} in the jungle",
11
+ "A photo of {} on a beach",
12
+ "A photo of {} in Times Square",
13
+ "A photo of {} in the moon",
14
+ "A painting of {} in the style of Monet",
15
+ "Oil painting of {}",
16
+ "A Marc Chagall painting of {}",
17
+ "A manga drawing of {}",
18
+ 'A watercolor painting of {}',
19
+ "A statue of {}",
20
+ "App icon of {}",
21
+ "A sand sculpture of {}",
22
+ "Colorful graffiti of {}",
23
+ "A photograph of two {} on a table",
24
+ ]
25
+
26
+ VALIDATION_PROMPTS = [
27
+ "A photo of a {}",
28
+ "A photo of a {} on a beach",
29
+ "App icon of {}",
30
+ "A painting of {} in the style of Monet",
31
+ ]
32
+
33
+ IMAGENET_TEMPLATES_SMALL = [
34
+ "a photo of a {}",
35
+ "a rendering of a {}",
36
+ "a cropped photo of the {}",
37
+ "the photo of a {}",
38
+ "a photo of a clean {}",
39
+ "a photo of a dirty {}",
40
+ "a dark photo of the {}",
41
+ "a photo of my {}",
42
+ "a photo of the cool {}",
43
+ "a close-up photo of a {}",
44
+ "a bright photo of the {}",
45
+ "a cropped photo of a {}",
46
+ "a photo of the {}",
47
+ "a good photo of the {}",
48
+ "a photo of one {}",
49
+ "a close-up photo of the {}",
50
+ "a rendition of the {}",
51
+ "a photo of the clean {}",
52
+ "a rendition of a {}",
53
+ "a photo of a nice {}",
54
+ "a good photo of a {}",
55
+ "a photo of the nice {}",
56
+ "a photo of the small {}",
57
+ "a photo of the weird {}",
58
+ "a photo of the large {}",
59
+ "a photo of a cool {}",
60
+ "a photo of a small {}",
61
+ ]
62
+
63
+ IMAGENET_STYLE_TEMPLATES_SMALL = [
64
+ "a painting in the style of {}",
65
+ "a rendering in the style of {}",
66
+ "a cropped painting in the style of {}",
67
+ "the painting in the style of {}",
68
+ "a clean painting in the style of {}",
69
+ "a dirty painting in the style of {}",
70
+ "a dark painting in the style of {}",
71
+ "a picture in the style of {}",
72
+ "a cool painting in the style of {}",
73
+ "a close-up painting in the style of {}",
74
+ "a bright painting in the style of {}",
75
+ "a cropped painting in the style of {}",
76
+ "a good painting in the style of {}",
77
+ "a close-up painting in the style of {}",
78
+ "a rendition in the style of {}",
79
+ "a nice painting in the style of {}",
80
+ "a small painting in the style of {}",
81
+ "a weird painting in the style of {}",
82
+ "a large painting in the style of {}",
83
+ ]
src/models/__init__.py ADDED
File without changes
src/models/net_clip_text_embedding.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import CLIPTextConfig
6
+
7
+ from src.models.neti_mapper import NeTIMapper
8
+ from src.utils.types import NeTIBatch
9
+
10
+
11
+ class NeTICLIPTextEmbeddings(nn.Module):
12
+ """ Modification of CLIPTextEmbedding to allow for the use of a NeTIMapper to overwrite the concept token. """
13
+
14
+ def __init__(self, config: CLIPTextConfig):
15
+ super().__init__()
16
+ embed_dim = config.hidden_size
17
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
18
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
19
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
20
+
21
+ def set_mapper(self, mapper: NeTIMapper):
22
+ self.mapper = mapper
23
+
24
+ def forward(self, input_ids: Optional[torch.LongTensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ inputs_embeds: Optional[torch.FloatTensor] = None,
27
+ batch: Optional[NeTIBatch] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
28
+
29
+ if batch is not None:
30
+ input_ids = batch.input_ids
31
+
32
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
33
+
34
+ if position_ids is None:
35
+ position_ids = self.position_ids[:, :seq_length]
36
+
37
+ if inputs_embeds is None:
38
+ inputs_embeds = self.token_embedding(input_ids)
39
+
40
+ ####################################################################
41
+ # NeTI logic - Use mapper to overwrite the learnable token embedding
42
+ ####################################################################
43
+ bypass_outputs = None
44
+ if batch is not None:
45
+ mapper_outputs = self.mapper(timestep=batch.timesteps.float(),
46
+ unet_layer=batch.unet_layers.float(),
47
+ truncation_idx=batch.truncation_idx)
48
+ mapper_outputs = mapper_outputs.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
49
+ if self.mapper.output_bypass:
50
+ bypass_outputs = mapper_outputs[:, mapper_outputs.shape[1] // 2:]
51
+ mapper_outputs = mapper_outputs[:, :mapper_outputs.shape[1] // 2]
52
+
53
+ # Overwrite the index of the placeholder token with the mapper output for each entry in the batch
54
+ learnable_idxs = (input_ids == batch.placeholder_token_id).nonzero(as_tuple=True)[1]
55
+ inputs_embeds[torch.arange(input_ids.shape[0]), learnable_idxs] = mapper_outputs
56
+
57
+ position_embeddings = self.position_embedding(position_ids)
58
+ embeddings = inputs_embeds + position_embeddings
59
+
60
+ return embeddings, bypass_outputs
src/models/neti_clip_text_encoder.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.models.clip.modeling_clip import CLIPTextConfig, CLIPTextModel, CLIPEncoder
8
+ from transformers.models.clip.modeling_clip import CLIPTextTransformer, _expand_mask
9
+
10
+ from src.models.net_clip_text_embedding import NeTICLIPTextEmbeddings
11
+ from src.utils.types import NeTIBatch
12
+
13
+
14
+ class NeTICLIPTextModel(CLIPTextModel):
15
+ """ Modification of CLIPTextModel to use our NeTI mapper for computing the embeddings of the concept. """
16
+
17
+ def __init__(self, config: CLIPTextConfig):
18
+ super().__init__(config)
19
+ self.text_model = NeTICLIPTextTransformer(config)
20
+ self.post_init()
21
+
22
+ def forward(self, input_ids: Optional[torch.Tensor] = None,
23
+ attention_mask: Optional[torch.Tensor] = None,
24
+ position_ids: Optional[torch.Tensor] = None,
25
+ output_attentions: Optional[bool] = None,
26
+ output_hidden_states: Optional[bool] = None,
27
+ return_dict: Optional[bool] = None,
28
+ batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
29
+ return self.text_model.forward(
30
+ batch=batch,
31
+ input_ids=input_ids,
32
+ attention_mask=attention_mask,
33
+ position_ids=position_ids,
34
+ output_attentions=output_attentions,
35
+ output_hidden_states=output_hidden_states,
36
+ return_dict=return_dict,
37
+ )
38
+
39
+
40
+ class NeTICLIPTextTransformer(CLIPTextTransformer):
41
+ """ Modification of CLIPTextTransformer to use our NeTI mapper for computing the embeddings of the concept. """
42
+
43
+ def __init__(self, config: CLIPTextConfig):
44
+ super().__init__(config=config)
45
+ self.config = config
46
+ embed_dim = config.hidden_size
47
+ self.embeddings = NeTICLIPTextEmbeddings(config)
48
+ self.encoder = CLIPEncoder(config)
49
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
50
+
51
+ def forward(self, input_ids: Optional[torch.Tensor] = None,
52
+ attention_mask: Optional[torch.Tensor] = None,
53
+ position_ids: Optional[torch.Tensor] = None,
54
+ output_attentions: Optional[bool] = None,
55
+ output_hidden_states: Optional[bool] = None,
56
+ return_dict: Optional[bool] = None,
57
+ batch: Optional[NeTIBatch] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
58
+
59
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
60
+ output_hidden_states = (
61
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
62
+ )
63
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
64
+
65
+ bypass_output = None
66
+
67
+ if input_ids is not None: # Regular embedding logic
68
+ input_shape = input_ids.size()
69
+ input_ids = input_ids.view(-1, input_shape[-1])
70
+ hidden_states, _ = self.embeddings(input_ids=input_ids, position_ids=position_ids)
71
+
72
+ ###########################
73
+ # NeTI logic
74
+ ###########################
75
+ elif batch is not None:
76
+ input_shape = batch.input_ids.size()
77
+ batch.input_ids = batch.input_ids.view(-1, input_shape[-1])
78
+ hidden_states, bypass_output = self.embeddings(batch=batch, position_ids=position_ids)
79
+
80
+ else:
81
+ raise ValueError("You have to specify either batch or input_ids!")
82
+
83
+ bsz, seq_len = input_shape
84
+ # CLIP's text model uses causal mask, prepare it here.
85
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
86
+ causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
87
+ hidden_states.device
88
+ )
89
+
90
+ # expand attention_mask
91
+ if attention_mask is not None:
92
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
93
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
94
+
95
+ encoder_outputs = self.encoder(
96
+ inputs_embeds=hidden_states,
97
+ attention_mask=attention_mask,
98
+ causal_attention_mask=causal_attention_mask,
99
+ output_attentions=output_attentions,
100
+ output_hidden_states=output_hidden_states,
101
+ return_dict=return_dict,
102
+ )
103
+
104
+ last_hidden_state = encoder_outputs[0]
105
+ last_hidden_state_with_bypass = last_hidden_state.clone()
106
+
107
+ ###############################################
108
+ # NeTI logic - compute the scaled bypass output
109
+ ###############################################
110
+ if bypass_output is not None:
111
+ learnable_idxs = (batch.input_ids == batch.placeholder_token_id).nonzero(as_tuple=True)[1]
112
+ existing_state = last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs]
113
+ bypass_output = bypass_output / bypass_output.norm(dim=1, keepdim=True) \
114
+ * existing_state.norm(dim=1, keepdim=True)
115
+ new_state = existing_state + 0.2 * bypass_output
116
+ new_state = new_state.to(dtype=hidden_states.dtype)
117
+ last_hidden_state_with_bypass[torch.arange(last_hidden_state.shape[0]), learnable_idxs] = new_state
118
+
119
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
120
+ last_hidden_state_with_bypass = self.final_layer_norm(last_hidden_state_with_bypass)
121
+
122
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
123
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
124
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
125
+ if input_ids is not None:
126
+ pooled_output = last_hidden_state[
127
+ torch.arange(last_hidden_state.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
128
+ ]
129
+ pooled_output_with_bypass = last_hidden_state_with_bypass[
130
+ torch.arange(last_hidden_state_with_bypass.shape[0]), input_ids.to(torch.int).argmax(dim=-1)
131
+ ]
132
+ elif batch is not None:
133
+ pooled_output = last_hidden_state[
134
+ torch.arange(last_hidden_state.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
135
+ ]
136
+ pooled_output_with_bypass = last_hidden_state_with_bypass[
137
+ torch.arange(last_hidden_state_with_bypass.shape[0]), batch.input_ids.to(torch.int).argmax(dim=-1)
138
+ ]
139
+ else:
140
+ raise ValueError("You have to specify either batch or input_ids!")
141
+
142
+ if bypass_output is not None:
143
+ return BaseModelOutputWithPooling(
144
+ last_hidden_state=last_hidden_state,
145
+ pooler_output=pooled_output,
146
+ hidden_states=encoder_outputs.hidden_states,
147
+ attentions=encoder_outputs.attentions,
148
+ ), BaseModelOutputWithPooling(
149
+ last_hidden_state=last_hidden_state_with_bypass,
150
+ pooler_output=pooled_output_with_bypass,
151
+ hidden_states=encoder_outputs.hidden_states,
152
+ attentions=encoder_outputs.attentions,
153
+ )
154
+ else:
155
+ return BaseModelOutputWithPooling(
156
+ last_hidden_state=last_hidden_state,
157
+ pooler_output=pooled_output,
158
+ hidden_states=encoder_outputs.hidden_states,
159
+ attentions=encoder_outputs.attentions,
160
+ ), None
src/models/neti_mapper.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from src.constants import UNET_LAYERS
9
+ from src.models.positional_encoding import NeTIPositionalEncoding, BasicEncoder
10
+ from src.utils.types import PESigmas
11
+
12
+
13
+ class NeTIMapper(nn.Module):
14
+ """ Main logic of our NeTI mapper. """
15
+
16
+ def __init__(self, output_dim: int = 768,
17
+ unet_layers: List[str] = UNET_LAYERS,
18
+ use_nested_dropout: bool = True,
19
+ nested_dropout_prob: float = 0.5,
20
+ norm_scale: Optional[torch.Tensor] = None,
21
+ use_positional_encoding: bool = True,
22
+ num_pe_time_anchors: int = 10,
23
+ pe_sigmas: PESigmas = PESigmas(sigma_t=0.03, sigma_l=2.0),
24
+ output_bypass: bool = True):
25
+ super().__init__()
26
+ self.use_nested_dropout = use_nested_dropout
27
+ self.nested_dropout_prob = nested_dropout_prob
28
+ self.norm_scale = norm_scale
29
+ self.output_bypass = output_bypass
30
+ if self.output_bypass:
31
+ output_dim *= 2 # Output two vectors
32
+
33
+ self.use_positional_encoding = use_positional_encoding
34
+ if self.use_positional_encoding:
35
+ self.encoder = NeTIPositionalEncoding(sigma_t=pe_sigmas.sigma_t, sigma_l=pe_sigmas.sigma_l).cuda()
36
+ self.input_dim = num_pe_time_anchors * len(unet_layers)
37
+ else:
38
+ self.encoder = BasicEncoder().cuda()
39
+ self.input_dim = 2
40
+
41
+ self.set_net(num_unet_layers=len(unet_layers),
42
+ num_time_anchors=num_pe_time_anchors,
43
+ output_dim=output_dim)
44
+
45
+ def set_net(self, num_unet_layers: int, num_time_anchors: int, output_dim: int = 768):
46
+ self.input_layer = self.set_input_layer(num_unet_layers, num_time_anchors)
47
+ self.net = nn.Sequential(self.input_layer,
48
+ nn.Linear(self.input_dim, 128), nn.LayerNorm(128), nn.LeakyReLU(),
49
+ nn.Linear(128, 128), nn.LayerNorm(128), nn.LeakyReLU())
50
+ self.output_layer = nn.Sequential(nn.Linear(128, output_dim))
51
+
52
+ def set_input_layer(self, num_unet_layers: int, num_time_anchors: int) -> nn.Module:
53
+ if self.use_positional_encoding:
54
+ input_layer = nn.Linear(self.encoder.num_w * 2, self.input_dim)
55
+ input_layer.weight.data = self.encoder.init_layer(num_time_anchors, num_unet_layers)
56
+ else:
57
+ input_layer = nn.Identity()
58
+ return input_layer
59
+
60
+ def forward(self, timestep: torch.Tensor, unet_layer: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
61
+ embedding = self.extract_hidden_representation(timestep, unet_layer)
62
+ if self.use_nested_dropout:
63
+ embedding = self.apply_nested_dropout(embedding, truncation_idx=truncation_idx)
64
+ embedding = self.get_output(embedding)
65
+ return embedding
66
+
67
+ def get_encoded_input(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
68
+ return self.encoder.encode(timestep, unet_layer)
69
+
70
+ def extract_hidden_representation(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
71
+ encoded_input = self.get_encoded_input(timestep, unet_layer)
72
+ embedding = self.net(encoded_input)
73
+ return embedding
74
+
75
+ def apply_nested_dropout(self, embedding: torch.Tensor, truncation_idx: int = None) -> torch.Tensor:
76
+ if self.training:
77
+ if random.random() < self.nested_dropout_prob:
78
+ dropout_idxs = torch.randint(low=0, high=embedding.shape[1], size=(embedding.shape[0],))
79
+ for idx in torch.arange(embedding.shape[0]):
80
+ embedding[idx][dropout_idxs[idx]:] = 0
81
+ if not self.training and truncation_idx is not None:
82
+ for idx in torch.arange(embedding.shape[0]):
83
+ embedding[idx][truncation_idx:] = 0
84
+ return embedding
85
+
86
+ def get_output(self, embedding: torch.Tensor) -> torch.Tensor:
87
+ embedding = self.output_layer(embedding)
88
+ if self.norm_scale is not None:
89
+ embedding = F.normalize(embedding, dim=-1) * self.norm_scale
90
+ return embedding
src/models/positional_encoding.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class NeTIPositionalEncoding(nn.Module):
8
+
9
+ def __init__(self, sigma_t: float, sigma_l: float, num_w: int = 1024):
10
+ super().__init__()
11
+ self.sigma_t = sigma_t
12
+ self.sigma_l = sigma_l
13
+ self.num_w = num_w
14
+ self.w = torch.randn((num_w, 2))
15
+ self.w[:, 0] *= sigma_t
16
+ self.w[:, 1] *= sigma_l
17
+ self.w = nn.Parameter(self.w).cuda()
18
+
19
+ def encode(self, t: Union[int, torch.Tensor], l: Union[int, torch.Tensor]):
20
+ """ Maps the given time and layer input into a 2048-dimensional vector. """
21
+ if type(t) == int or t.ndim == 0:
22
+ x = torch.tensor([t, l]).float()
23
+ else:
24
+ x = torch.stack([t, l], dim=1).T
25
+ x = x.cuda()
26
+ v = torch.cat([torch.sin(self.w.detach() @ x), torch.cos(self.w.detach() @ x)])
27
+ if type(t) == int:
28
+ v_norm = v / v.norm()
29
+ else:
30
+ v_norm = v / v.norm(dim=0)
31
+ v_norm = v_norm.T
32
+ return v_norm
33
+
34
+ def init_layer(self, num_time_anchors: int, num_layers: int) -> torch.Tensor:
35
+ """ Computes the weights for the positional encoding layer of size 160x2048."""
36
+ anchor_vectors = []
37
+ for t_anchor in range(0, 1000, 1000 // num_time_anchors):
38
+ for l_anchor in range(0, num_layers):
39
+ anchor_vectors.append(self.encode(t_anchor, l_anchor).float())
40
+ A = torch.stack(anchor_vectors)
41
+ return A
42
+
43
+
44
+ class BasicEncoder(nn.Module):
45
+ """ Simply normalizes the given timestep and unet layer to be between -1 and 1. """
46
+
47
+ def __init__(self, num_denoising_timesteps: int = 1000, num_unet_layers: int = 16):
48
+ super().__init__()
49
+ self.normalized_timesteps = (torch.arange(num_denoising_timesteps) / (num_denoising_timesteps - 1)) * 2 - 1
50
+ self.normalized_unet_layers = (torch.arange(num_unet_layers) / (num_unet_layers - 1)) * 2 - 1
51
+ self.normalized_timesteps = nn.Parameter(self.normalized_timesteps).cuda()
52
+ self.normalized_unet_layers = nn.Parameter(self.normalized_unet_layers).cuda()
53
+
54
+ def encode(self, timestep: torch.Tensor, unet_layer: torch.Tensor) -> torch.Tensor:
55
+ normalized_input = torch.stack([self.normalized_timesteps[timestep.long()],
56
+ self.normalized_unet_layers[unet_layer.long()]]).T
57
+ return normalized_input
src/models/xti_attention_processor.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ import torch
4
+ from diffusers.models.cross_attention import CrossAttention
5
+
6
+
7
+ class XTIAttenProc:
8
+
9
+ def __call__(self, attn: CrossAttention,
10
+ hidden_states: torch.Tensor,
11
+ encoder_hidden_states: Optional[Dict[str, torch.Tensor]] = None,
12
+ attention_mask: Optional[torch.Tensor] = None):
13
+
14
+ _ehs_bypass = None
15
+ if encoder_hidden_states is not None:
16
+ if isinstance(encoder_hidden_states, dict):
17
+ this_idx = encoder_hidden_states["this_idx"]
18
+ _ehs = encoder_hidden_states[f"CONTEXT_TENSOR_{this_idx}"]
19
+ if f"CONTEXT_TENSOR_BYPASS_{this_idx}" in encoder_hidden_states:
20
+ _ehs_bypass = encoder_hidden_states[f"CONTEXT_TENSOR_BYPASS_{this_idx}"]
21
+ encoder_hidden_states["this_idx"] += 1
22
+ encoder_hidden_states["this_idx"] %= 16
23
+ else:
24
+ _ehs = encoder_hidden_states
25
+ else:
26
+ _ehs = None
27
+
28
+ batch_size, sequence_length, _ = (hidden_states.shape if _ehs is None else _ehs.shape)
29
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
30
+ query = attn.to_q(hidden_states)
31
+
32
+ if _ehs is None:
33
+ _ehs = hidden_states
34
+ elif attn.cross_attention_norm:
35
+ _ehs = attn.norm_cross(_ehs)
36
+ _ehs_bypass = attn.norm_cross(_ehs_bypass)
37
+
38
+ key = attn.to_k(_ehs)
39
+ if _ehs_bypass is not None:
40
+ value = attn.to_v(_ehs_bypass)
41
+ else:
42
+ value = attn.to_v(_ehs)
43
+
44
+ query = attn.head_to_batch_dim(query)
45
+ key = attn.head_to_batch_dim(key)
46
+ value = attn.head_to_batch_dim(value)
47
+
48
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
49
+ hidden_states = torch.bmm(attention_probs, value)
50
+ hidden_states = attn.batch_to_head_dim(hidden_states)
51
+
52
+ # linear proj
53
+ hidden_states = attn.to_out[0](hidden_states)
54
+ # dropout
55
+ hidden_states = attn.to_out[1](hidden_states)
56
+
57
+ return hidden_states
src/prompt_manager.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Dict, Any
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from transformers import CLIPTokenizer
6
+
7
+ from src import constants
8
+ from src.models.neti_clip_text_encoder import NeTICLIPTextModel
9
+ from src.utils.types import NeTIBatch
10
+
11
+
12
+ class PromptManager:
13
+ """ Class for computing all time and space embeddings for a given prompt. """
14
+ def __init__(self, tokenizer: CLIPTokenizer,
15
+ text_encoder: NeTICLIPTextModel,
16
+ timesteps: List[int] = constants.SD_INFERENCE_TIMESTEPS,
17
+ unet_layers: List[str] = constants.UNET_LAYERS,
18
+ placeholder_token_id: Optional[List] = None,
19
+ placeholder_token: Optional[List] = None,
20
+ torch_dtype: torch.dtype = torch.float32):
21
+ self.tokenizer = tokenizer
22
+ self.text_encoder = text_encoder
23
+ self.timesteps = timesteps
24
+ self.unet_layers = unet_layers
25
+ self.placeholder_token = placeholder_token
26
+ self.placeholder_token_id = placeholder_token_id
27
+ self.dtype = torch_dtype
28
+
29
+ def embed_prompt(self, text: str,
30
+ truncation_idx: Optional[int] = None,
31
+ num_images_per_prompt: int = 1) -> List[Dict[str, Any]]:
32
+ """
33
+ Compute the conditioning vectors for the given prompt. We assume that the prompt is defined using `{}`
34
+ for indicating where to place the placeholder token string. See constants.VALIDATION_PROMPTS for examples.
35
+ """
36
+ text = text.format(self.placeholder_token)
37
+ ids = self.tokenizer(
38
+ text,
39
+ padding="max_length",
40
+ max_length=self.tokenizer.model_max_length,
41
+ return_tensors="pt",
42
+ ).input_ids
43
+
44
+ # Compute embeddings for each timestep and each U-Net layer
45
+ print(f"Computing embeddings over {len(self.timesteps)} timesteps and {len(self.unet_layers)} U-Net layers.")
46
+ hidden_states_per_timestep = []
47
+ for timestep in tqdm(self.timesteps):
48
+ _hs = {"this_idx": 0}.copy()
49
+ for layer_idx, unet_layer in enumerate(self.unet_layers):
50
+ batch = NeTIBatch(input_ids=ids.to(device=self.text_encoder.device),
51
+ timesteps=timestep.unsqueeze(0).to(device=self.text_encoder.device),
52
+ unet_layers=torch.tensor(layer_idx, device=self.text_encoder.device).unsqueeze(0),
53
+ placeholder_token_id=self.placeholder_token_id,
54
+ truncation_idx=truncation_idx)
55
+ layer_hs, layer_hs_bypass = self.text_encoder(batch=batch)
56
+ layer_hs = layer_hs[0].to(dtype=self.dtype)
57
+ _hs[f"CONTEXT_TENSOR_{layer_idx}"] = layer_hs.repeat(num_images_per_prompt, 1, 1)
58
+ if layer_hs_bypass is not None:
59
+ layer_hs_bypass = layer_hs_bypass[0].to(dtype=self.dtype)
60
+ _hs[f"CONTEXT_TENSOR_BYPASS_{layer_idx}"] = layer_hs_bypass.repeat(num_images_per_prompt, 1, 1)
61
+ hidden_states_per_timestep.append(_hs)
62
+ print("Done.")
63
+ return hidden_states_per_timestep
src/scripts/__init__.py ADDED
File without changes
src/scripts/inference.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import Optional, List, Tuple, Union
5
+
6
+ import numpy as np
7
+ import pyrallis
8
+ import torch
9
+ from PIL import Image
10
+ from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
11
+ from transformers import CLIPTokenizer
12
+
13
+ sys.path.append(".")
14
+ sys.path.append("..")
15
+
16
+ from src import constants
17
+ from src.models.neti_clip_text_encoder import NeTICLIPTextModel
18
+ from src.models.neti_mapper import NeTIMapper
19
+ from src.prompt_manager import PromptManager
20
+ from src.sd_pipeline_call import sd_pipeline_call
21
+ from src.models.xti_attention_processor import XTIAttenProc
22
+ from src.checkpoint_handler import CheckpointHandler
23
+ from src.utils import vis_utils
24
+
25
+
26
+ @dataclass
27
+ class InferenceConfig:
28
+ # Specifies which checkpoint iteration we want to load
29
+ iteration: Optional[int] = None
30
+ # The input directory containing the saved models and embeddings
31
+ input_dir: Optional[Path] = None
32
+ # Where the save the inference results to
33
+ inference_dir: Optional[Path] = None
34
+ # Specific path to the mapper you want to load, overrides `input_dir`
35
+ mapper_checkpoint_path: Optional[Path] = None
36
+ # Specific path to the embeddings you want to load, overrides `input_dir`
37
+ learned_embeds_path: Optional[Path] = None
38
+ # List of prompts to run inference on
39
+ prompts: Optional[List[str]] = None
40
+ # Text file containing a prompts to run inference on (one prompt per line), overrides `prompts`
41
+ prompts_file_path: Optional[Path] = None
42
+ # List of random seeds to run on
43
+ seeds: List[int] = field(default_factory=lambda: [42])
44
+ # If you want to run with dropout at inference time, this specifies the truncation indices for applying dropout.
45
+ # None indicates that no dropout will be performed. If a list of indices is provided, will run all indices.
46
+ truncation_idxs: Optional[Union[int, List[int]]] = None
47
+ # Whether to run with torch.float16 or torch.float32
48
+ torch_dtype: str = "fp16"
49
+
50
+ def __post_init__(self):
51
+ assert bool(self.prompts) != bool(self.prompts_file_path), \
52
+ "You must provide either prompts or prompts_file_path, but not both!"
53
+ self._set_prompts()
54
+ self._set_input_paths()
55
+ self.inference_dir.mkdir(exist_ok=True, parents=True)
56
+ if type(self.truncation_idxs) == int:
57
+ self.truncation_idxs = [self.truncation_idxs]
58
+ self.torch_dtype = torch.float16 if self.torch_dtype == "fp16" else torch.float32
59
+
60
+ def _set_input_paths(self):
61
+ if self.inference_dir is None:
62
+ assert self.input_dir is not None, "You must pass an input_dir if you do not specify inference_dir"
63
+ self.inference_dir = self.input_dir / f"inference_{self.iteration}"
64
+ if self.mapper_checkpoint_path is None:
65
+ assert self.input_dir is not None, "You must pass an input_dir if you do not specify mapper_checkpoint_path"
66
+ self.mapper_checkpoint_path = self.input_dir / f"mapper-steps-{self.iteration}.pt"
67
+ if self.learned_embeds_path is None:
68
+ assert self.input_dir is not None, "You must pass an input_dir if you do not specify learned_embeds_path"
69
+ self.learned_embeds_path = self.input_dir / f"learned_embeds-steps-{self.iteration}.bin"
70
+
71
+ def _set_prompts(self):
72
+ if self.prompts_file_path is not None:
73
+ assert self.prompts_file_path.exists(), f"Prompts file {self.prompts_file_path} does not exist!"
74
+ self.prompts = self.prompts_file_path.read_text().splitlines()
75
+
76
+
77
+ @pyrallis.wrap()
78
+ def main(infer_cfg: InferenceConfig):
79
+ train_cfg, mapper = CheckpointHandler.load_mapper(infer_cfg.mapper_checkpoint_path)
80
+ pipeline, placeholder_token, placeholder_token_id = load_stable_diffusion_model(
81
+ pretrained_model_name_or_path=train_cfg.model.pretrained_model_name_or_path,
82
+ mapper=mapper,
83
+ learned_embeds_path=infer_cfg.learned_embeds_path,
84
+ torch_dtype=infer_cfg.torch_dtype
85
+ )
86
+ prompt_manager = PromptManager(tokenizer=pipeline.tokenizer,
87
+ text_encoder=pipeline.text_encoder,
88
+ timesteps=pipeline.scheduler.timesteps,
89
+ unet_layers=constants.UNET_LAYERS,
90
+ placeholder_token=placeholder_token,
91
+ placeholder_token_id=placeholder_token_id,
92
+ torch_dtype=infer_cfg.torch_dtype)
93
+ for prompt in infer_cfg.prompts:
94
+ output_path = infer_cfg.inference_dir / prompt.format(placeholder_token)
95
+ output_path.mkdir(exist_ok=True, parents=True)
96
+ for truncation_idx in infer_cfg.truncation_idxs:
97
+ print(f"Running with truncation index: {truncation_idx}")
98
+ prompt_image = run_inference(prompt=prompt,
99
+ pipeline=pipeline,
100
+ prompt_manager=prompt_manager,
101
+ seeds=infer_cfg.seeds,
102
+ output_path=output_path,
103
+ num_images_per_prompt=1,
104
+ truncation_idx=truncation_idx)
105
+ if truncation_idx is not None:
106
+ save_name = f"{prompt.format(placeholder_token)}_truncation_{truncation_idx}.png"
107
+ else:
108
+ save_name = f"{prompt.format(placeholder_token)}.png"
109
+ prompt_image.save(infer_cfg.inference_dir / save_name)
110
+
111
+
112
+ def run_inference(prompt: str,
113
+ pipeline: StableDiffusionPipeline,
114
+ prompt_manager: PromptManager,
115
+ seeds: List[int],
116
+ output_path: Optional[Path] = None,
117
+ num_images_per_prompt: int = 1,
118
+ truncation_idx: Optional[int] = None) -> Image.Image:
119
+ with torch.autocast("cuda"):
120
+ with torch.no_grad():
121
+ prompt_embeds = prompt_manager.embed_prompt(prompt,
122
+ num_images_per_prompt=num_images_per_prompt,
123
+ truncation_idx=truncation_idx)
124
+ joined_images = []
125
+ for seed in seeds:
126
+ generator = torch.Generator(device='cuda').manual_seed(seed)
127
+ images = sd_pipeline_call(pipeline,
128
+ prompt_embeds=prompt_embeds,
129
+ generator=generator,
130
+ num_images_per_prompt=num_images_per_prompt).images
131
+ seed_image = Image.fromarray(np.concatenate(images, axis=1)).convert("RGB")
132
+ if output_path is not None:
133
+ save_name = f'{seed}_truncation_{truncation_idx}.png' if truncation_idx is not None else f'{seed}.png'
134
+ seed_image.save(output_path / save_name)
135
+ joined_images.append(seed_image)
136
+ joined_image = vis_utils.get_image_grid(joined_images)
137
+ return joined_image
138
+
139
+
140
+ def load_stable_diffusion_model(pretrained_model_name_or_path: str,
141
+ learned_embeds_path: Path,
142
+ mapper: Optional[NeTIMapper] = None,
143
+ num_denoising_steps: int = 50,
144
+ torch_dtype: torch.dtype = torch.float16) -> Tuple[StableDiffusionPipeline, str, int]:
145
+ tokenizer = CLIPTokenizer.from_pretrained(
146
+ pretrained_model_name_or_path, subfolder="tokenizer")
147
+ text_encoder = NeTICLIPTextModel.from_pretrained(
148
+ pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype,
149
+ )
150
+ if mapper is not None:
151
+ text_encoder.text_model.embeddings.set_mapper(mapper)
152
+ placeholder_token, placeholder_token_id = CheckpointHandler.load_learned_embed_in_clip(
153
+ learned_embeds_path=learned_embeds_path,
154
+ text_encoder=text_encoder,
155
+ tokenizer=tokenizer
156
+ )
157
+ pipeline = StableDiffusionPipeline.from_pretrained(
158
+ pretrained_model_name_or_path,
159
+ torch_dtype=torch_dtype,
160
+ text_encoder=text_encoder,
161
+ tokenizer=tokenizer
162
+ ).to("cuda")
163
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
164
+ pipeline.scheduler.set_timesteps(num_denoising_steps, device=pipeline.device)
165
+ pipeline.unet.set_attn_processor(XTIAttenProc())
166
+ return pipeline, placeholder_token, placeholder_token_id
167
+
168
+
169
+ if __name__ == '__main__':
170
+ main()
src/sd_pipeline_call.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline
5
+
6
+
7
+ @torch.no_grad()
8
+ def sd_pipeline_call(
9
+ pipeline: StableDiffusionPipeline,
10
+ prompt_embeds: torch.FloatTensor,
11
+ height: Optional[int] = None,
12
+ width: Optional[int] = None,
13
+ num_inference_steps: int = 50,
14
+ guidance_scale: float = 7.5,
15
+ negative_prompt: Optional[Union[str, List[str]]] = None,
16
+ num_images_per_prompt: Optional[int] = 1,
17
+ eta: float = 0.0,
18
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
19
+ latents: Optional[torch.FloatTensor] = None,
20
+ output_type: Optional[str] = "pil",
21
+ return_dict: bool = True,
22
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
23
+ callback_steps: int = 1,
24
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None):
25
+ """ Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument."""
26
+
27
+ # 0. Default height and width to unet
28
+ height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
29
+ width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
30
+
31
+ # 2. Define call parameters
32
+ batch_size = 1
33
+ device = pipeline._execution_device
34
+
35
+ neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
36
+ negative_prompt_embeds, _ = pipeline.text_encoder(
37
+ input_ids=neg_prompt.input_ids.to(device),
38
+ attention_mask=None,
39
+ )
40
+ negative_prompt_embeds = negative_prompt_embeds[0]
41
+
42
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
43
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
44
+ # corresponds to doing no classifier free guidance.
45
+ do_classifier_free_guidance = guidance_scale > 1.0
46
+
47
+ # 4. Prepare timesteps
48
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
49
+ timesteps = pipeline.scheduler.timesteps
50
+
51
+ # 5. Prepare latent variables
52
+ num_channels_latents = pipeline.unet.in_channels
53
+ latents = pipeline.prepare_latents(
54
+ batch_size * num_images_per_prompt,
55
+ num_channels_latents,
56
+ height,
57
+ width,
58
+ pipeline.text_encoder.dtype,
59
+ device,
60
+ generator,
61
+ latents,
62
+ )
63
+
64
+ # 6. Prepare extra step kwargs.
65
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
66
+
67
+ # 7. Denoising loop
68
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
69
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
70
+ for i, t in enumerate(timesteps):
71
+
72
+ if do_classifier_free_guidance:
73
+ latent_model_input = latents
74
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
75
+
76
+ # predict the noise residual
77
+ noise_pred_uncond = pipeline.unet(
78
+ latent_model_input,
79
+ t,
80
+ encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
81
+ cross_attention_kwargs=cross_attention_kwargs,
82
+ ).sample
83
+
84
+ ###############################################################
85
+ # NeTI logic: use the prompt embedding for the current timestep
86
+ ###############################################################
87
+ embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds
88
+ noise_pred_text = pipeline.unet(
89
+ latent_model_input,
90
+ t,
91
+ encoder_hidden_states=embed,
92
+ cross_attention_kwargs=cross_attention_kwargs,
93
+ ).sample
94
+
95
+ # perform guidance
96
+ if do_classifier_free_guidance:
97
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
98
+
99
+ # compute the previous noisy sample x_t -> x_t-1
100
+ latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
101
+
102
+ # call the callback, if provided
103
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
104
+ progress_bar.update()
105
+ if callback is not None and i % callback_steps == 0:
106
+ callback(i, t, latents)
107
+
108
+ if output_type == "latent":
109
+ image = latents
110
+ has_nsfw_concept = None
111
+ elif output_type == "pil":
112
+ # 8. Post-processing
113
+ image = pipeline.decode_latents(latents)
114
+ # 9. Run safety checker
115
+ image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
116
+ # 10. Convert to PIL
117
+ image = pipeline.numpy_to_pil(image)
118
+ else:
119
+ # 8. Post-processing
120
+ image = pipeline.decode_latents(latents)
121
+ # 9. Run safety checker
122
+ image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
123
+
124
+ # Offload last model to CPU
125
+ if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
126
+ pipeline.final_offload_hook.offload()
127
+
128
+ if not return_dict:
129
+ return image, has_nsfw_concept
130
+
131
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
132
+
133
+
134
+ def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline,
135
+ negative_prompt: Optional[Union[str, List[str]]] = None):
136
+ if negative_prompt is None:
137
+ negative_prompt = ""
138
+ uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
139
+ uncond_input = pipeline.tokenizer(
140
+ uncond_tokens,
141
+ padding="max_length",
142
+ max_length=pipeline.tokenizer.model_max_length,
143
+ truncation=True,
144
+ return_tensors="pt",
145
+ )
146
+ return uncond_input
src/utils/__init__.py ADDED
File without changes
src/utils/types.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ @dataclass
9
+ class NeTIBatch:
10
+ input_ids: torch.Tensor
11
+ placeholder_token_id: int
12
+ timesteps: torch.Tensor
13
+ unet_layers: torch.Tensor
14
+ truncation_idx: Optional[int] = None
15
+
16
+
17
+ @dataclass
18
+ class PESigmas:
19
+ sigma_t: float
20
+ sigma_l: float
src/utils/vis_utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ from PIL import Image
5
+
6
+
7
+ def get_image_grid(images: List[Image.Image]) -> Image:
8
+ num_images = len(images)
9
+ cols = int(math.ceil(math.sqrt(num_images)))
10
+ rows = int(math.ceil(num_images / cols))
11
+ width, height = images[0].size
12
+ grid_image = Image.new('RGB', (cols * width, rows * height))
13
+ for i, img in enumerate(images):
14
+ x = i % cols
15
+ y = i // cols
16
+ grid_image.paste(img, (x * width, y * height))
17
+ return grid_image
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }