from dataclasses import dataclass, field | |
from pathlib import Path | |
from typing import List, Optional, Dict | |
from constants import VALIDATION_PROMPTS | |
from utils.types import PESigmas | |
class LogConfig: | |
""" Parameters for logging and saving """ | |
# Name of experiment. This will be the name of the output folder | |
exp_name: str | |
# The output directory where the model predictions and checkpoints will be written | |
exp_dir: Path = Path("./outputs") | |
# Save interval | |
save_steps: int = 250 | |
# [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to | |
# `output_dir/runs/**CURRENT_DATETIME_HOSTNAME` | |
logging_dir: Path = Path("logs") | |
# The integration to report the results to. Supported platforms are "tensorboard" ' | |
# (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
report_to: str = "tensorboard" | |
# Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` | |
checkpoints_total_limit: Optional[int] = None | |
class DataConfig: | |
""" Parameters for data """ | |
# A folder containing the training data | |
train_data_dir: Path | |
# A token to use as a placeholder for the concept | |
placeholder_token: str | |
# Super category token to use for normalizing the mapper output | |
super_category_token: Optional[str] = "object" | |
# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process | |
dataloader_num_workers: int = 8 | |
# Choose between 'object' and 'style' - used for selecting the prompts for training | |
learnable_property: str = "object" | |
# How many times to repeat the training data | |
repeats: int = 100 | |
# The resolution for input images, all the images in the train/validation dataset will be resized to this resolution | |
resolution: int = 512 | |
# Whether to center crop images before resizing to resolution | |
center_crop: bool = False | |
class ModelConfig: | |
""" Parameters for defining all models """ | |
# Path to pretrained model or model identifier from huggingface.co/models | |
pretrained_model_name_or_path: str = "CompVis/stable-diffusion-v1-4" | |
# Whether to use our Nested Dropout technique | |
use_nested_dropout: bool = True | |
# Probability to apply nested dropout during training | |
nested_dropout_prob: float = 0.5 | |
# Whether to normalize the norm of the mapper's output vector | |
normalize_mapper_output: bool = True | |
# Target norm for the mapper's output vector | |
target_norm: Optional[float] = None | |
# Whether to use positional encoding over the input to the mapper | |
use_positional_encoding: bool = True | |
# Sigmas used for computing positional encoding | |
pe_sigmas: Dict[str, float] = field(default_factory=lambda: {'sigma_t': 0.03, 'sigma_l': 2.0}) | |
# Number of time anchors for computing our positional encodings | |
num_pe_time_anchors: int = 10 | |
# Whether to output the textual bypass vector | |
output_bypass: bool = True | |
# Revision of pretrained model identifier from huggingface.co/models | |
revision: Optional[str] = None | |
# Whether training should be resumed from a previous checkpoint. | |
mapper_checkpoint_path: Optional[Path] = None | |
def __post_init__(self): | |
if self.pe_sigmas is not None: | |
assert len(self.pe_sigmas) == 2, "Should provide exactly two sigma values: one for two and one for layers!" | |
self.pe_sigmas = PESigmas(sigma_t=self.pe_sigmas['sigma_t'], sigma_l=self.pe_sigmas['sigma_l']) | |
class EvalConfig: | |
""" Parameters for validation """ | |
# A list of prompts that will be used during validation to verify that the model is learning | |
validation_prompts: List[str] = field(default_factory=lambda: VALIDATION_PROMPTS) | |
# Number of images that should be generated during validation with `validation_prompt` | |
num_validation_images: int = 4 | |
# Seeds to use for generating the validation images | |
validation_seeds: Optional[List[int]] = field(default_factory=lambda: [42, 420, 501, 5456]) | |
# Run validation every X steps. | |
validation_steps: int = 100 | |
# Number of denoising steps | |
num_denoising_steps: int = 50 | |
def __post_init__(self): | |
if self.validation_seeds is None: | |
self.validation_seeds = list(range(self.num_validation_images)) | |
assert len(self.validation_seeds) == self.num_validation_images, \ | |
"Length of validation_seeds should equal num_validation_images" | |
class OptimConfig: | |
""" Parameters for the optimization process """ | |
# Total number of training steps to perform. | |
max_train_steps: Optional[int] = 1_000 | |
# Learning rate | |
learning_rate: float = 1e-3 | |
# Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size | |
scale_lr: bool = True | |
# Batch size (per device) for the training dataloader | |
train_batch_size: int = 2 | |
# Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass | |
gradient_checkpointing: bool = False | |
# Number of updates steps to accumulate before performing a backward/update pass | |
gradient_accumulation_steps: int = 4 | |
# A seed for reproducible training | |
seed: Optional[int] = None | |
# The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", | |
# "constant", "constant_with_warmup"] | |
lr_scheduler: str = "constant" | |
# Number of steps for the warmup in the lr scheduler | |
lr_warmup_steps: int = 0 | |
# The beta1 parameter for the Adam optimizer | |
adam_beta1: float = 0.9 | |
# The beta2 parameter for the Adam optimizer | |
adam_beta2: float = 0.999 | |
# Weight decay to use | |
adam_weight_decay: float = 1e-2 | |
# Epsilon value for the Adam optimizer | |
adam_epsilon: float = 1e-08 | |
# Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10. | |
# and an Nvidia Ampere GPU. | |
mixed_precision: str = "no" | |
# Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see | |
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
allow_tf32: bool = False | |
class RunConfig: | |
""" The main configuration for the coach trainer """ | |
log: LogConfig = field(default_factory=LogConfig) | |
data: DataConfig = field(default_factory=DataConfig) | |
model: ModelConfig = field(default_factory=ModelConfig) | |
eval: EvalConfig = field(default_factory=EvalConfig) | |
optim: OptimConfig = field(default_factory=OptimConfig) | |