Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
097567a
1
Parent(s):
90e5afe
Delete sgm
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- sgm/__init__.py +0 -4
- sgm/data/__init__.py +0 -1
- sgm/data/cifar10.py +0 -67
- sgm/data/dataset.py +0 -80
- sgm/data/mnist.py +0 -85
- sgm/inference/api.py +0 -385
- sgm/inference/helpers.py +0 -305
- sgm/lr_scheduler.py +0 -135
- sgm/models/__init__.py +0 -2
- sgm/models/autoencoder.py +0 -615
- sgm/models/diffusion.py +0 -341
- sgm/modules/__init__.py +0 -6
- sgm/modules/attention.py +0 -759
- sgm/modules/autoencoding/__init__.py +0 -0
- sgm/modules/autoencoding/losses/__init__.py +0 -7
- sgm/modules/autoencoding/losses/discriminator_loss.py +0 -306
- sgm/modules/autoencoding/losses/lpips.py +0 -73
- sgm/modules/autoencoding/lpips/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/.gitignore +0 -1
- sgm/modules/autoencoding/lpips/loss/LICENSE +0 -23
- sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/loss/lpips.py +0 -147
- sgm/modules/autoencoding/lpips/model/LICENSE +0 -58
- sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
- sgm/modules/autoencoding/lpips/model/model.py +0 -88
- sgm/modules/autoencoding/lpips/util.py +0 -128
- sgm/modules/autoencoding/lpips/vqperceptual.py +0 -17
- sgm/modules/autoencoding/regularizers/__init__.py +0 -31
- sgm/modules/autoencoding/regularizers/base.py +0 -40
- sgm/modules/autoencoding/regularizers/quantize.py +0 -487
- sgm/modules/autoencoding/temporal_ae.py +0 -349
- sgm/modules/diffusionmodules/__init__.py +0 -0
- sgm/modules/diffusionmodules/denoiser.py +0 -75
- sgm/modules/diffusionmodules/denoiser_scaling.py +0 -59
- sgm/modules/diffusionmodules/denoiser_weighting.py +0 -24
- sgm/modules/diffusionmodules/discretizer.py +0 -69
- sgm/modules/diffusionmodules/guiders.py +0 -99
- sgm/modules/diffusionmodules/loss.py +0 -105
- sgm/modules/diffusionmodules/loss_weighting.py +0 -32
- sgm/modules/diffusionmodules/model.py +0 -748
- sgm/modules/diffusionmodules/openaimodel.py +0 -853
- sgm/modules/diffusionmodules/sampling.py +0 -362
- sgm/modules/diffusionmodules/sampling_utils.py +0 -43
- sgm/modules/diffusionmodules/sigma_sampling.py +0 -31
- sgm/modules/diffusionmodules/util.py +0 -369
- sgm/modules/diffusionmodules/video_model.py +0 -493
- sgm/modules/diffusionmodules/wrappers.py +0 -34
- sgm/modules/distributions/__init__.py +0 -0
- sgm/modules/distributions/distributions.py +0 -102
- sgm/modules/ema.py +0 -86
sgm/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
from .models import AutoencodingEngine, DiffusionEngine
|
2 |
-
from .util import get_configs_path, instantiate_from_config
|
3 |
-
|
4 |
-
__version__ = "0.1.0"
|
|
|
|
|
|
|
|
|
|
sgm/data/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .dataset import StableDataModuleFromConfig
|
|
|
|
sgm/data/cifar10.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import pytorch_lightning as pl
|
2 |
-
import torchvision
|
3 |
-
from torch.utils.data import DataLoader, Dataset
|
4 |
-
from torchvision import transforms
|
5 |
-
|
6 |
-
|
7 |
-
class CIFAR10DataDictWrapper(Dataset):
|
8 |
-
def __init__(self, dset):
|
9 |
-
super().__init__()
|
10 |
-
self.dset = dset
|
11 |
-
|
12 |
-
def __getitem__(self, i):
|
13 |
-
x, y = self.dset[i]
|
14 |
-
return {"jpg": x, "cls": y}
|
15 |
-
|
16 |
-
def __len__(self):
|
17 |
-
return len(self.dset)
|
18 |
-
|
19 |
-
|
20 |
-
class CIFAR10Loader(pl.LightningDataModule):
|
21 |
-
def __init__(self, batch_size, num_workers=0, shuffle=True):
|
22 |
-
super().__init__()
|
23 |
-
|
24 |
-
transform = transforms.Compose(
|
25 |
-
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
-
)
|
27 |
-
|
28 |
-
self.batch_size = batch_size
|
29 |
-
self.num_workers = num_workers
|
30 |
-
self.shuffle = shuffle
|
31 |
-
self.train_dataset = CIFAR10DataDictWrapper(
|
32 |
-
torchvision.datasets.CIFAR10(
|
33 |
-
root=".data/", train=True, download=True, transform=transform
|
34 |
-
)
|
35 |
-
)
|
36 |
-
self.test_dataset = CIFAR10DataDictWrapper(
|
37 |
-
torchvision.datasets.CIFAR10(
|
38 |
-
root=".data/", train=False, download=True, transform=transform
|
39 |
-
)
|
40 |
-
)
|
41 |
-
|
42 |
-
def prepare_data(self):
|
43 |
-
pass
|
44 |
-
|
45 |
-
def train_dataloader(self):
|
46 |
-
return DataLoader(
|
47 |
-
self.train_dataset,
|
48 |
-
batch_size=self.batch_size,
|
49 |
-
shuffle=self.shuffle,
|
50 |
-
num_workers=self.num_workers,
|
51 |
-
)
|
52 |
-
|
53 |
-
def test_dataloader(self):
|
54 |
-
return DataLoader(
|
55 |
-
self.test_dataset,
|
56 |
-
batch_size=self.batch_size,
|
57 |
-
shuffle=self.shuffle,
|
58 |
-
num_workers=self.num_workers,
|
59 |
-
)
|
60 |
-
|
61 |
-
def val_dataloader(self):
|
62 |
-
return DataLoader(
|
63 |
-
self.test_dataset,
|
64 |
-
batch_size=self.batch_size,
|
65 |
-
shuffle=self.shuffle,
|
66 |
-
num_workers=self.num_workers,
|
67 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/data/dataset.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
-
import torchdata.datapipes.iter
|
4 |
-
import webdataset as wds
|
5 |
-
from omegaconf import DictConfig
|
6 |
-
from pytorch_lightning import LightningDataModule
|
7 |
-
|
8 |
-
try:
|
9 |
-
from sdata import create_dataset, create_dummy_dataset, create_loader
|
10 |
-
except ImportError as e:
|
11 |
-
print("#" * 100)
|
12 |
-
print("Datasets not yet available")
|
13 |
-
print("to enable, we need to add stable-datasets as a submodule")
|
14 |
-
print("please use ``git submodule update --init --recursive``")
|
15 |
-
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
16 |
-
print("#" * 100)
|
17 |
-
exit(1)
|
18 |
-
|
19 |
-
|
20 |
-
class StableDataModuleFromConfig(LightningDataModule):
|
21 |
-
def __init__(
|
22 |
-
self,
|
23 |
-
train: DictConfig,
|
24 |
-
validation: Optional[DictConfig] = None,
|
25 |
-
test: Optional[DictConfig] = None,
|
26 |
-
skip_val_loader: bool = False,
|
27 |
-
dummy: bool = False,
|
28 |
-
):
|
29 |
-
super().__init__()
|
30 |
-
self.train_config = train
|
31 |
-
assert (
|
32 |
-
"datapipeline" in self.train_config and "loader" in self.train_config
|
33 |
-
), "train config requires the fields `datapipeline` and `loader`"
|
34 |
-
|
35 |
-
self.val_config = validation
|
36 |
-
if not skip_val_loader:
|
37 |
-
if self.val_config is not None:
|
38 |
-
assert (
|
39 |
-
"datapipeline" in self.val_config and "loader" in self.val_config
|
40 |
-
), "validation config requires the fields `datapipeline` and `loader`"
|
41 |
-
else:
|
42 |
-
print(
|
43 |
-
"Warning: No Validation datapipeline defined, using that one from training"
|
44 |
-
)
|
45 |
-
self.val_config = train
|
46 |
-
|
47 |
-
self.test_config = test
|
48 |
-
if self.test_config is not None:
|
49 |
-
assert (
|
50 |
-
"datapipeline" in self.test_config and "loader" in self.test_config
|
51 |
-
), "test config requires the fields `datapipeline` and `loader`"
|
52 |
-
|
53 |
-
self.dummy = dummy
|
54 |
-
if self.dummy:
|
55 |
-
print("#" * 100)
|
56 |
-
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
57 |
-
print("#" * 100)
|
58 |
-
|
59 |
-
def setup(self, stage: str) -> None:
|
60 |
-
print("Preparing datasets")
|
61 |
-
if self.dummy:
|
62 |
-
data_fn = create_dummy_dataset
|
63 |
-
else:
|
64 |
-
data_fn = create_dataset
|
65 |
-
|
66 |
-
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
|
67 |
-
if self.val_config:
|
68 |
-
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
|
69 |
-
if self.test_config:
|
70 |
-
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
|
71 |
-
|
72 |
-
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
|
73 |
-
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
|
74 |
-
return loader
|
75 |
-
|
76 |
-
def val_dataloader(self) -> wds.DataPipeline:
|
77 |
-
return create_loader(self.val_datapipeline, **self.val_config.loader)
|
78 |
-
|
79 |
-
def test_dataloader(self) -> wds.DataPipeline:
|
80 |
-
return create_loader(self.test_datapipeline, **self.test_config.loader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/data/mnist.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
import pytorch_lightning as pl
|
2 |
-
import torchvision
|
3 |
-
from torch.utils.data import DataLoader, Dataset
|
4 |
-
from torchvision import transforms
|
5 |
-
|
6 |
-
|
7 |
-
class MNISTDataDictWrapper(Dataset):
|
8 |
-
def __init__(self, dset):
|
9 |
-
super().__init__()
|
10 |
-
self.dset = dset
|
11 |
-
|
12 |
-
def __getitem__(self, i):
|
13 |
-
x, y = self.dset[i]
|
14 |
-
return {"jpg": x, "cls": y}
|
15 |
-
|
16 |
-
def __len__(self):
|
17 |
-
return len(self.dset)
|
18 |
-
|
19 |
-
|
20 |
-
class MNISTLoader(pl.LightningDataModule):
|
21 |
-
def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
|
22 |
-
super().__init__()
|
23 |
-
|
24 |
-
transform = transforms.Compose(
|
25 |
-
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
26 |
-
)
|
27 |
-
|
28 |
-
self.batch_size = batch_size
|
29 |
-
self.num_workers = num_workers
|
30 |
-
self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
|
31 |
-
self.shuffle = shuffle
|
32 |
-
self.train_dataset = MNISTDataDictWrapper(
|
33 |
-
torchvision.datasets.MNIST(
|
34 |
-
root=".data/", train=True, download=True, transform=transform
|
35 |
-
)
|
36 |
-
)
|
37 |
-
self.test_dataset = MNISTDataDictWrapper(
|
38 |
-
torchvision.datasets.MNIST(
|
39 |
-
root=".data/", train=False, download=True, transform=transform
|
40 |
-
)
|
41 |
-
)
|
42 |
-
|
43 |
-
def prepare_data(self):
|
44 |
-
pass
|
45 |
-
|
46 |
-
def train_dataloader(self):
|
47 |
-
return DataLoader(
|
48 |
-
self.train_dataset,
|
49 |
-
batch_size=self.batch_size,
|
50 |
-
shuffle=self.shuffle,
|
51 |
-
num_workers=self.num_workers,
|
52 |
-
prefetch_factor=self.prefetch_factor,
|
53 |
-
)
|
54 |
-
|
55 |
-
def test_dataloader(self):
|
56 |
-
return DataLoader(
|
57 |
-
self.test_dataset,
|
58 |
-
batch_size=self.batch_size,
|
59 |
-
shuffle=self.shuffle,
|
60 |
-
num_workers=self.num_workers,
|
61 |
-
prefetch_factor=self.prefetch_factor,
|
62 |
-
)
|
63 |
-
|
64 |
-
def val_dataloader(self):
|
65 |
-
return DataLoader(
|
66 |
-
self.test_dataset,
|
67 |
-
batch_size=self.batch_size,
|
68 |
-
shuffle=self.shuffle,
|
69 |
-
num_workers=self.num_workers,
|
70 |
-
prefetch_factor=self.prefetch_factor,
|
71 |
-
)
|
72 |
-
|
73 |
-
|
74 |
-
if __name__ == "__main__":
|
75 |
-
dset = MNISTDataDictWrapper(
|
76 |
-
torchvision.datasets.MNIST(
|
77 |
-
root=".data/",
|
78 |
-
train=False,
|
79 |
-
download=True,
|
80 |
-
transform=transforms.Compose(
|
81 |
-
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
|
82 |
-
),
|
83 |
-
)
|
84 |
-
)
|
85 |
-
ex = dset[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/inference/api.py
DELETED
@@ -1,385 +0,0 @@
|
|
1 |
-
import pathlib
|
2 |
-
from dataclasses import asdict, dataclass
|
3 |
-
from enum import Enum
|
4 |
-
from typing import Optional
|
5 |
-
|
6 |
-
from omegaconf import OmegaConf
|
7 |
-
|
8 |
-
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
|
9 |
-
do_sample)
|
10 |
-
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
11 |
-
DPMPP2SAncestralSampler,
|
12 |
-
EulerAncestralSampler,
|
13 |
-
EulerEDMSampler,
|
14 |
-
HeunEDMSampler,
|
15 |
-
LinearMultistepSampler)
|
16 |
-
from sgm.util import load_model_from_config
|
17 |
-
|
18 |
-
|
19 |
-
class ModelArchitecture(str, Enum):
|
20 |
-
SD_2_1 = "stable-diffusion-v2-1"
|
21 |
-
SD_2_1_768 = "stable-diffusion-v2-1-768"
|
22 |
-
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
23 |
-
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
24 |
-
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
25 |
-
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
26 |
-
|
27 |
-
|
28 |
-
class Sampler(str, Enum):
|
29 |
-
EULER_EDM = "EulerEDMSampler"
|
30 |
-
HEUN_EDM = "HeunEDMSampler"
|
31 |
-
EULER_ANCESTRAL = "EulerAncestralSampler"
|
32 |
-
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
|
33 |
-
DPMPP2M = "DPMPP2MSampler"
|
34 |
-
LINEAR_MULTISTEP = "LinearMultistepSampler"
|
35 |
-
|
36 |
-
|
37 |
-
class Discretization(str, Enum):
|
38 |
-
LEGACY_DDPM = "LegacyDDPMDiscretization"
|
39 |
-
EDM = "EDMDiscretization"
|
40 |
-
|
41 |
-
|
42 |
-
class Guider(str, Enum):
|
43 |
-
VANILLA = "VanillaCFG"
|
44 |
-
IDENTITY = "IdentityGuider"
|
45 |
-
|
46 |
-
|
47 |
-
class Thresholder(str, Enum):
|
48 |
-
NONE = "None"
|
49 |
-
|
50 |
-
|
51 |
-
@dataclass
|
52 |
-
class SamplingParams:
|
53 |
-
width: int = 1024
|
54 |
-
height: int = 1024
|
55 |
-
steps: int = 50
|
56 |
-
sampler: Sampler = Sampler.DPMPP2M
|
57 |
-
discretization: Discretization = Discretization.LEGACY_DDPM
|
58 |
-
guider: Guider = Guider.VANILLA
|
59 |
-
thresholder: Thresholder = Thresholder.NONE
|
60 |
-
scale: float = 6.0
|
61 |
-
aesthetic_score: float = 5.0
|
62 |
-
negative_aesthetic_score: float = 5.0
|
63 |
-
img2img_strength: float = 1.0
|
64 |
-
orig_width: int = 1024
|
65 |
-
orig_height: int = 1024
|
66 |
-
crop_coords_top: int = 0
|
67 |
-
crop_coords_left: int = 0
|
68 |
-
sigma_min: float = 0.0292
|
69 |
-
sigma_max: float = 14.6146
|
70 |
-
rho: float = 3.0
|
71 |
-
s_churn: float = 0.0
|
72 |
-
s_tmin: float = 0.0
|
73 |
-
s_tmax: float = 999.0
|
74 |
-
s_noise: float = 1.0
|
75 |
-
eta: float = 1.0
|
76 |
-
order: int = 4
|
77 |
-
|
78 |
-
|
79 |
-
@dataclass
|
80 |
-
class SamplingSpec:
|
81 |
-
width: int
|
82 |
-
height: int
|
83 |
-
channels: int
|
84 |
-
factor: int
|
85 |
-
is_legacy: bool
|
86 |
-
config: str
|
87 |
-
ckpt: str
|
88 |
-
is_guided: bool
|
89 |
-
|
90 |
-
|
91 |
-
model_specs = {
|
92 |
-
ModelArchitecture.SD_2_1: SamplingSpec(
|
93 |
-
height=512,
|
94 |
-
width=512,
|
95 |
-
channels=4,
|
96 |
-
factor=8,
|
97 |
-
is_legacy=True,
|
98 |
-
config="sd_2_1.yaml",
|
99 |
-
ckpt="v2-1_512-ema-pruned.safetensors",
|
100 |
-
is_guided=True,
|
101 |
-
),
|
102 |
-
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
103 |
-
height=768,
|
104 |
-
width=768,
|
105 |
-
channels=4,
|
106 |
-
factor=8,
|
107 |
-
is_legacy=True,
|
108 |
-
config="sd_2_1_768.yaml",
|
109 |
-
ckpt="v2-1_768-ema-pruned.safetensors",
|
110 |
-
is_guided=True,
|
111 |
-
),
|
112 |
-
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
113 |
-
height=1024,
|
114 |
-
width=1024,
|
115 |
-
channels=4,
|
116 |
-
factor=8,
|
117 |
-
is_legacy=False,
|
118 |
-
config="sd_xl_base.yaml",
|
119 |
-
ckpt="sd_xl_base_0.9.safetensors",
|
120 |
-
is_guided=True,
|
121 |
-
),
|
122 |
-
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
123 |
-
height=1024,
|
124 |
-
width=1024,
|
125 |
-
channels=4,
|
126 |
-
factor=8,
|
127 |
-
is_legacy=True,
|
128 |
-
config="sd_xl_refiner.yaml",
|
129 |
-
ckpt="sd_xl_refiner_0.9.safetensors",
|
130 |
-
is_guided=True,
|
131 |
-
),
|
132 |
-
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
133 |
-
height=1024,
|
134 |
-
width=1024,
|
135 |
-
channels=4,
|
136 |
-
factor=8,
|
137 |
-
is_legacy=False,
|
138 |
-
config="sd_xl_base.yaml",
|
139 |
-
ckpt="sd_xl_base_1.0.safetensors",
|
140 |
-
is_guided=True,
|
141 |
-
),
|
142 |
-
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
143 |
-
height=1024,
|
144 |
-
width=1024,
|
145 |
-
channels=4,
|
146 |
-
factor=8,
|
147 |
-
is_legacy=True,
|
148 |
-
config="sd_xl_refiner.yaml",
|
149 |
-
ckpt="sd_xl_refiner_1.0.safetensors",
|
150 |
-
is_guided=True,
|
151 |
-
),
|
152 |
-
}
|
153 |
-
|
154 |
-
|
155 |
-
class SamplingPipeline:
|
156 |
-
def __init__(
|
157 |
-
self,
|
158 |
-
model_id: ModelArchitecture,
|
159 |
-
model_path="checkpoints",
|
160 |
-
config_path="configs/inference",
|
161 |
-
device="cuda",
|
162 |
-
use_fp16=True,
|
163 |
-
) -> None:
|
164 |
-
if model_id not in model_specs:
|
165 |
-
raise ValueError(f"Model {model_id} not supported")
|
166 |
-
self.model_id = model_id
|
167 |
-
self.specs = model_specs[self.model_id]
|
168 |
-
self.config = str(pathlib.Path(config_path, self.specs.config))
|
169 |
-
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
170 |
-
self.device = device
|
171 |
-
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
172 |
-
|
173 |
-
def _load_model(self, device="cuda", use_fp16=True):
|
174 |
-
config = OmegaConf.load(self.config)
|
175 |
-
model = load_model_from_config(config, self.ckpt)
|
176 |
-
if model is None:
|
177 |
-
raise ValueError(f"Model {self.model_id} could not be loaded")
|
178 |
-
model.to(device)
|
179 |
-
if use_fp16:
|
180 |
-
model.conditioner.half()
|
181 |
-
model.model.half()
|
182 |
-
return model
|
183 |
-
|
184 |
-
def text_to_image(
|
185 |
-
self,
|
186 |
-
params: SamplingParams,
|
187 |
-
prompt: str,
|
188 |
-
negative_prompt: str = "",
|
189 |
-
samples: int = 1,
|
190 |
-
return_latents: bool = False,
|
191 |
-
):
|
192 |
-
sampler = get_sampler_config(params)
|
193 |
-
value_dict = asdict(params)
|
194 |
-
value_dict["prompt"] = prompt
|
195 |
-
value_dict["negative_prompt"] = negative_prompt
|
196 |
-
value_dict["target_width"] = params.width
|
197 |
-
value_dict["target_height"] = params.height
|
198 |
-
return do_sample(
|
199 |
-
self.model,
|
200 |
-
sampler,
|
201 |
-
value_dict,
|
202 |
-
samples,
|
203 |
-
params.height,
|
204 |
-
params.width,
|
205 |
-
self.specs.channels,
|
206 |
-
self.specs.factor,
|
207 |
-
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
208 |
-
return_latents=return_latents,
|
209 |
-
filter=None,
|
210 |
-
)
|
211 |
-
|
212 |
-
def image_to_image(
|
213 |
-
self,
|
214 |
-
params: SamplingParams,
|
215 |
-
image,
|
216 |
-
prompt: str,
|
217 |
-
negative_prompt: str = "",
|
218 |
-
samples: int = 1,
|
219 |
-
return_latents: bool = False,
|
220 |
-
):
|
221 |
-
sampler = get_sampler_config(params)
|
222 |
-
|
223 |
-
if params.img2img_strength < 1.0:
|
224 |
-
sampler.discretization = Img2ImgDiscretizationWrapper(
|
225 |
-
sampler.discretization,
|
226 |
-
strength=params.img2img_strength,
|
227 |
-
)
|
228 |
-
height, width = image.shape[2], image.shape[3]
|
229 |
-
value_dict = asdict(params)
|
230 |
-
value_dict["prompt"] = prompt
|
231 |
-
value_dict["negative_prompt"] = negative_prompt
|
232 |
-
value_dict["target_width"] = width
|
233 |
-
value_dict["target_height"] = height
|
234 |
-
return do_img2img(
|
235 |
-
image,
|
236 |
-
self.model,
|
237 |
-
sampler,
|
238 |
-
value_dict,
|
239 |
-
samples,
|
240 |
-
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
241 |
-
return_latents=return_latents,
|
242 |
-
filter=None,
|
243 |
-
)
|
244 |
-
|
245 |
-
def refiner(
|
246 |
-
self,
|
247 |
-
params: SamplingParams,
|
248 |
-
image,
|
249 |
-
prompt: str,
|
250 |
-
negative_prompt: Optional[str] = None,
|
251 |
-
samples: int = 1,
|
252 |
-
return_latents: bool = False,
|
253 |
-
):
|
254 |
-
sampler = get_sampler_config(params)
|
255 |
-
value_dict = {
|
256 |
-
"orig_width": image.shape[3] * 8,
|
257 |
-
"orig_height": image.shape[2] * 8,
|
258 |
-
"target_width": image.shape[3] * 8,
|
259 |
-
"target_height": image.shape[2] * 8,
|
260 |
-
"prompt": prompt,
|
261 |
-
"negative_prompt": negative_prompt,
|
262 |
-
"crop_coords_top": 0,
|
263 |
-
"crop_coords_left": 0,
|
264 |
-
"aesthetic_score": 6.0,
|
265 |
-
"negative_aesthetic_score": 2.5,
|
266 |
-
}
|
267 |
-
|
268 |
-
return do_img2img(
|
269 |
-
image,
|
270 |
-
self.model,
|
271 |
-
sampler,
|
272 |
-
value_dict,
|
273 |
-
samples,
|
274 |
-
skip_encode=True,
|
275 |
-
return_latents=return_latents,
|
276 |
-
filter=None,
|
277 |
-
)
|
278 |
-
|
279 |
-
|
280 |
-
def get_guider_config(params: SamplingParams):
|
281 |
-
if params.guider == Guider.IDENTITY:
|
282 |
-
guider_config = {
|
283 |
-
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
284 |
-
}
|
285 |
-
elif params.guider == Guider.VANILLA:
|
286 |
-
scale = params.scale
|
287 |
-
|
288 |
-
thresholder = params.thresholder
|
289 |
-
|
290 |
-
if thresholder == Thresholder.NONE:
|
291 |
-
dyn_thresh_config = {
|
292 |
-
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
293 |
-
}
|
294 |
-
else:
|
295 |
-
raise NotImplementedError
|
296 |
-
|
297 |
-
guider_config = {
|
298 |
-
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
299 |
-
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
300 |
-
}
|
301 |
-
else:
|
302 |
-
raise NotImplementedError
|
303 |
-
return guider_config
|
304 |
-
|
305 |
-
|
306 |
-
def get_discretization_config(params: SamplingParams):
|
307 |
-
if params.discretization == Discretization.LEGACY_DDPM:
|
308 |
-
discretization_config = {
|
309 |
-
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
310 |
-
}
|
311 |
-
elif params.discretization == Discretization.EDM:
|
312 |
-
discretization_config = {
|
313 |
-
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
314 |
-
"params": {
|
315 |
-
"sigma_min": params.sigma_min,
|
316 |
-
"sigma_max": params.sigma_max,
|
317 |
-
"rho": params.rho,
|
318 |
-
},
|
319 |
-
}
|
320 |
-
else:
|
321 |
-
raise ValueError(f"unknown discretization {params.discretization}")
|
322 |
-
return discretization_config
|
323 |
-
|
324 |
-
|
325 |
-
def get_sampler_config(params: SamplingParams):
|
326 |
-
discretization_config = get_discretization_config(params)
|
327 |
-
guider_config = get_guider_config(params)
|
328 |
-
sampler = None
|
329 |
-
if params.sampler == Sampler.EULER_EDM:
|
330 |
-
return EulerEDMSampler(
|
331 |
-
num_steps=params.steps,
|
332 |
-
discretization_config=discretization_config,
|
333 |
-
guider_config=guider_config,
|
334 |
-
s_churn=params.s_churn,
|
335 |
-
s_tmin=params.s_tmin,
|
336 |
-
s_tmax=params.s_tmax,
|
337 |
-
s_noise=params.s_noise,
|
338 |
-
verbose=True,
|
339 |
-
)
|
340 |
-
if params.sampler == Sampler.HEUN_EDM:
|
341 |
-
return HeunEDMSampler(
|
342 |
-
num_steps=params.steps,
|
343 |
-
discretization_config=discretization_config,
|
344 |
-
guider_config=guider_config,
|
345 |
-
s_churn=params.s_churn,
|
346 |
-
s_tmin=params.s_tmin,
|
347 |
-
s_tmax=params.s_tmax,
|
348 |
-
s_noise=params.s_noise,
|
349 |
-
verbose=True,
|
350 |
-
)
|
351 |
-
if params.sampler == Sampler.EULER_ANCESTRAL:
|
352 |
-
return EulerAncestralSampler(
|
353 |
-
num_steps=params.steps,
|
354 |
-
discretization_config=discretization_config,
|
355 |
-
guider_config=guider_config,
|
356 |
-
eta=params.eta,
|
357 |
-
s_noise=params.s_noise,
|
358 |
-
verbose=True,
|
359 |
-
)
|
360 |
-
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
|
361 |
-
return DPMPP2SAncestralSampler(
|
362 |
-
num_steps=params.steps,
|
363 |
-
discretization_config=discretization_config,
|
364 |
-
guider_config=guider_config,
|
365 |
-
eta=params.eta,
|
366 |
-
s_noise=params.s_noise,
|
367 |
-
verbose=True,
|
368 |
-
)
|
369 |
-
if params.sampler == Sampler.DPMPP2M:
|
370 |
-
return DPMPP2MSampler(
|
371 |
-
num_steps=params.steps,
|
372 |
-
discretization_config=discretization_config,
|
373 |
-
guider_config=guider_config,
|
374 |
-
verbose=True,
|
375 |
-
)
|
376 |
-
if params.sampler == Sampler.LINEAR_MULTISTEP:
|
377 |
-
return LinearMultistepSampler(
|
378 |
-
num_steps=params.steps,
|
379 |
-
discretization_config=discretization_config,
|
380 |
-
guider_config=guider_config,
|
381 |
-
order=params.order,
|
382 |
-
verbose=True,
|
383 |
-
)
|
384 |
-
|
385 |
-
raise ValueError(f"unknown sampler {params.sampler}!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/inference/helpers.py
DELETED
@@ -1,305 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import os
|
3 |
-
from typing import List, Optional, Union
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
from einops import rearrange
|
8 |
-
from imwatermark import WatermarkEncoder
|
9 |
-
from omegaconf import ListConfig
|
10 |
-
from PIL import Image
|
11 |
-
from torch import autocast
|
12 |
-
|
13 |
-
from sgm.util import append_dims
|
14 |
-
|
15 |
-
|
16 |
-
class WatermarkEmbedder:
|
17 |
-
def __init__(self, watermark):
|
18 |
-
self.watermark = watermark
|
19 |
-
self.num_bits = len(WATERMARK_BITS)
|
20 |
-
self.encoder = WatermarkEncoder()
|
21 |
-
self.encoder.set_watermark("bits", self.watermark)
|
22 |
-
|
23 |
-
def __call__(self, image: torch.Tensor) -> torch.Tensor:
|
24 |
-
"""
|
25 |
-
Adds a predefined watermark to the input image
|
26 |
-
|
27 |
-
Args:
|
28 |
-
image: ([N,] B, RGB, H, W) in range [0, 1]
|
29 |
-
|
30 |
-
Returns:
|
31 |
-
same as input but watermarked
|
32 |
-
"""
|
33 |
-
squeeze = len(image.shape) == 4
|
34 |
-
if squeeze:
|
35 |
-
image = image[None, ...]
|
36 |
-
n = image.shape[0]
|
37 |
-
image_np = rearrange(
|
38 |
-
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
39 |
-
).numpy()[:, :, :, ::-1]
|
40 |
-
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
41 |
-
# watermarking libary expects input as cv2 BGR format
|
42 |
-
for k in range(image_np.shape[0]):
|
43 |
-
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
44 |
-
image = torch.from_numpy(
|
45 |
-
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
46 |
-
).to(image.device)
|
47 |
-
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
48 |
-
if squeeze:
|
49 |
-
image = image[0]
|
50 |
-
return image
|
51 |
-
|
52 |
-
|
53 |
-
# A fixed 48-bit message that was choosen at random
|
54 |
-
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
55 |
-
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
56 |
-
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
57 |
-
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
58 |
-
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
|
59 |
-
|
60 |
-
|
61 |
-
def get_unique_embedder_keys_from_conditioner(conditioner):
|
62 |
-
return list({x.input_key for x in conditioner.embedders})
|
63 |
-
|
64 |
-
|
65 |
-
def perform_save_locally(save_path, samples):
|
66 |
-
os.makedirs(os.path.join(save_path), exist_ok=True)
|
67 |
-
base_count = len(os.listdir(os.path.join(save_path)))
|
68 |
-
samples = embed_watermark(samples)
|
69 |
-
for sample in samples:
|
70 |
-
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
71 |
-
Image.fromarray(sample.astype(np.uint8)).save(
|
72 |
-
os.path.join(save_path, f"{base_count:09}.png")
|
73 |
-
)
|
74 |
-
base_count += 1
|
75 |
-
|
76 |
-
|
77 |
-
class Img2ImgDiscretizationWrapper:
|
78 |
-
"""
|
79 |
-
wraps a discretizer, and prunes the sigmas
|
80 |
-
params:
|
81 |
-
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
82 |
-
"""
|
83 |
-
|
84 |
-
def __init__(self, discretization, strength: float = 1.0):
|
85 |
-
self.discretization = discretization
|
86 |
-
self.strength = strength
|
87 |
-
assert 0.0 <= self.strength <= 1.0
|
88 |
-
|
89 |
-
def __call__(self, *args, **kwargs):
|
90 |
-
# sigmas start large first, and decrease then
|
91 |
-
sigmas = self.discretization(*args, **kwargs)
|
92 |
-
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
93 |
-
sigmas = torch.flip(sigmas, (0,))
|
94 |
-
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
95 |
-
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
96 |
-
sigmas = torch.flip(sigmas, (0,))
|
97 |
-
print(f"sigmas after pruning: ", sigmas)
|
98 |
-
return sigmas
|
99 |
-
|
100 |
-
|
101 |
-
def do_sample(
|
102 |
-
model,
|
103 |
-
sampler,
|
104 |
-
value_dict,
|
105 |
-
num_samples,
|
106 |
-
H,
|
107 |
-
W,
|
108 |
-
C,
|
109 |
-
F,
|
110 |
-
force_uc_zero_embeddings: Optional[List] = None,
|
111 |
-
batch2model_input: Optional[List] = None,
|
112 |
-
return_latents=False,
|
113 |
-
filter=None,
|
114 |
-
device="cuda",
|
115 |
-
):
|
116 |
-
if force_uc_zero_embeddings is None:
|
117 |
-
force_uc_zero_embeddings = []
|
118 |
-
if batch2model_input is None:
|
119 |
-
batch2model_input = []
|
120 |
-
|
121 |
-
with torch.no_grad():
|
122 |
-
with autocast(device) as precision_scope:
|
123 |
-
with model.ema_scope():
|
124 |
-
num_samples = [num_samples]
|
125 |
-
batch, batch_uc = get_batch(
|
126 |
-
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
127 |
-
value_dict,
|
128 |
-
num_samples,
|
129 |
-
)
|
130 |
-
for key in batch:
|
131 |
-
if isinstance(batch[key], torch.Tensor):
|
132 |
-
print(key, batch[key].shape)
|
133 |
-
elif isinstance(batch[key], list):
|
134 |
-
print(key, [len(l) for l in batch[key]])
|
135 |
-
else:
|
136 |
-
print(key, batch[key])
|
137 |
-
c, uc = model.conditioner.get_unconditional_conditioning(
|
138 |
-
batch,
|
139 |
-
batch_uc=batch_uc,
|
140 |
-
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
141 |
-
)
|
142 |
-
|
143 |
-
for k in c:
|
144 |
-
if not k == "crossattn":
|
145 |
-
c[k], uc[k] = map(
|
146 |
-
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
|
147 |
-
)
|
148 |
-
|
149 |
-
additional_model_inputs = {}
|
150 |
-
for k in batch2model_input:
|
151 |
-
additional_model_inputs[k] = batch[k]
|
152 |
-
|
153 |
-
shape = (math.prod(num_samples), C, H // F, W // F)
|
154 |
-
randn = torch.randn(shape).to(device)
|
155 |
-
|
156 |
-
def denoiser(input, sigma, c):
|
157 |
-
return model.denoiser(
|
158 |
-
model.model, input, sigma, c, **additional_model_inputs
|
159 |
-
)
|
160 |
-
|
161 |
-
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
162 |
-
samples_x = model.decode_first_stage(samples_z)
|
163 |
-
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
164 |
-
|
165 |
-
if filter is not None:
|
166 |
-
samples = filter(samples)
|
167 |
-
|
168 |
-
if return_latents:
|
169 |
-
return samples, samples_z
|
170 |
-
return samples
|
171 |
-
|
172 |
-
|
173 |
-
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
174 |
-
# Hardcoded demo setups; might undergo some changes in the future
|
175 |
-
|
176 |
-
batch = {}
|
177 |
-
batch_uc = {}
|
178 |
-
|
179 |
-
for key in keys:
|
180 |
-
if key == "txt":
|
181 |
-
batch["txt"] = (
|
182 |
-
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
183 |
-
.reshape(N)
|
184 |
-
.tolist()
|
185 |
-
)
|
186 |
-
batch_uc["txt"] = (
|
187 |
-
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
188 |
-
.reshape(N)
|
189 |
-
.tolist()
|
190 |
-
)
|
191 |
-
elif key == "original_size_as_tuple":
|
192 |
-
batch["original_size_as_tuple"] = (
|
193 |
-
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
194 |
-
.to(device)
|
195 |
-
.repeat(*N, 1)
|
196 |
-
)
|
197 |
-
elif key == "crop_coords_top_left":
|
198 |
-
batch["crop_coords_top_left"] = (
|
199 |
-
torch.tensor(
|
200 |
-
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
201 |
-
)
|
202 |
-
.to(device)
|
203 |
-
.repeat(*N, 1)
|
204 |
-
)
|
205 |
-
elif key == "aesthetic_score":
|
206 |
-
batch["aesthetic_score"] = (
|
207 |
-
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
208 |
-
)
|
209 |
-
batch_uc["aesthetic_score"] = (
|
210 |
-
torch.tensor([value_dict["negative_aesthetic_score"]])
|
211 |
-
.to(device)
|
212 |
-
.repeat(*N, 1)
|
213 |
-
)
|
214 |
-
|
215 |
-
elif key == "target_size_as_tuple":
|
216 |
-
batch["target_size_as_tuple"] = (
|
217 |
-
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
218 |
-
.to(device)
|
219 |
-
.repeat(*N, 1)
|
220 |
-
)
|
221 |
-
else:
|
222 |
-
batch[key] = value_dict[key]
|
223 |
-
|
224 |
-
for key in batch.keys():
|
225 |
-
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
226 |
-
batch_uc[key] = torch.clone(batch[key])
|
227 |
-
return batch, batch_uc
|
228 |
-
|
229 |
-
|
230 |
-
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
231 |
-
w, h = image.size
|
232 |
-
print(f"loaded input image of size ({w}, {h})")
|
233 |
-
width, height = map(
|
234 |
-
lambda x: x - x % 64, (w, h)
|
235 |
-
) # resize to integer multiple of 64
|
236 |
-
image = image.resize((width, height))
|
237 |
-
image_array = np.array(image.convert("RGB"))
|
238 |
-
image_array = image_array[None].transpose(0, 3, 1, 2)
|
239 |
-
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
|
240 |
-
return image_tensor.to(device)
|
241 |
-
|
242 |
-
|
243 |
-
def do_img2img(
|
244 |
-
img,
|
245 |
-
model,
|
246 |
-
sampler,
|
247 |
-
value_dict,
|
248 |
-
num_samples,
|
249 |
-
force_uc_zero_embeddings=[],
|
250 |
-
additional_kwargs={},
|
251 |
-
offset_noise_level: float = 0.0,
|
252 |
-
return_latents=False,
|
253 |
-
skip_encode=False,
|
254 |
-
filter=None,
|
255 |
-
device="cuda",
|
256 |
-
):
|
257 |
-
with torch.no_grad():
|
258 |
-
with autocast(device) as precision_scope:
|
259 |
-
with model.ema_scope():
|
260 |
-
batch, batch_uc = get_batch(
|
261 |
-
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
262 |
-
value_dict,
|
263 |
-
[num_samples],
|
264 |
-
)
|
265 |
-
c, uc = model.conditioner.get_unconditional_conditioning(
|
266 |
-
batch,
|
267 |
-
batch_uc=batch_uc,
|
268 |
-
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
269 |
-
)
|
270 |
-
|
271 |
-
for k in c:
|
272 |
-
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
|
273 |
-
|
274 |
-
for k in additional_kwargs:
|
275 |
-
c[k] = uc[k] = additional_kwargs[k]
|
276 |
-
if skip_encode:
|
277 |
-
z = img
|
278 |
-
else:
|
279 |
-
z = model.encode_first_stage(img)
|
280 |
-
noise = torch.randn_like(z)
|
281 |
-
sigmas = sampler.discretization(sampler.num_steps)
|
282 |
-
sigma = sigmas[0].to(z.device)
|
283 |
-
|
284 |
-
if offset_noise_level > 0.0:
|
285 |
-
noise = noise + offset_noise_level * append_dims(
|
286 |
-
torch.randn(z.shape[0], device=z.device), z.ndim
|
287 |
-
)
|
288 |
-
noised_z = z + noise * append_dims(sigma, z.ndim)
|
289 |
-
noised_z = noised_z / torch.sqrt(
|
290 |
-
1.0 + sigmas[0] ** 2.0
|
291 |
-
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
292 |
-
|
293 |
-
def denoiser(x, sigma, c):
|
294 |
-
return model.denoiser(model.model, x, sigma, c)
|
295 |
-
|
296 |
-
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
297 |
-
samples_x = model.decode_first_stage(samples_z)
|
298 |
-
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
299 |
-
|
300 |
-
if filter is not None:
|
301 |
-
samples = filter(samples)
|
302 |
-
|
303 |
-
if return_latents:
|
304 |
-
return samples, samples_z
|
305 |
-
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/lr_scheduler.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
|
4 |
-
class LambdaWarmUpCosineScheduler:
|
5 |
-
"""
|
6 |
-
note: use with a base_lr of 1.0
|
7 |
-
"""
|
8 |
-
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
warm_up_steps,
|
12 |
-
lr_min,
|
13 |
-
lr_max,
|
14 |
-
lr_start,
|
15 |
-
max_decay_steps,
|
16 |
-
verbosity_interval=0,
|
17 |
-
):
|
18 |
-
self.lr_warm_up_steps = warm_up_steps
|
19 |
-
self.lr_start = lr_start
|
20 |
-
self.lr_min = lr_min
|
21 |
-
self.lr_max = lr_max
|
22 |
-
self.lr_max_decay_steps = max_decay_steps
|
23 |
-
self.last_lr = 0.0
|
24 |
-
self.verbosity_interval = verbosity_interval
|
25 |
-
|
26 |
-
def schedule(self, n, **kwargs):
|
27 |
-
if self.verbosity_interval > 0:
|
28 |
-
if n % self.verbosity_interval == 0:
|
29 |
-
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
30 |
-
if n < self.lr_warm_up_steps:
|
31 |
-
lr = (
|
32 |
-
self.lr_max - self.lr_start
|
33 |
-
) / self.lr_warm_up_steps * n + self.lr_start
|
34 |
-
self.last_lr = lr
|
35 |
-
return lr
|
36 |
-
else:
|
37 |
-
t = (n - self.lr_warm_up_steps) / (
|
38 |
-
self.lr_max_decay_steps - self.lr_warm_up_steps
|
39 |
-
)
|
40 |
-
t = min(t, 1.0)
|
41 |
-
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
42 |
-
1 + np.cos(t * np.pi)
|
43 |
-
)
|
44 |
-
self.last_lr = lr
|
45 |
-
return lr
|
46 |
-
|
47 |
-
def __call__(self, n, **kwargs):
|
48 |
-
return self.schedule(n, **kwargs)
|
49 |
-
|
50 |
-
|
51 |
-
class LambdaWarmUpCosineScheduler2:
|
52 |
-
"""
|
53 |
-
supports repeated iterations, configurable via lists
|
54 |
-
note: use with a base_lr of 1.0.
|
55 |
-
"""
|
56 |
-
|
57 |
-
def __init__(
|
58 |
-
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
59 |
-
):
|
60 |
-
assert (
|
61 |
-
len(warm_up_steps)
|
62 |
-
== len(f_min)
|
63 |
-
== len(f_max)
|
64 |
-
== len(f_start)
|
65 |
-
== len(cycle_lengths)
|
66 |
-
)
|
67 |
-
self.lr_warm_up_steps = warm_up_steps
|
68 |
-
self.f_start = f_start
|
69 |
-
self.f_min = f_min
|
70 |
-
self.f_max = f_max
|
71 |
-
self.cycle_lengths = cycle_lengths
|
72 |
-
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
73 |
-
self.last_f = 0.0
|
74 |
-
self.verbosity_interval = verbosity_interval
|
75 |
-
|
76 |
-
def find_in_interval(self, n):
|
77 |
-
interval = 0
|
78 |
-
for cl in self.cum_cycles[1:]:
|
79 |
-
if n <= cl:
|
80 |
-
return interval
|
81 |
-
interval += 1
|
82 |
-
|
83 |
-
def schedule(self, n, **kwargs):
|
84 |
-
cycle = self.find_in_interval(n)
|
85 |
-
n = n - self.cum_cycles[cycle]
|
86 |
-
if self.verbosity_interval > 0:
|
87 |
-
if n % self.verbosity_interval == 0:
|
88 |
-
print(
|
89 |
-
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
90 |
-
f"current cycle {cycle}"
|
91 |
-
)
|
92 |
-
if n < self.lr_warm_up_steps[cycle]:
|
93 |
-
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
94 |
-
cycle
|
95 |
-
] * n + self.f_start[cycle]
|
96 |
-
self.last_f = f
|
97 |
-
return f
|
98 |
-
else:
|
99 |
-
t = (n - self.lr_warm_up_steps[cycle]) / (
|
100 |
-
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
101 |
-
)
|
102 |
-
t = min(t, 1.0)
|
103 |
-
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
104 |
-
1 + np.cos(t * np.pi)
|
105 |
-
)
|
106 |
-
self.last_f = f
|
107 |
-
return f
|
108 |
-
|
109 |
-
def __call__(self, n, **kwargs):
|
110 |
-
return self.schedule(n, **kwargs)
|
111 |
-
|
112 |
-
|
113 |
-
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
114 |
-
def schedule(self, n, **kwargs):
|
115 |
-
cycle = self.find_in_interval(n)
|
116 |
-
n = n - self.cum_cycles[cycle]
|
117 |
-
if self.verbosity_interval > 0:
|
118 |
-
if n % self.verbosity_interval == 0:
|
119 |
-
print(
|
120 |
-
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
121 |
-
f"current cycle {cycle}"
|
122 |
-
)
|
123 |
-
|
124 |
-
if n < self.lr_warm_up_steps[cycle]:
|
125 |
-
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
126 |
-
cycle
|
127 |
-
] * n + self.f_start[cycle]
|
128 |
-
self.last_f = f
|
129 |
-
return f
|
130 |
-
else:
|
131 |
-
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
132 |
-
self.cycle_lengths[cycle] - n
|
133 |
-
) / (self.cycle_lengths[cycle])
|
134 |
-
self.last_f = f
|
135 |
-
return f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/models/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .autoencoder import AutoencodingEngine
|
2 |
-
from .diffusion import DiffusionEngine
|
|
|
|
|
|
sgm/models/autoencoder.py
DELETED
@@ -1,615 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import math
|
3 |
-
import re
|
4 |
-
from abc import abstractmethod
|
5 |
-
from contextlib import contextmanager
|
6 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
7 |
-
|
8 |
-
import pytorch_lightning as pl
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
from einops import rearrange
|
12 |
-
from packaging import version
|
13 |
-
|
14 |
-
from ..modules.autoencoding.regularizers import AbstractRegularizer
|
15 |
-
from ..modules.ema import LitEma
|
16 |
-
from ..util import (default, get_nested_attribute, get_obj_from_str,
|
17 |
-
instantiate_from_config)
|
18 |
-
|
19 |
-
logpy = logging.getLogger(__name__)
|
20 |
-
|
21 |
-
|
22 |
-
class AbstractAutoencoder(pl.LightningModule):
|
23 |
-
"""
|
24 |
-
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
25 |
-
unCLIP models, etc. Hence, it is fairly general, and specific features
|
26 |
-
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
27 |
-
"""
|
28 |
-
|
29 |
-
def __init__(
|
30 |
-
self,
|
31 |
-
ema_decay: Union[None, float] = None,
|
32 |
-
monitor: Union[None, str] = None,
|
33 |
-
input_key: str = "jpg",
|
34 |
-
):
|
35 |
-
super().__init__()
|
36 |
-
|
37 |
-
self.input_key = input_key
|
38 |
-
self.use_ema = ema_decay is not None
|
39 |
-
if monitor is not None:
|
40 |
-
self.monitor = monitor
|
41 |
-
|
42 |
-
if self.use_ema:
|
43 |
-
self.model_ema = LitEma(self, decay=ema_decay)
|
44 |
-
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
45 |
-
|
46 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
47 |
-
self.automatic_optimization = False
|
48 |
-
|
49 |
-
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
50 |
-
if ckpt is None:
|
51 |
-
return
|
52 |
-
if isinstance(ckpt, str):
|
53 |
-
ckpt = {
|
54 |
-
"target": "sgm.modules.checkpoint.CheckpointEngine",
|
55 |
-
"params": {"ckpt_path": ckpt},
|
56 |
-
}
|
57 |
-
engine = instantiate_from_config(ckpt)
|
58 |
-
engine(self)
|
59 |
-
|
60 |
-
@abstractmethod
|
61 |
-
def get_input(self, batch) -> Any:
|
62 |
-
raise NotImplementedError()
|
63 |
-
|
64 |
-
def on_train_batch_end(self, *args, **kwargs):
|
65 |
-
# for EMA computation
|
66 |
-
if self.use_ema:
|
67 |
-
self.model_ema(self)
|
68 |
-
|
69 |
-
@contextmanager
|
70 |
-
def ema_scope(self, context=None):
|
71 |
-
if self.use_ema:
|
72 |
-
self.model_ema.store(self.parameters())
|
73 |
-
self.model_ema.copy_to(self)
|
74 |
-
if context is not None:
|
75 |
-
logpy.info(f"{context}: Switched to EMA weights")
|
76 |
-
try:
|
77 |
-
yield None
|
78 |
-
finally:
|
79 |
-
if self.use_ema:
|
80 |
-
self.model_ema.restore(self.parameters())
|
81 |
-
if context is not None:
|
82 |
-
logpy.info(f"{context}: Restored training weights")
|
83 |
-
|
84 |
-
@abstractmethod
|
85 |
-
def encode(self, *args, **kwargs) -> torch.Tensor:
|
86 |
-
raise NotImplementedError("encode()-method of abstract base class called")
|
87 |
-
|
88 |
-
@abstractmethod
|
89 |
-
def decode(self, *args, **kwargs) -> torch.Tensor:
|
90 |
-
raise NotImplementedError("decode()-method of abstract base class called")
|
91 |
-
|
92 |
-
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
93 |
-
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
94 |
-
return get_obj_from_str(cfg["target"])(
|
95 |
-
params, lr=lr, **cfg.get("params", dict())
|
96 |
-
)
|
97 |
-
|
98 |
-
def configure_optimizers(self) -> Any:
|
99 |
-
raise NotImplementedError()
|
100 |
-
|
101 |
-
|
102 |
-
class AutoencodingEngine(AbstractAutoencoder):
|
103 |
-
"""
|
104 |
-
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
105 |
-
(we also restore them explicitly as special cases for legacy reasons).
|
106 |
-
Regularizations such as KL or VQ are moved to the regularizer class.
|
107 |
-
"""
|
108 |
-
|
109 |
-
def __init__(
|
110 |
-
self,
|
111 |
-
*args,
|
112 |
-
encoder_config: Dict,
|
113 |
-
decoder_config: Dict,
|
114 |
-
loss_config: Dict,
|
115 |
-
regularizer_config: Dict,
|
116 |
-
optimizer_config: Union[Dict, None] = None,
|
117 |
-
lr_g_factor: float = 1.0,
|
118 |
-
trainable_ae_params: Optional[List[List[str]]] = None,
|
119 |
-
ae_optimizer_args: Optional[List[dict]] = None,
|
120 |
-
trainable_disc_params: Optional[List[List[str]]] = None,
|
121 |
-
disc_optimizer_args: Optional[List[dict]] = None,
|
122 |
-
disc_start_iter: int = 0,
|
123 |
-
diff_boost_factor: float = 3.0,
|
124 |
-
ckpt_engine: Union[None, str, dict] = None,
|
125 |
-
ckpt_path: Optional[str] = None,
|
126 |
-
additional_decode_keys: Optional[List[str]] = None,
|
127 |
-
**kwargs,
|
128 |
-
):
|
129 |
-
super().__init__(*args, **kwargs)
|
130 |
-
self.automatic_optimization = False # pytorch lightning
|
131 |
-
|
132 |
-
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
133 |
-
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
134 |
-
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
135 |
-
self.regularization: AbstractRegularizer = instantiate_from_config(
|
136 |
-
regularizer_config
|
137 |
-
)
|
138 |
-
self.optimizer_config = default(
|
139 |
-
optimizer_config, {"target": "torch.optim.Adam"}
|
140 |
-
)
|
141 |
-
self.diff_boost_factor = diff_boost_factor
|
142 |
-
self.disc_start_iter = disc_start_iter
|
143 |
-
self.lr_g_factor = lr_g_factor
|
144 |
-
self.trainable_ae_params = trainable_ae_params
|
145 |
-
if self.trainable_ae_params is not None:
|
146 |
-
self.ae_optimizer_args = default(
|
147 |
-
ae_optimizer_args,
|
148 |
-
[{} for _ in range(len(self.trainable_ae_params))],
|
149 |
-
)
|
150 |
-
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
151 |
-
else:
|
152 |
-
self.ae_optimizer_args = [{}] # makes type consitent
|
153 |
-
|
154 |
-
self.trainable_disc_params = trainable_disc_params
|
155 |
-
if self.trainable_disc_params is not None:
|
156 |
-
self.disc_optimizer_args = default(
|
157 |
-
disc_optimizer_args,
|
158 |
-
[{} for _ in range(len(self.trainable_disc_params))],
|
159 |
-
)
|
160 |
-
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
161 |
-
else:
|
162 |
-
self.disc_optimizer_args = [{}] # makes type consitent
|
163 |
-
|
164 |
-
if ckpt_path is not None:
|
165 |
-
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
166 |
-
logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
|
167 |
-
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
168 |
-
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
169 |
-
|
170 |
-
def get_input(self, batch: Dict) -> torch.Tensor:
|
171 |
-
# assuming unified data format, dataloader returns a dict.
|
172 |
-
# image tensors should be scaled to -1 ... 1 and in channels-first
|
173 |
-
# format (e.g., bchw instead if bhwc)
|
174 |
-
return batch[self.input_key]
|
175 |
-
|
176 |
-
def get_autoencoder_params(self) -> list:
|
177 |
-
params = []
|
178 |
-
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
179 |
-
params += list(self.loss.get_trainable_autoencoder_parameters())
|
180 |
-
if hasattr(self.regularization, "get_trainable_parameters"):
|
181 |
-
params += list(self.regularization.get_trainable_parameters())
|
182 |
-
params = params + list(self.encoder.parameters())
|
183 |
-
params = params + list(self.decoder.parameters())
|
184 |
-
return params
|
185 |
-
|
186 |
-
def get_discriminator_params(self) -> list:
|
187 |
-
if hasattr(self.loss, "get_trainable_parameters"):
|
188 |
-
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
189 |
-
else:
|
190 |
-
params = []
|
191 |
-
return params
|
192 |
-
|
193 |
-
def get_last_layer(self):
|
194 |
-
return self.decoder.get_last_layer()
|
195 |
-
|
196 |
-
def encode(
|
197 |
-
self,
|
198 |
-
x: torch.Tensor,
|
199 |
-
return_reg_log: bool = False,
|
200 |
-
unregularized: bool = False,
|
201 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
202 |
-
z = self.encoder(x)
|
203 |
-
if unregularized:
|
204 |
-
return z, dict()
|
205 |
-
z, reg_log = self.regularization(z)
|
206 |
-
if return_reg_log:
|
207 |
-
return z, reg_log
|
208 |
-
return z
|
209 |
-
|
210 |
-
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
211 |
-
x = self.decoder(z, **kwargs)
|
212 |
-
return x
|
213 |
-
|
214 |
-
def forward(
|
215 |
-
self, x: torch.Tensor, **additional_decode_kwargs
|
216 |
-
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
217 |
-
z, reg_log = self.encode(x, return_reg_log=True)
|
218 |
-
dec = self.decode(z, **additional_decode_kwargs)
|
219 |
-
return z, dec, reg_log
|
220 |
-
|
221 |
-
def inner_training_step(
|
222 |
-
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
223 |
-
) -> torch.Tensor:
|
224 |
-
x = self.get_input(batch)
|
225 |
-
additional_decode_kwargs = {
|
226 |
-
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
227 |
-
}
|
228 |
-
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
229 |
-
if hasattr(self.loss, "forward_keys"):
|
230 |
-
extra_info = {
|
231 |
-
"z": z,
|
232 |
-
"optimizer_idx": optimizer_idx,
|
233 |
-
"global_step": self.global_step,
|
234 |
-
"last_layer": self.get_last_layer(),
|
235 |
-
"split": "train",
|
236 |
-
"regularization_log": regularization_log,
|
237 |
-
"autoencoder": self,
|
238 |
-
}
|
239 |
-
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
240 |
-
else:
|
241 |
-
extra_info = dict()
|
242 |
-
|
243 |
-
if optimizer_idx == 0:
|
244 |
-
# autoencode
|
245 |
-
out_loss = self.loss(x, xrec, **extra_info)
|
246 |
-
if isinstance(out_loss, tuple):
|
247 |
-
aeloss, log_dict_ae = out_loss
|
248 |
-
else:
|
249 |
-
# simple loss function
|
250 |
-
aeloss = out_loss
|
251 |
-
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
252 |
-
|
253 |
-
self.log_dict(
|
254 |
-
log_dict_ae,
|
255 |
-
prog_bar=False,
|
256 |
-
logger=True,
|
257 |
-
on_step=True,
|
258 |
-
on_epoch=True,
|
259 |
-
sync_dist=False,
|
260 |
-
)
|
261 |
-
self.log(
|
262 |
-
"loss",
|
263 |
-
aeloss.mean().detach(),
|
264 |
-
prog_bar=True,
|
265 |
-
logger=False,
|
266 |
-
on_epoch=False,
|
267 |
-
on_step=True,
|
268 |
-
)
|
269 |
-
return aeloss
|
270 |
-
elif optimizer_idx == 1:
|
271 |
-
# discriminator
|
272 |
-
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
273 |
-
# -> discriminator always needs to return a tuple
|
274 |
-
self.log_dict(
|
275 |
-
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
276 |
-
)
|
277 |
-
return discloss
|
278 |
-
else:
|
279 |
-
raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
|
280 |
-
|
281 |
-
def training_step(self, batch: dict, batch_idx: int):
|
282 |
-
opts = self.optimizers()
|
283 |
-
if not isinstance(opts, list):
|
284 |
-
# Non-adversarial case
|
285 |
-
opts = [opts]
|
286 |
-
optimizer_idx = batch_idx % len(opts)
|
287 |
-
if self.global_step < self.disc_start_iter:
|
288 |
-
optimizer_idx = 0
|
289 |
-
opt = opts[optimizer_idx]
|
290 |
-
opt.zero_grad()
|
291 |
-
with opt.toggle_model():
|
292 |
-
loss = self.inner_training_step(
|
293 |
-
batch, batch_idx, optimizer_idx=optimizer_idx
|
294 |
-
)
|
295 |
-
self.manual_backward(loss)
|
296 |
-
opt.step()
|
297 |
-
|
298 |
-
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
299 |
-
log_dict = self._validation_step(batch, batch_idx)
|
300 |
-
with self.ema_scope():
|
301 |
-
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
302 |
-
log_dict.update(log_dict_ema)
|
303 |
-
return log_dict
|
304 |
-
|
305 |
-
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
306 |
-
x = self.get_input(batch)
|
307 |
-
|
308 |
-
z, xrec, regularization_log = self(x)
|
309 |
-
if hasattr(self.loss, "forward_keys"):
|
310 |
-
extra_info = {
|
311 |
-
"z": z,
|
312 |
-
"optimizer_idx": 0,
|
313 |
-
"global_step": self.global_step,
|
314 |
-
"last_layer": self.get_last_layer(),
|
315 |
-
"split": "val" + postfix,
|
316 |
-
"regularization_log": regularization_log,
|
317 |
-
"autoencoder": self,
|
318 |
-
}
|
319 |
-
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
320 |
-
else:
|
321 |
-
extra_info = dict()
|
322 |
-
out_loss = self.loss(x, xrec, **extra_info)
|
323 |
-
if isinstance(out_loss, tuple):
|
324 |
-
aeloss, log_dict_ae = out_loss
|
325 |
-
else:
|
326 |
-
# simple loss function
|
327 |
-
aeloss = out_loss
|
328 |
-
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
329 |
-
full_log_dict = log_dict_ae
|
330 |
-
|
331 |
-
if "optimizer_idx" in extra_info:
|
332 |
-
extra_info["optimizer_idx"] = 1
|
333 |
-
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
334 |
-
full_log_dict.update(log_dict_disc)
|
335 |
-
self.log(
|
336 |
-
f"val{postfix}/loss/rec",
|
337 |
-
log_dict_ae[f"val{postfix}/loss/rec"],
|
338 |
-
sync_dist=True,
|
339 |
-
)
|
340 |
-
self.log_dict(full_log_dict, sync_dist=True)
|
341 |
-
return full_log_dict
|
342 |
-
|
343 |
-
def get_param_groups(
|
344 |
-
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
345 |
-
) -> Tuple[List[Dict[str, Any]], int]:
|
346 |
-
groups = []
|
347 |
-
num_params = 0
|
348 |
-
for names, args in zip(parameter_names, optimizer_args):
|
349 |
-
params = []
|
350 |
-
for pattern_ in names:
|
351 |
-
pattern_params = []
|
352 |
-
pattern = re.compile(pattern_)
|
353 |
-
for p_name, param in self.named_parameters():
|
354 |
-
if re.match(pattern, p_name):
|
355 |
-
pattern_params.append(param)
|
356 |
-
num_params += param.numel()
|
357 |
-
if len(pattern_params) == 0:
|
358 |
-
logpy.warn(f"Did not find parameters for pattern {pattern_}")
|
359 |
-
params.extend(pattern_params)
|
360 |
-
groups.append({"params": params, **args})
|
361 |
-
return groups, num_params
|
362 |
-
|
363 |
-
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
364 |
-
if self.trainable_ae_params is None:
|
365 |
-
ae_params = self.get_autoencoder_params()
|
366 |
-
else:
|
367 |
-
ae_params, num_ae_params = self.get_param_groups(
|
368 |
-
self.trainable_ae_params, self.ae_optimizer_args
|
369 |
-
)
|
370 |
-
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
371 |
-
if self.trainable_disc_params is None:
|
372 |
-
disc_params = self.get_discriminator_params()
|
373 |
-
else:
|
374 |
-
disc_params, num_disc_params = self.get_param_groups(
|
375 |
-
self.trainable_disc_params, self.disc_optimizer_args
|
376 |
-
)
|
377 |
-
logpy.info(
|
378 |
-
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
379 |
-
)
|
380 |
-
opt_ae = self.instantiate_optimizer_from_config(
|
381 |
-
ae_params,
|
382 |
-
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
383 |
-
self.optimizer_config,
|
384 |
-
)
|
385 |
-
opts = [opt_ae]
|
386 |
-
if len(disc_params) > 0:
|
387 |
-
opt_disc = self.instantiate_optimizer_from_config(
|
388 |
-
disc_params, self.learning_rate, self.optimizer_config
|
389 |
-
)
|
390 |
-
opts.append(opt_disc)
|
391 |
-
|
392 |
-
return opts
|
393 |
-
|
394 |
-
@torch.no_grad()
|
395 |
-
def log_images(
|
396 |
-
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
397 |
-
) -> dict:
|
398 |
-
log = dict()
|
399 |
-
additional_decode_kwargs = {}
|
400 |
-
x = self.get_input(batch)
|
401 |
-
additional_decode_kwargs.update(
|
402 |
-
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
403 |
-
)
|
404 |
-
|
405 |
-
_, xrec, _ = self(x, **additional_decode_kwargs)
|
406 |
-
log["inputs"] = x
|
407 |
-
log["reconstructions"] = xrec
|
408 |
-
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
409 |
-
diff.clamp_(0, 1.0)
|
410 |
-
log["diff"] = 2.0 * diff - 1.0
|
411 |
-
# diff_boost shows location of small errors, by boosting their
|
412 |
-
# brightness.
|
413 |
-
log["diff_boost"] = (
|
414 |
-
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
415 |
-
)
|
416 |
-
if hasattr(self.loss, "log_images"):
|
417 |
-
log.update(self.loss.log_images(x, xrec))
|
418 |
-
with self.ema_scope():
|
419 |
-
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
420 |
-
log["reconstructions_ema"] = xrec_ema
|
421 |
-
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
422 |
-
diff_ema.clamp_(0, 1.0)
|
423 |
-
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
424 |
-
log["diff_boost_ema"] = (
|
425 |
-
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
426 |
-
)
|
427 |
-
if additional_log_kwargs:
|
428 |
-
additional_decode_kwargs.update(additional_log_kwargs)
|
429 |
-
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
430 |
-
log_str = "reconstructions-" + "-".join(
|
431 |
-
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
432 |
-
)
|
433 |
-
log[log_str] = xrec_add
|
434 |
-
return log
|
435 |
-
|
436 |
-
|
437 |
-
class AutoencodingEngineLegacy(AutoencodingEngine):
|
438 |
-
def __init__(self, embed_dim: int, **kwargs):
|
439 |
-
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
440 |
-
ddconfig = kwargs.pop("ddconfig")
|
441 |
-
ckpt_path = kwargs.pop("ckpt_path", None)
|
442 |
-
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
443 |
-
super().__init__(
|
444 |
-
encoder_config={
|
445 |
-
"target": "sgm.modules.diffusionmodules.model.Encoder",
|
446 |
-
"params": ddconfig,
|
447 |
-
},
|
448 |
-
decoder_config={
|
449 |
-
"target": "sgm.modules.diffusionmodules.model.Decoder",
|
450 |
-
"params": ddconfig,
|
451 |
-
},
|
452 |
-
**kwargs,
|
453 |
-
)
|
454 |
-
self.quant_conv = torch.nn.Conv2d(
|
455 |
-
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
456 |
-
(1 + ddconfig["double_z"]) * embed_dim,
|
457 |
-
1,
|
458 |
-
)
|
459 |
-
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
460 |
-
self.embed_dim = embed_dim
|
461 |
-
|
462 |
-
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
463 |
-
|
464 |
-
def get_autoencoder_params(self) -> list:
|
465 |
-
params = super().get_autoencoder_params()
|
466 |
-
return params
|
467 |
-
|
468 |
-
def encode(
|
469 |
-
self, x: torch.Tensor, return_reg_log: bool = False
|
470 |
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
471 |
-
if self.max_batch_size is None:
|
472 |
-
z = self.encoder(x)
|
473 |
-
z = self.quant_conv(z)
|
474 |
-
else:
|
475 |
-
N = x.shape[0]
|
476 |
-
bs = self.max_batch_size
|
477 |
-
n_batches = int(math.ceil(N / bs))
|
478 |
-
z = list()
|
479 |
-
for i_batch in range(n_batches):
|
480 |
-
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
481 |
-
z_batch = self.quant_conv(z_batch)
|
482 |
-
z.append(z_batch)
|
483 |
-
z = torch.cat(z, 0)
|
484 |
-
|
485 |
-
z, reg_log = self.regularization(z)
|
486 |
-
if return_reg_log:
|
487 |
-
return z, reg_log
|
488 |
-
return z
|
489 |
-
|
490 |
-
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
491 |
-
if self.max_batch_size is None:
|
492 |
-
dec = self.post_quant_conv(z)
|
493 |
-
dec = self.decoder(dec, **decoder_kwargs)
|
494 |
-
else:
|
495 |
-
N = z.shape[0]
|
496 |
-
bs = self.max_batch_size
|
497 |
-
n_batches = int(math.ceil(N / bs))
|
498 |
-
dec = list()
|
499 |
-
for i_batch in range(n_batches):
|
500 |
-
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
501 |
-
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
502 |
-
dec.append(dec_batch)
|
503 |
-
dec = torch.cat(dec, 0)
|
504 |
-
|
505 |
-
return dec
|
506 |
-
|
507 |
-
|
508 |
-
class AutoencoderKL(AutoencodingEngineLegacy):
|
509 |
-
def __init__(self, **kwargs):
|
510 |
-
if "lossconfig" in kwargs:
|
511 |
-
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
512 |
-
super().__init__(
|
513 |
-
regularizer_config={
|
514 |
-
"target": (
|
515 |
-
"sgm.modules.autoencoding.regularizers"
|
516 |
-
".DiagonalGaussianRegularizer"
|
517 |
-
)
|
518 |
-
},
|
519 |
-
**kwargs,
|
520 |
-
)
|
521 |
-
|
522 |
-
|
523 |
-
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
524 |
-
def __init__(
|
525 |
-
self,
|
526 |
-
embed_dim: int,
|
527 |
-
n_embed: int,
|
528 |
-
sane_index_shape: bool = False,
|
529 |
-
**kwargs,
|
530 |
-
):
|
531 |
-
if "lossconfig" in kwargs:
|
532 |
-
logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
|
533 |
-
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
534 |
-
super().__init__(
|
535 |
-
regularizer_config={
|
536 |
-
"target": (
|
537 |
-
"sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
|
538 |
-
),
|
539 |
-
"params": {
|
540 |
-
"n_e": n_embed,
|
541 |
-
"e_dim": embed_dim,
|
542 |
-
"sane_index_shape": sane_index_shape,
|
543 |
-
},
|
544 |
-
},
|
545 |
-
**kwargs,
|
546 |
-
)
|
547 |
-
|
548 |
-
|
549 |
-
class IdentityFirstStage(AbstractAutoencoder):
|
550 |
-
def __init__(self, *args, **kwargs):
|
551 |
-
super().__init__(*args, **kwargs)
|
552 |
-
|
553 |
-
def get_input(self, x: Any) -> Any:
|
554 |
-
return x
|
555 |
-
|
556 |
-
def encode(self, x: Any, *args, **kwargs) -> Any:
|
557 |
-
return x
|
558 |
-
|
559 |
-
def decode(self, x: Any, *args, **kwargs) -> Any:
|
560 |
-
return x
|
561 |
-
|
562 |
-
|
563 |
-
class AEIntegerWrapper(nn.Module):
|
564 |
-
def __init__(
|
565 |
-
self,
|
566 |
-
model: nn.Module,
|
567 |
-
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
568 |
-
regularization_key: str = "regularization",
|
569 |
-
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
570 |
-
):
|
571 |
-
super().__init__()
|
572 |
-
self.model = model
|
573 |
-
assert hasattr(model, "encode") and hasattr(
|
574 |
-
model, "decode"
|
575 |
-
), "Need AE interface"
|
576 |
-
self.regularization = get_nested_attribute(model, regularization_key)
|
577 |
-
self.shape = shape
|
578 |
-
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
579 |
-
|
580 |
-
def encode(self, x) -> torch.Tensor:
|
581 |
-
assert (
|
582 |
-
not self.training
|
583 |
-
), f"{self.__class__.__name__} only supports inference currently"
|
584 |
-
_, log = self.model.encode(x, **self.encoder_kwargs)
|
585 |
-
assert isinstance(log, dict)
|
586 |
-
inds = log["min_encoding_indices"]
|
587 |
-
return rearrange(inds, "b ... -> b (...)")
|
588 |
-
|
589 |
-
def decode(
|
590 |
-
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
591 |
-
) -> torch.Tensor:
|
592 |
-
# expect inds shape (b, s) with s = h*w
|
593 |
-
shape = default(shape, self.shape) # Optional[(h, w)]
|
594 |
-
if shape is not None:
|
595 |
-
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
596 |
-
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
597 |
-
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
598 |
-
h = rearrange(h, "b h w c -> b c h w")
|
599 |
-
return self.model.decode(h)
|
600 |
-
|
601 |
-
|
602 |
-
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
603 |
-
def __init__(self, **kwargs):
|
604 |
-
if "lossconfig" in kwargs:
|
605 |
-
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
606 |
-
super().__init__(
|
607 |
-
regularizer_config={
|
608 |
-
"target": (
|
609 |
-
"sgm.modules.autoencoding.regularizers"
|
610 |
-
".DiagonalGaussianRegularizer"
|
611 |
-
),
|
612 |
-
"params": {"sample": False},
|
613 |
-
},
|
614 |
-
**kwargs,
|
615 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/models/diffusion.py
DELETED
@@ -1,341 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from contextlib import contextmanager
|
3 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
4 |
-
|
5 |
-
import pytorch_lightning as pl
|
6 |
-
import torch
|
7 |
-
from omegaconf import ListConfig, OmegaConf
|
8 |
-
from safetensors.torch import load_file as load_safetensors
|
9 |
-
from torch.optim.lr_scheduler import LambdaLR
|
10 |
-
|
11 |
-
from ..modules import UNCONDITIONAL_CONFIG
|
12 |
-
from ..modules.autoencoding.temporal_ae import VideoDecoder
|
13 |
-
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
14 |
-
from ..modules.ema import LitEma
|
15 |
-
from ..util import (default, disabled_train, get_obj_from_str,
|
16 |
-
instantiate_from_config, log_txt_as_img)
|
17 |
-
|
18 |
-
|
19 |
-
class DiffusionEngine(pl.LightningModule):
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
network_config,
|
23 |
-
denoiser_config,
|
24 |
-
first_stage_config,
|
25 |
-
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
26 |
-
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
27 |
-
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
28 |
-
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
29 |
-
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
30 |
-
network_wrapper: Union[None, str] = None,
|
31 |
-
ckpt_path: Union[None, str] = None,
|
32 |
-
use_ema: bool = False,
|
33 |
-
ema_decay_rate: float = 0.9999,
|
34 |
-
scale_factor: float = 1.0,
|
35 |
-
disable_first_stage_autocast=False,
|
36 |
-
input_key: str = "jpg",
|
37 |
-
log_keys: Union[List, None] = None,
|
38 |
-
no_cond_log: bool = False,
|
39 |
-
compile_model: bool = False,
|
40 |
-
en_and_decode_n_samples_a_time: Optional[int] = None,
|
41 |
-
):
|
42 |
-
super().__init__()
|
43 |
-
self.log_keys = log_keys
|
44 |
-
self.input_key = input_key
|
45 |
-
self.optimizer_config = default(
|
46 |
-
optimizer_config, {"target": "torch.optim.AdamW"}
|
47 |
-
)
|
48 |
-
model = instantiate_from_config(network_config)
|
49 |
-
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
50 |
-
model, compile_model=compile_model
|
51 |
-
)
|
52 |
-
|
53 |
-
self.denoiser = instantiate_from_config(denoiser_config)
|
54 |
-
self.sampler = (
|
55 |
-
instantiate_from_config(sampler_config)
|
56 |
-
if sampler_config is not None
|
57 |
-
else None
|
58 |
-
)
|
59 |
-
self.conditioner = instantiate_from_config(
|
60 |
-
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
61 |
-
)
|
62 |
-
self.scheduler_config = scheduler_config
|
63 |
-
self._init_first_stage(first_stage_config)
|
64 |
-
|
65 |
-
self.loss_fn = (
|
66 |
-
instantiate_from_config(loss_fn_config)
|
67 |
-
if loss_fn_config is not None
|
68 |
-
else None
|
69 |
-
)
|
70 |
-
|
71 |
-
self.use_ema = use_ema
|
72 |
-
if self.use_ema:
|
73 |
-
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
74 |
-
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
75 |
-
|
76 |
-
self.scale_factor = scale_factor
|
77 |
-
self.disable_first_stage_autocast = disable_first_stage_autocast
|
78 |
-
self.no_cond_log = no_cond_log
|
79 |
-
|
80 |
-
if ckpt_path is not None:
|
81 |
-
self.init_from_ckpt(ckpt_path)
|
82 |
-
|
83 |
-
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
84 |
-
|
85 |
-
def init_from_ckpt(
|
86 |
-
self,
|
87 |
-
path: str,
|
88 |
-
) -> None:
|
89 |
-
if path.endswith("ckpt"):
|
90 |
-
sd = torch.load(path, map_location="cpu")["state_dict"]
|
91 |
-
elif path.endswith("safetensors"):
|
92 |
-
sd = load_safetensors(path)
|
93 |
-
else:
|
94 |
-
raise NotImplementedError
|
95 |
-
|
96 |
-
missing, unexpected = self.load_state_dict(sd, strict=False)
|
97 |
-
print(
|
98 |
-
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
99 |
-
)
|
100 |
-
if len(missing) > 0:
|
101 |
-
print(f"Missing Keys: {missing}")
|
102 |
-
if len(unexpected) > 0:
|
103 |
-
print(f"Unexpected Keys: {unexpected}")
|
104 |
-
|
105 |
-
def _init_first_stage(self, config):
|
106 |
-
model = instantiate_from_config(config).eval()
|
107 |
-
model.train = disabled_train
|
108 |
-
for param in model.parameters():
|
109 |
-
param.requires_grad = False
|
110 |
-
self.first_stage_model = model
|
111 |
-
|
112 |
-
def get_input(self, batch):
|
113 |
-
# assuming unified data format, dataloader returns a dict.
|
114 |
-
# image tensors should be scaled to -1 ... 1 and in bchw format
|
115 |
-
return batch[self.input_key]
|
116 |
-
|
117 |
-
@torch.no_grad()
|
118 |
-
def decode_first_stage(self, z):
|
119 |
-
z = 1.0 / self.scale_factor * z
|
120 |
-
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
121 |
-
|
122 |
-
n_rounds = math.ceil(z.shape[0] / n_samples)
|
123 |
-
all_out = []
|
124 |
-
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
125 |
-
for n in range(n_rounds):
|
126 |
-
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
127 |
-
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
128 |
-
else:
|
129 |
-
kwargs = {}
|
130 |
-
out = self.first_stage_model.decode(
|
131 |
-
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
132 |
-
)
|
133 |
-
all_out.append(out)
|
134 |
-
out = torch.cat(all_out, dim=0)
|
135 |
-
return out
|
136 |
-
|
137 |
-
@torch.no_grad()
|
138 |
-
def encode_first_stage(self, x):
|
139 |
-
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
140 |
-
n_rounds = math.ceil(x.shape[0] / n_samples)
|
141 |
-
all_out = []
|
142 |
-
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
|
143 |
-
for n in range(n_rounds):
|
144 |
-
out = self.first_stage_model.encode(
|
145 |
-
x[n * n_samples : (n + 1) * n_samples]
|
146 |
-
)
|
147 |
-
all_out.append(out)
|
148 |
-
z = torch.cat(all_out, dim=0)
|
149 |
-
z = self.scale_factor * z
|
150 |
-
return z
|
151 |
-
|
152 |
-
def forward(self, x, batch):
|
153 |
-
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
154 |
-
loss_mean = loss.mean()
|
155 |
-
loss_dict = {"loss": loss_mean}
|
156 |
-
return loss_mean, loss_dict
|
157 |
-
|
158 |
-
def shared_step(self, batch: Dict) -> Any:
|
159 |
-
x = self.get_input(batch)
|
160 |
-
x = self.encode_first_stage(x)
|
161 |
-
batch["global_step"] = self.global_step
|
162 |
-
loss, loss_dict = self(x, batch)
|
163 |
-
return loss, loss_dict
|
164 |
-
|
165 |
-
def training_step(self, batch, batch_idx):
|
166 |
-
loss, loss_dict = self.shared_step(batch)
|
167 |
-
|
168 |
-
self.log_dict(
|
169 |
-
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
170 |
-
)
|
171 |
-
|
172 |
-
self.log(
|
173 |
-
"global_step",
|
174 |
-
self.global_step,
|
175 |
-
prog_bar=True,
|
176 |
-
logger=True,
|
177 |
-
on_step=True,
|
178 |
-
on_epoch=False,
|
179 |
-
)
|
180 |
-
|
181 |
-
if self.scheduler_config is not None:
|
182 |
-
lr = self.optimizers().param_groups[0]["lr"]
|
183 |
-
self.log(
|
184 |
-
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
185 |
-
)
|
186 |
-
|
187 |
-
return loss
|
188 |
-
|
189 |
-
def on_train_start(self, *args, **kwargs):
|
190 |
-
if self.sampler is None or self.loss_fn is None:
|
191 |
-
raise ValueError("Sampler and loss function need to be set for training.")
|
192 |
-
|
193 |
-
def on_train_batch_end(self, *args, **kwargs):
|
194 |
-
if self.use_ema:
|
195 |
-
self.model_ema(self.model)
|
196 |
-
|
197 |
-
@contextmanager
|
198 |
-
def ema_scope(self, context=None):
|
199 |
-
if self.use_ema:
|
200 |
-
self.model_ema.store(self.model.parameters())
|
201 |
-
self.model_ema.copy_to(self.model)
|
202 |
-
if context is not None:
|
203 |
-
print(f"{context}: Switched to EMA weights")
|
204 |
-
try:
|
205 |
-
yield None
|
206 |
-
finally:
|
207 |
-
if self.use_ema:
|
208 |
-
self.model_ema.restore(self.model.parameters())
|
209 |
-
if context is not None:
|
210 |
-
print(f"{context}: Restored training weights")
|
211 |
-
|
212 |
-
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
213 |
-
return get_obj_from_str(cfg["target"])(
|
214 |
-
params, lr=lr, **cfg.get("params", dict())
|
215 |
-
)
|
216 |
-
|
217 |
-
def configure_optimizers(self):
|
218 |
-
lr = self.learning_rate
|
219 |
-
params = list(self.model.parameters())
|
220 |
-
for embedder in self.conditioner.embedders:
|
221 |
-
if embedder.is_trainable:
|
222 |
-
params = params + list(embedder.parameters())
|
223 |
-
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
224 |
-
if self.scheduler_config is not None:
|
225 |
-
scheduler = instantiate_from_config(self.scheduler_config)
|
226 |
-
print("Setting up LambdaLR scheduler...")
|
227 |
-
scheduler = [
|
228 |
-
{
|
229 |
-
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
230 |
-
"interval": "step",
|
231 |
-
"frequency": 1,
|
232 |
-
}
|
233 |
-
]
|
234 |
-
return [opt], scheduler
|
235 |
-
return opt
|
236 |
-
|
237 |
-
@torch.no_grad()
|
238 |
-
def sample(
|
239 |
-
self,
|
240 |
-
cond: Dict,
|
241 |
-
uc: Union[Dict, None] = None,
|
242 |
-
batch_size: int = 16,
|
243 |
-
shape: Union[None, Tuple, List] = None,
|
244 |
-
**kwargs,
|
245 |
-
):
|
246 |
-
randn = torch.randn(batch_size, *shape).to(self.device)
|
247 |
-
|
248 |
-
denoiser = lambda input, sigma, c: self.denoiser(
|
249 |
-
self.model, input, sigma, c, **kwargs
|
250 |
-
)
|
251 |
-
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
252 |
-
return samples
|
253 |
-
|
254 |
-
@torch.no_grad()
|
255 |
-
def log_conditionings(self, batch: Dict, n: int) -> Dict:
|
256 |
-
"""
|
257 |
-
Defines heuristics to log different conditionings.
|
258 |
-
These can be lists of strings (text-to-image), tensors, ints, ...
|
259 |
-
"""
|
260 |
-
image_h, image_w = batch[self.input_key].shape[2:]
|
261 |
-
log = dict()
|
262 |
-
|
263 |
-
for embedder in self.conditioner.embedders:
|
264 |
-
if (
|
265 |
-
(self.log_keys is None) or (embedder.input_key in self.log_keys)
|
266 |
-
) and not self.no_cond_log:
|
267 |
-
x = batch[embedder.input_key][:n]
|
268 |
-
if isinstance(x, torch.Tensor):
|
269 |
-
if x.dim() == 1:
|
270 |
-
# class-conditional, convert integer to string
|
271 |
-
x = [str(x[i].item()) for i in range(x.shape[0])]
|
272 |
-
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
|
273 |
-
elif x.dim() == 2:
|
274 |
-
# size and crop cond and the like
|
275 |
-
x = [
|
276 |
-
"x".join([str(xx) for xx in x[i].tolist()])
|
277 |
-
for i in range(x.shape[0])
|
278 |
-
]
|
279 |
-
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
280 |
-
else:
|
281 |
-
raise NotImplementedError()
|
282 |
-
elif isinstance(x, (List, ListConfig)):
|
283 |
-
if isinstance(x[0], str):
|
284 |
-
# strings
|
285 |
-
xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
|
286 |
-
else:
|
287 |
-
raise NotImplementedError()
|
288 |
-
else:
|
289 |
-
raise NotImplementedError()
|
290 |
-
log[embedder.input_key] = xc
|
291 |
-
return log
|
292 |
-
|
293 |
-
@torch.no_grad()
|
294 |
-
def log_images(
|
295 |
-
self,
|
296 |
-
batch: Dict,
|
297 |
-
N: int = 8,
|
298 |
-
sample: bool = True,
|
299 |
-
ucg_keys: List[str] = None,
|
300 |
-
**kwargs,
|
301 |
-
) -> Dict:
|
302 |
-
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
303 |
-
if ucg_keys:
|
304 |
-
assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
|
305 |
-
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
306 |
-
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
307 |
-
)
|
308 |
-
else:
|
309 |
-
ucg_keys = conditioner_input_keys
|
310 |
-
log = dict()
|
311 |
-
|
312 |
-
x = self.get_input(batch)
|
313 |
-
|
314 |
-
c, uc = self.conditioner.get_unconditional_conditioning(
|
315 |
-
batch,
|
316 |
-
force_uc_zero_embeddings=ucg_keys
|
317 |
-
if len(self.conditioner.embedders) > 0
|
318 |
-
else [],
|
319 |
-
)
|
320 |
-
|
321 |
-
sampling_kwargs = {}
|
322 |
-
|
323 |
-
N = min(x.shape[0], N)
|
324 |
-
x = x.to(self.device)[:N]
|
325 |
-
log["inputs"] = x
|
326 |
-
z = self.encode_first_stage(x)
|
327 |
-
log["reconstructions"] = self.decode_first_stage(z)
|
328 |
-
log.update(self.log_conditionings(batch, N))
|
329 |
-
|
330 |
-
for k in c:
|
331 |
-
if isinstance(c[k], torch.Tensor):
|
332 |
-
c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
|
333 |
-
|
334 |
-
if sample:
|
335 |
-
with self.ema_scope("Plotting"):
|
336 |
-
samples = self.sample(
|
337 |
-
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
338 |
-
)
|
339 |
-
samples = self.decode_first_stage(samples)
|
340 |
-
log["samples"] = samples
|
341 |
-
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
from .encoders.modules import GeneralConditioner
|
2 |
-
|
3 |
-
UNCONDITIONAL_CONFIG = {
|
4 |
-
"target": "sgm.modules.GeneralConditioner",
|
5 |
-
"params": {"emb_models": []},
|
6 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/attention.py
DELETED
@@ -1,759 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import math
|
3 |
-
from inspect import isfunction
|
4 |
-
from typing import Any, Optional
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from einops import rearrange, repeat
|
9 |
-
from packaging import version
|
10 |
-
from torch import nn
|
11 |
-
from torch.utils.checkpoint import checkpoint
|
12 |
-
|
13 |
-
logpy = logging.getLogger(__name__)
|
14 |
-
|
15 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
16 |
-
SDP_IS_AVAILABLE = True
|
17 |
-
from torch.backends.cuda import SDPBackend, sdp_kernel
|
18 |
-
|
19 |
-
BACKEND_MAP = {
|
20 |
-
SDPBackend.MATH: {
|
21 |
-
"enable_math": True,
|
22 |
-
"enable_flash": False,
|
23 |
-
"enable_mem_efficient": False,
|
24 |
-
},
|
25 |
-
SDPBackend.FLASH_ATTENTION: {
|
26 |
-
"enable_math": False,
|
27 |
-
"enable_flash": True,
|
28 |
-
"enable_mem_efficient": False,
|
29 |
-
},
|
30 |
-
SDPBackend.EFFICIENT_ATTENTION: {
|
31 |
-
"enable_math": False,
|
32 |
-
"enable_flash": False,
|
33 |
-
"enable_mem_efficient": True,
|
34 |
-
},
|
35 |
-
None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
|
36 |
-
}
|
37 |
-
else:
|
38 |
-
from contextlib import nullcontext
|
39 |
-
|
40 |
-
SDP_IS_AVAILABLE = False
|
41 |
-
sdp_kernel = nullcontext
|
42 |
-
BACKEND_MAP = {}
|
43 |
-
logpy.warn(
|
44 |
-
f"No SDP backend available, likely because you are running in pytorch "
|
45 |
-
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
46 |
-
f"You might want to consider upgrading."
|
47 |
-
)
|
48 |
-
|
49 |
-
try:
|
50 |
-
import xformers
|
51 |
-
import xformers.ops
|
52 |
-
|
53 |
-
XFORMERS_IS_AVAILABLE = True
|
54 |
-
except:
|
55 |
-
XFORMERS_IS_AVAILABLE = False
|
56 |
-
logpy.warn("no module 'xformers'. Processing without...")
|
57 |
-
|
58 |
-
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
59 |
-
|
60 |
-
|
61 |
-
def exists(val):
|
62 |
-
return val is not None
|
63 |
-
|
64 |
-
|
65 |
-
def uniq(arr):
|
66 |
-
return {el: True for el in arr}.keys()
|
67 |
-
|
68 |
-
|
69 |
-
def default(val, d):
|
70 |
-
if exists(val):
|
71 |
-
return val
|
72 |
-
return d() if isfunction(d) else d
|
73 |
-
|
74 |
-
|
75 |
-
def max_neg_value(t):
|
76 |
-
return -torch.finfo(t.dtype).max
|
77 |
-
|
78 |
-
|
79 |
-
def init_(tensor):
|
80 |
-
dim = tensor.shape[-1]
|
81 |
-
std = 1 / math.sqrt(dim)
|
82 |
-
tensor.uniform_(-std, std)
|
83 |
-
return tensor
|
84 |
-
|
85 |
-
|
86 |
-
# feedforward
|
87 |
-
class GEGLU(nn.Module):
|
88 |
-
def __init__(self, dim_in, dim_out):
|
89 |
-
super().__init__()
|
90 |
-
self.proj = nn.Linear(dim_in, dim_out * 2)
|
91 |
-
|
92 |
-
def forward(self, x):
|
93 |
-
x, gate = self.proj(x).chunk(2, dim=-1)
|
94 |
-
return x * F.gelu(gate)
|
95 |
-
|
96 |
-
|
97 |
-
class FeedForward(nn.Module):
|
98 |
-
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
99 |
-
super().__init__()
|
100 |
-
inner_dim = int(dim * mult)
|
101 |
-
dim_out = default(dim_out, dim)
|
102 |
-
project_in = (
|
103 |
-
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
104 |
-
if not glu
|
105 |
-
else GEGLU(dim, inner_dim)
|
106 |
-
)
|
107 |
-
|
108 |
-
self.net = nn.Sequential(
|
109 |
-
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
110 |
-
)
|
111 |
-
|
112 |
-
def forward(self, x):
|
113 |
-
return self.net(x)
|
114 |
-
|
115 |
-
|
116 |
-
def zero_module(module):
|
117 |
-
"""
|
118 |
-
Zero out the parameters of a module and return it.
|
119 |
-
"""
|
120 |
-
for p in module.parameters():
|
121 |
-
p.detach().zero_()
|
122 |
-
return module
|
123 |
-
|
124 |
-
|
125 |
-
def Normalize(in_channels):
|
126 |
-
return torch.nn.GroupNorm(
|
127 |
-
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
128 |
-
)
|
129 |
-
|
130 |
-
|
131 |
-
class LinearAttention(nn.Module):
|
132 |
-
def __init__(self, dim, heads=4, dim_head=32):
|
133 |
-
super().__init__()
|
134 |
-
self.heads = heads
|
135 |
-
hidden_dim = dim_head * heads
|
136 |
-
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
137 |
-
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
138 |
-
|
139 |
-
def forward(self, x):
|
140 |
-
b, c, h, w = x.shape
|
141 |
-
qkv = self.to_qkv(x)
|
142 |
-
q, k, v = rearrange(
|
143 |
-
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
144 |
-
)
|
145 |
-
k = k.softmax(dim=-1)
|
146 |
-
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
147 |
-
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
148 |
-
out = rearrange(
|
149 |
-
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
150 |
-
)
|
151 |
-
return self.to_out(out)
|
152 |
-
|
153 |
-
|
154 |
-
class SelfAttention(nn.Module):
|
155 |
-
ATTENTION_MODES = ("xformers", "torch", "math")
|
156 |
-
|
157 |
-
def __init__(
|
158 |
-
self,
|
159 |
-
dim: int,
|
160 |
-
num_heads: int = 8,
|
161 |
-
qkv_bias: bool = False,
|
162 |
-
qk_scale: Optional[float] = None,
|
163 |
-
attn_drop: float = 0.0,
|
164 |
-
proj_drop: float = 0.0,
|
165 |
-
attn_mode: str = "xformers",
|
166 |
-
):
|
167 |
-
super().__init__()
|
168 |
-
self.num_heads = num_heads
|
169 |
-
head_dim = dim // num_heads
|
170 |
-
self.scale = qk_scale or head_dim**-0.5
|
171 |
-
|
172 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
173 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
174 |
-
self.proj = nn.Linear(dim, dim)
|
175 |
-
self.proj_drop = nn.Dropout(proj_drop)
|
176 |
-
assert attn_mode in self.ATTENTION_MODES
|
177 |
-
self.attn_mode = attn_mode
|
178 |
-
|
179 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
180 |
-
B, L, C = x.shape
|
181 |
-
|
182 |
-
qkv = self.qkv(x)
|
183 |
-
if self.attn_mode == "torch":
|
184 |
-
qkv = rearrange(
|
185 |
-
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
186 |
-
).float()
|
187 |
-
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
188 |
-
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
189 |
-
x = rearrange(x, "B H L D -> B L (H D)")
|
190 |
-
elif self.attn_mode == "xformers":
|
191 |
-
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
192 |
-
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
193 |
-
x = xformers.ops.memory_efficient_attention(q, k, v)
|
194 |
-
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
195 |
-
elif self.attn_mode == "math":
|
196 |
-
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
197 |
-
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
198 |
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
199 |
-
attn = attn.softmax(dim=-1)
|
200 |
-
attn = self.attn_drop(attn)
|
201 |
-
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
202 |
-
else:
|
203 |
-
raise NotImplemented
|
204 |
-
|
205 |
-
x = self.proj(x)
|
206 |
-
x = self.proj_drop(x)
|
207 |
-
return x
|
208 |
-
|
209 |
-
|
210 |
-
class SpatialSelfAttention(nn.Module):
|
211 |
-
def __init__(self, in_channels):
|
212 |
-
super().__init__()
|
213 |
-
self.in_channels = in_channels
|
214 |
-
|
215 |
-
self.norm = Normalize(in_channels)
|
216 |
-
self.q = torch.nn.Conv2d(
|
217 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
218 |
-
)
|
219 |
-
self.k = torch.nn.Conv2d(
|
220 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
221 |
-
)
|
222 |
-
self.v = torch.nn.Conv2d(
|
223 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
224 |
-
)
|
225 |
-
self.proj_out = torch.nn.Conv2d(
|
226 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
227 |
-
)
|
228 |
-
|
229 |
-
def forward(self, x):
|
230 |
-
h_ = x
|
231 |
-
h_ = self.norm(h_)
|
232 |
-
q = self.q(h_)
|
233 |
-
k = self.k(h_)
|
234 |
-
v = self.v(h_)
|
235 |
-
|
236 |
-
# compute attention
|
237 |
-
b, c, h, w = q.shape
|
238 |
-
q = rearrange(q, "b c h w -> b (h w) c")
|
239 |
-
k = rearrange(k, "b c h w -> b c (h w)")
|
240 |
-
w_ = torch.einsum("bij,bjk->bik", q, k)
|
241 |
-
|
242 |
-
w_ = w_ * (int(c) ** (-0.5))
|
243 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
244 |
-
|
245 |
-
# attend to values
|
246 |
-
v = rearrange(v, "b c h w -> b c (h w)")
|
247 |
-
w_ = rearrange(w_, "b i j -> b j i")
|
248 |
-
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
249 |
-
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
250 |
-
h_ = self.proj_out(h_)
|
251 |
-
|
252 |
-
return x + h_
|
253 |
-
|
254 |
-
|
255 |
-
class CrossAttention(nn.Module):
|
256 |
-
def __init__(
|
257 |
-
self,
|
258 |
-
query_dim,
|
259 |
-
context_dim=None,
|
260 |
-
heads=8,
|
261 |
-
dim_head=64,
|
262 |
-
dropout=0.0,
|
263 |
-
backend=None,
|
264 |
-
):
|
265 |
-
super().__init__()
|
266 |
-
inner_dim = dim_head * heads
|
267 |
-
context_dim = default(context_dim, query_dim)
|
268 |
-
|
269 |
-
self.scale = dim_head**-0.5
|
270 |
-
self.heads = heads
|
271 |
-
|
272 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
273 |
-
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
274 |
-
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
275 |
-
|
276 |
-
self.to_out = nn.Sequential(
|
277 |
-
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
278 |
-
)
|
279 |
-
self.backend = backend
|
280 |
-
|
281 |
-
def forward(
|
282 |
-
self,
|
283 |
-
x,
|
284 |
-
context=None,
|
285 |
-
mask=None,
|
286 |
-
additional_tokens=None,
|
287 |
-
n_times_crossframe_attn_in_self=0,
|
288 |
-
):
|
289 |
-
h = self.heads
|
290 |
-
|
291 |
-
if additional_tokens is not None:
|
292 |
-
# get the number of masked tokens at the beginning of the output sequence
|
293 |
-
n_tokens_to_mask = additional_tokens.shape[1]
|
294 |
-
# add additional token
|
295 |
-
x = torch.cat([additional_tokens, x], dim=1)
|
296 |
-
|
297 |
-
q = self.to_q(x)
|
298 |
-
context = default(context, x)
|
299 |
-
k = self.to_k(context)
|
300 |
-
v = self.to_v(context)
|
301 |
-
|
302 |
-
if n_times_crossframe_attn_in_self:
|
303 |
-
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
304 |
-
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
305 |
-
n_cp = x.shape[0] // n_times_crossframe_attn_in_self
|
306 |
-
k = repeat(
|
307 |
-
k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
308 |
-
)
|
309 |
-
v = repeat(
|
310 |
-
v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
|
311 |
-
)
|
312 |
-
|
313 |
-
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
314 |
-
|
315 |
-
## old
|
316 |
-
"""
|
317 |
-
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
318 |
-
del q, k
|
319 |
-
|
320 |
-
if exists(mask):
|
321 |
-
mask = rearrange(mask, 'b ... -> b (...)')
|
322 |
-
max_neg_value = -torch.finfo(sim.dtype).max
|
323 |
-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
324 |
-
sim.masked_fill_(~mask, max_neg_value)
|
325 |
-
|
326 |
-
# attention, what we cannot get enough of
|
327 |
-
sim = sim.softmax(dim=-1)
|
328 |
-
|
329 |
-
out = einsum('b i j, b j d -> b i d', sim, v)
|
330 |
-
"""
|
331 |
-
## new
|
332 |
-
with sdp_kernel(**BACKEND_MAP[self.backend]):
|
333 |
-
# print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
|
334 |
-
out = F.scaled_dot_product_attention(
|
335 |
-
q, k, v, attn_mask=mask
|
336 |
-
) # scale is dim_head ** -0.5 per default
|
337 |
-
|
338 |
-
del q, k, v
|
339 |
-
out = rearrange(out, "b h n d -> b n (h d)", h=h)
|
340 |
-
|
341 |
-
if additional_tokens is not None:
|
342 |
-
# remove additional token
|
343 |
-
out = out[:, n_tokens_to_mask:]
|
344 |
-
return self.to_out(out)
|
345 |
-
|
346 |
-
|
347 |
-
class MemoryEfficientCrossAttention(nn.Module):
|
348 |
-
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
349 |
-
def __init__(
|
350 |
-
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
351 |
-
):
|
352 |
-
super().__init__()
|
353 |
-
logpy.debug(
|
354 |
-
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
355 |
-
f"context_dim is {context_dim} and using {heads} heads with a "
|
356 |
-
f"dimension of {dim_head}."
|
357 |
-
)
|
358 |
-
inner_dim = dim_head * heads
|
359 |
-
context_dim = default(context_dim, query_dim)
|
360 |
-
|
361 |
-
self.heads = heads
|
362 |
-
self.dim_head = dim_head
|
363 |
-
|
364 |
-
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
365 |
-
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
366 |
-
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
367 |
-
|
368 |
-
self.to_out = nn.Sequential(
|
369 |
-
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
|
370 |
-
)
|
371 |
-
self.attention_op: Optional[Any] = None
|
372 |
-
|
373 |
-
def forward(
|
374 |
-
self,
|
375 |
-
x,
|
376 |
-
context=None,
|
377 |
-
mask=None,
|
378 |
-
additional_tokens=None,
|
379 |
-
n_times_crossframe_attn_in_self=0,
|
380 |
-
):
|
381 |
-
if additional_tokens is not None:
|
382 |
-
# get the number of masked tokens at the beginning of the output sequence
|
383 |
-
n_tokens_to_mask = additional_tokens.shape[1]
|
384 |
-
# add additional token
|
385 |
-
x = torch.cat([additional_tokens, x], dim=1)
|
386 |
-
q = self.to_q(x)
|
387 |
-
context = default(context, x)
|
388 |
-
k = self.to_k(context)
|
389 |
-
v = self.to_v(context)
|
390 |
-
|
391 |
-
if n_times_crossframe_attn_in_self:
|
392 |
-
# reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
|
393 |
-
assert x.shape[0] % n_times_crossframe_attn_in_self == 0
|
394 |
-
# n_cp = x.shape[0]//n_times_crossframe_attn_in_self
|
395 |
-
k = repeat(
|
396 |
-
k[::n_times_crossframe_attn_in_self],
|
397 |
-
"b ... -> (b n) ...",
|
398 |
-
n=n_times_crossframe_attn_in_self,
|
399 |
-
)
|
400 |
-
v = repeat(
|
401 |
-
v[::n_times_crossframe_attn_in_self],
|
402 |
-
"b ... -> (b n) ...",
|
403 |
-
n=n_times_crossframe_attn_in_self,
|
404 |
-
)
|
405 |
-
|
406 |
-
b, _, _ = q.shape
|
407 |
-
q, k, v = map(
|
408 |
-
lambda t: t.unsqueeze(3)
|
409 |
-
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
410 |
-
.permute(0, 2, 1, 3)
|
411 |
-
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
412 |
-
.contiguous(),
|
413 |
-
(q, k, v),
|
414 |
-
)
|
415 |
-
|
416 |
-
# actually compute the attention, what we cannot get enough of
|
417 |
-
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
418 |
-
# NOTE: workaround for
|
419 |
-
# https://github.com/facebookresearch/xformers/issues/845
|
420 |
-
max_bs = 32768
|
421 |
-
N = q.shape[0]
|
422 |
-
n_batches = math.ceil(N / max_bs)
|
423 |
-
out = list()
|
424 |
-
for i_batch in range(n_batches):
|
425 |
-
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
426 |
-
out.append(
|
427 |
-
xformers.ops.memory_efficient_attention(
|
428 |
-
q[batch],
|
429 |
-
k[batch],
|
430 |
-
v[batch],
|
431 |
-
attn_bias=None,
|
432 |
-
op=self.attention_op,
|
433 |
-
)
|
434 |
-
)
|
435 |
-
out = torch.cat(out, 0)
|
436 |
-
else:
|
437 |
-
out = xformers.ops.memory_efficient_attention(
|
438 |
-
q, k, v, attn_bias=None, op=self.attention_op
|
439 |
-
)
|
440 |
-
|
441 |
-
# TODO: Use this directly in the attention operation, as a bias
|
442 |
-
if exists(mask):
|
443 |
-
raise NotImplementedError
|
444 |
-
out = (
|
445 |
-
out.unsqueeze(0)
|
446 |
-
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
447 |
-
.permute(0, 2, 1, 3)
|
448 |
-
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
449 |
-
)
|
450 |
-
if additional_tokens is not None:
|
451 |
-
# remove additional token
|
452 |
-
out = out[:, n_tokens_to_mask:]
|
453 |
-
return self.to_out(out)
|
454 |
-
|
455 |
-
|
456 |
-
class BasicTransformerBlock(nn.Module):
|
457 |
-
ATTENTION_MODES = {
|
458 |
-
"softmax": CrossAttention, # vanilla attention
|
459 |
-
"softmax-xformers": MemoryEfficientCrossAttention, # ampere
|
460 |
-
}
|
461 |
-
|
462 |
-
def __init__(
|
463 |
-
self,
|
464 |
-
dim,
|
465 |
-
n_heads,
|
466 |
-
d_head,
|
467 |
-
dropout=0.0,
|
468 |
-
context_dim=None,
|
469 |
-
gated_ff=True,
|
470 |
-
checkpoint=True,
|
471 |
-
disable_self_attn=False,
|
472 |
-
attn_mode="softmax",
|
473 |
-
sdp_backend=None,
|
474 |
-
):
|
475 |
-
super().__init__()
|
476 |
-
assert attn_mode in self.ATTENTION_MODES
|
477 |
-
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
478 |
-
logpy.warn(
|
479 |
-
f"Attention mode '{attn_mode}' is not available. Falling "
|
480 |
-
f"back to native attention. This is not a problem in "
|
481 |
-
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
482 |
-
f"version {torch.__version__}."
|
483 |
-
)
|
484 |
-
attn_mode = "softmax"
|
485 |
-
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
486 |
-
logpy.warn(
|
487 |
-
"We do not support vanilla attention anymore, as it is too "
|
488 |
-
"expensive. Sorry."
|
489 |
-
)
|
490 |
-
if not XFORMERS_IS_AVAILABLE:
|
491 |
-
assert (
|
492 |
-
False
|
493 |
-
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
494 |
-
else:
|
495 |
-
logpy.info("Falling back to xformers efficient attention.")
|
496 |
-
attn_mode = "softmax-xformers"
|
497 |
-
attn_cls = self.ATTENTION_MODES[attn_mode]
|
498 |
-
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
499 |
-
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
500 |
-
else:
|
501 |
-
assert sdp_backend is None
|
502 |
-
self.disable_self_attn = disable_self_attn
|
503 |
-
self.attn1 = attn_cls(
|
504 |
-
query_dim=dim,
|
505 |
-
heads=n_heads,
|
506 |
-
dim_head=d_head,
|
507 |
-
dropout=dropout,
|
508 |
-
context_dim=context_dim if self.disable_self_attn else None,
|
509 |
-
backend=sdp_backend,
|
510 |
-
) # is a self-attention if not self.disable_self_attn
|
511 |
-
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
512 |
-
self.attn2 = attn_cls(
|
513 |
-
query_dim=dim,
|
514 |
-
context_dim=context_dim,
|
515 |
-
heads=n_heads,
|
516 |
-
dim_head=d_head,
|
517 |
-
dropout=dropout,
|
518 |
-
backend=sdp_backend,
|
519 |
-
) # is self-attn if context is none
|
520 |
-
self.norm1 = nn.LayerNorm(dim)
|
521 |
-
self.norm2 = nn.LayerNorm(dim)
|
522 |
-
self.norm3 = nn.LayerNorm(dim)
|
523 |
-
self.checkpoint = checkpoint
|
524 |
-
if self.checkpoint:
|
525 |
-
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
526 |
-
|
527 |
-
def forward(
|
528 |
-
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
529 |
-
):
|
530 |
-
kwargs = {"x": x}
|
531 |
-
|
532 |
-
if context is not None:
|
533 |
-
kwargs.update({"context": context})
|
534 |
-
|
535 |
-
if additional_tokens is not None:
|
536 |
-
kwargs.update({"additional_tokens": additional_tokens})
|
537 |
-
|
538 |
-
if n_times_crossframe_attn_in_self:
|
539 |
-
kwargs.update(
|
540 |
-
{"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
|
541 |
-
)
|
542 |
-
|
543 |
-
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
544 |
-
if self.checkpoint:
|
545 |
-
# inputs = {"x": x, "context": context}
|
546 |
-
return checkpoint(self._forward, x, context)
|
547 |
-
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
548 |
-
else:
|
549 |
-
return self._forward(**kwargs)
|
550 |
-
|
551 |
-
def _forward(
|
552 |
-
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
553 |
-
):
|
554 |
-
x = (
|
555 |
-
self.attn1(
|
556 |
-
self.norm1(x),
|
557 |
-
context=context if self.disable_self_attn else None,
|
558 |
-
additional_tokens=additional_tokens,
|
559 |
-
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
|
560 |
-
if not self.disable_self_attn
|
561 |
-
else 0,
|
562 |
-
)
|
563 |
-
+ x
|
564 |
-
)
|
565 |
-
x = (
|
566 |
-
self.attn2(
|
567 |
-
self.norm2(x), context=context, additional_tokens=additional_tokens
|
568 |
-
)
|
569 |
-
+ x
|
570 |
-
)
|
571 |
-
x = self.ff(self.norm3(x)) + x
|
572 |
-
return x
|
573 |
-
|
574 |
-
|
575 |
-
class BasicTransformerSingleLayerBlock(nn.Module):
|
576 |
-
ATTENTION_MODES = {
|
577 |
-
"softmax": CrossAttention, # vanilla attention
|
578 |
-
"softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
|
579 |
-
# (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
|
580 |
-
}
|
581 |
-
|
582 |
-
def __init__(
|
583 |
-
self,
|
584 |
-
dim,
|
585 |
-
n_heads,
|
586 |
-
d_head,
|
587 |
-
dropout=0.0,
|
588 |
-
context_dim=None,
|
589 |
-
gated_ff=True,
|
590 |
-
checkpoint=True,
|
591 |
-
attn_mode="softmax",
|
592 |
-
):
|
593 |
-
super().__init__()
|
594 |
-
assert attn_mode in self.ATTENTION_MODES
|
595 |
-
attn_cls = self.ATTENTION_MODES[attn_mode]
|
596 |
-
self.attn1 = attn_cls(
|
597 |
-
query_dim=dim,
|
598 |
-
heads=n_heads,
|
599 |
-
dim_head=d_head,
|
600 |
-
dropout=dropout,
|
601 |
-
context_dim=context_dim,
|
602 |
-
)
|
603 |
-
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
604 |
-
self.norm1 = nn.LayerNorm(dim)
|
605 |
-
self.norm2 = nn.LayerNorm(dim)
|
606 |
-
self.checkpoint = checkpoint
|
607 |
-
|
608 |
-
def forward(self, x, context=None):
|
609 |
-
# inputs = {"x": x, "context": context}
|
610 |
-
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
611 |
-
return checkpoint(self._forward, x, context)
|
612 |
-
|
613 |
-
def _forward(self, x, context=None):
|
614 |
-
x = self.attn1(self.norm1(x), context=context) + x
|
615 |
-
x = self.ff(self.norm2(x)) + x
|
616 |
-
return x
|
617 |
-
|
618 |
-
|
619 |
-
class SpatialTransformer(nn.Module):
|
620 |
-
"""
|
621 |
-
Transformer block for image-like data.
|
622 |
-
First, project the input (aka embedding)
|
623 |
-
and reshape to b, t, d.
|
624 |
-
Then apply standard transformer action.
|
625 |
-
Finally, reshape to image
|
626 |
-
NEW: use_linear for more efficiency instead of the 1x1 convs
|
627 |
-
"""
|
628 |
-
|
629 |
-
def __init__(
|
630 |
-
self,
|
631 |
-
in_channels,
|
632 |
-
n_heads,
|
633 |
-
d_head,
|
634 |
-
depth=1,
|
635 |
-
dropout=0.0,
|
636 |
-
context_dim=None,
|
637 |
-
disable_self_attn=False,
|
638 |
-
use_linear=False,
|
639 |
-
attn_type="softmax",
|
640 |
-
use_checkpoint=True,
|
641 |
-
# sdp_backend=SDPBackend.FLASH_ATTENTION
|
642 |
-
sdp_backend=None,
|
643 |
-
):
|
644 |
-
super().__init__()
|
645 |
-
logpy.debug(
|
646 |
-
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
647 |
-
f"{in_channels} channels and {n_heads} heads."
|
648 |
-
)
|
649 |
-
|
650 |
-
if exists(context_dim) and not isinstance(context_dim, list):
|
651 |
-
context_dim = [context_dim]
|
652 |
-
if exists(context_dim) and isinstance(context_dim, list):
|
653 |
-
if depth != len(context_dim):
|
654 |
-
logpy.warn(
|
655 |
-
f"{self.__class__.__name__}: Found context dims "
|
656 |
-
f"{context_dim} of depth {len(context_dim)}, which does not "
|
657 |
-
f"match the specified 'depth' of {depth}. Setting context_dim "
|
658 |
-
f"to {depth * [context_dim[0]]} now."
|
659 |
-
)
|
660 |
-
# depth does not match context dims.
|
661 |
-
assert all(
|
662 |
-
map(lambda x: x == context_dim[0], context_dim)
|
663 |
-
), "need homogenous context_dim to match depth automatically"
|
664 |
-
context_dim = depth * [context_dim[0]]
|
665 |
-
elif context_dim is None:
|
666 |
-
context_dim = [None] * depth
|
667 |
-
self.in_channels = in_channels
|
668 |
-
inner_dim = n_heads * d_head
|
669 |
-
self.norm = Normalize(in_channels)
|
670 |
-
if not use_linear:
|
671 |
-
self.proj_in = nn.Conv2d(
|
672 |
-
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
673 |
-
)
|
674 |
-
else:
|
675 |
-
self.proj_in = nn.Linear(in_channels, inner_dim)
|
676 |
-
|
677 |
-
self.transformer_blocks = nn.ModuleList(
|
678 |
-
[
|
679 |
-
BasicTransformerBlock(
|
680 |
-
inner_dim,
|
681 |
-
n_heads,
|
682 |
-
d_head,
|
683 |
-
dropout=dropout,
|
684 |
-
context_dim=context_dim[d],
|
685 |
-
disable_self_attn=disable_self_attn,
|
686 |
-
attn_mode=attn_type,
|
687 |
-
checkpoint=use_checkpoint,
|
688 |
-
sdp_backend=sdp_backend,
|
689 |
-
)
|
690 |
-
for d in range(depth)
|
691 |
-
]
|
692 |
-
)
|
693 |
-
if not use_linear:
|
694 |
-
self.proj_out = zero_module(
|
695 |
-
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
696 |
-
)
|
697 |
-
else:
|
698 |
-
# self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
699 |
-
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
700 |
-
self.use_linear = use_linear
|
701 |
-
|
702 |
-
def forward(self, x, context=None):
|
703 |
-
# note: if no context is given, cross-attention defaults to self-attention
|
704 |
-
if not isinstance(context, list):
|
705 |
-
context = [context]
|
706 |
-
b, c, h, w = x.shape
|
707 |
-
x_in = x
|
708 |
-
x = self.norm(x)
|
709 |
-
if not self.use_linear:
|
710 |
-
x = self.proj_in(x)
|
711 |
-
x = rearrange(x, "b c h w -> b (h w) c").contiguous()
|
712 |
-
if self.use_linear:
|
713 |
-
x = self.proj_in(x)
|
714 |
-
for i, block in enumerate(self.transformer_blocks):
|
715 |
-
if i > 0 and len(context) == 1:
|
716 |
-
i = 0 # use same context for each block
|
717 |
-
x = block(x, context=context[i])
|
718 |
-
if self.use_linear:
|
719 |
-
x = self.proj_out(x)
|
720 |
-
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
|
721 |
-
if not self.use_linear:
|
722 |
-
x = self.proj_out(x)
|
723 |
-
return x + x_in
|
724 |
-
|
725 |
-
|
726 |
-
class SimpleTransformer(nn.Module):
|
727 |
-
def __init__(
|
728 |
-
self,
|
729 |
-
dim: int,
|
730 |
-
depth: int,
|
731 |
-
heads: int,
|
732 |
-
dim_head: int,
|
733 |
-
context_dim: Optional[int] = None,
|
734 |
-
dropout: float = 0.0,
|
735 |
-
checkpoint: bool = True,
|
736 |
-
):
|
737 |
-
super().__init__()
|
738 |
-
self.layers = nn.ModuleList([])
|
739 |
-
for _ in range(depth):
|
740 |
-
self.layers.append(
|
741 |
-
BasicTransformerBlock(
|
742 |
-
dim,
|
743 |
-
heads,
|
744 |
-
dim_head,
|
745 |
-
dropout=dropout,
|
746 |
-
context_dim=context_dim,
|
747 |
-
attn_mode="softmax-xformers",
|
748 |
-
checkpoint=checkpoint,
|
749 |
-
)
|
750 |
-
)
|
751 |
-
|
752 |
-
def forward(
|
753 |
-
self,
|
754 |
-
x: torch.Tensor,
|
755 |
-
context: Optional[torch.Tensor] = None,
|
756 |
-
) -> torch.Tensor:
|
757 |
-
for layer in self.layers:
|
758 |
-
x = layer(x, context)
|
759 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/__init__.py
DELETED
File without changes
|
sgm/modules/autoencoding/losses/__init__.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
__all__ = [
|
2 |
-
"GeneralLPIPSWithDiscriminator",
|
3 |
-
"LatentLPIPS",
|
4 |
-
]
|
5 |
-
|
6 |
-
from .discriminator_loss import GeneralLPIPSWithDiscriminator
|
7 |
-
from .lpips import LatentLPIPS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/losses/discriminator_loss.py
DELETED
@@ -1,306 +0,0 @@
|
|
1 |
-
from typing import Dict, Iterator, List, Optional, Tuple, Union
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torchvision
|
7 |
-
from einops import rearrange
|
8 |
-
from matplotlib import colormaps
|
9 |
-
from matplotlib import pyplot as plt
|
10 |
-
|
11 |
-
from ....util import default, instantiate_from_config
|
12 |
-
from ..lpips.loss.lpips import LPIPS
|
13 |
-
from ..lpips.model.model import weights_init
|
14 |
-
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
|
15 |
-
|
16 |
-
|
17 |
-
class GeneralLPIPSWithDiscriminator(nn.Module):
|
18 |
-
def __init__(
|
19 |
-
self,
|
20 |
-
disc_start: int,
|
21 |
-
logvar_init: float = 0.0,
|
22 |
-
disc_num_layers: int = 3,
|
23 |
-
disc_in_channels: int = 3,
|
24 |
-
disc_factor: float = 1.0,
|
25 |
-
disc_weight: float = 1.0,
|
26 |
-
perceptual_weight: float = 1.0,
|
27 |
-
disc_loss: str = "hinge",
|
28 |
-
scale_input_to_tgt_size: bool = False,
|
29 |
-
dims: int = 2,
|
30 |
-
learn_logvar: bool = False,
|
31 |
-
regularization_weights: Union[None, Dict[str, float]] = None,
|
32 |
-
additional_log_keys: Optional[List[str]] = None,
|
33 |
-
discriminator_config: Optional[Dict] = None,
|
34 |
-
):
|
35 |
-
super().__init__()
|
36 |
-
self.dims = dims
|
37 |
-
if self.dims > 2:
|
38 |
-
print(
|
39 |
-
f"running with dims={dims}. This means that for perceptual loss "
|
40 |
-
f"calculation, the LPIPS loss will be applied to each frame "
|
41 |
-
f"independently."
|
42 |
-
)
|
43 |
-
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
44 |
-
assert disc_loss in ["hinge", "vanilla"]
|
45 |
-
self.perceptual_loss = LPIPS().eval()
|
46 |
-
self.perceptual_weight = perceptual_weight
|
47 |
-
# output log variance
|
48 |
-
self.logvar = nn.Parameter(
|
49 |
-
torch.full((), logvar_init), requires_grad=learn_logvar
|
50 |
-
)
|
51 |
-
self.learn_logvar = learn_logvar
|
52 |
-
|
53 |
-
discriminator_config = default(
|
54 |
-
discriminator_config,
|
55 |
-
{
|
56 |
-
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
|
57 |
-
"params": {
|
58 |
-
"input_nc": disc_in_channels,
|
59 |
-
"n_layers": disc_num_layers,
|
60 |
-
"use_actnorm": False,
|
61 |
-
},
|
62 |
-
},
|
63 |
-
)
|
64 |
-
|
65 |
-
self.discriminator = instantiate_from_config(discriminator_config).apply(
|
66 |
-
weights_init
|
67 |
-
)
|
68 |
-
self.discriminator_iter_start = disc_start
|
69 |
-
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
70 |
-
self.disc_factor = disc_factor
|
71 |
-
self.discriminator_weight = disc_weight
|
72 |
-
self.regularization_weights = default(regularization_weights, {})
|
73 |
-
|
74 |
-
self.forward_keys = [
|
75 |
-
"optimizer_idx",
|
76 |
-
"global_step",
|
77 |
-
"last_layer",
|
78 |
-
"split",
|
79 |
-
"regularization_log",
|
80 |
-
]
|
81 |
-
|
82 |
-
self.additional_log_keys = set(default(additional_log_keys, []))
|
83 |
-
self.additional_log_keys.update(set(self.regularization_weights.keys()))
|
84 |
-
|
85 |
-
def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
|
86 |
-
return self.discriminator.parameters()
|
87 |
-
|
88 |
-
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
|
89 |
-
if self.learn_logvar:
|
90 |
-
yield self.logvar
|
91 |
-
yield from ()
|
92 |
-
|
93 |
-
@torch.no_grad()
|
94 |
-
def log_images(
|
95 |
-
self, inputs: torch.Tensor, reconstructions: torch.Tensor
|
96 |
-
) -> Dict[str, torch.Tensor]:
|
97 |
-
# calc logits of real/fake
|
98 |
-
logits_real = self.discriminator(inputs.contiguous().detach())
|
99 |
-
if len(logits_real.shape) < 4:
|
100 |
-
# Non patch-discriminator
|
101 |
-
return dict()
|
102 |
-
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
103 |
-
# -> (b, 1, h, w)
|
104 |
-
|
105 |
-
# parameters for colormapping
|
106 |
-
high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
|
107 |
-
cmap = colormaps["PiYG"] # diverging colormap
|
108 |
-
|
109 |
-
def to_colormap(logits: torch.Tensor) -> torch.Tensor:
|
110 |
-
"""(b, 1, ...) -> (b, 3, ...)"""
|
111 |
-
logits = (logits + high) / (2 * high)
|
112 |
-
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
|
113 |
-
# -> (b, 1, ..., 3)
|
114 |
-
logits = torch.from_numpy(logits_np).to(logits.device)
|
115 |
-
return rearrange(logits, "b 1 ... c -> b c ...")
|
116 |
-
|
117 |
-
logits_real = torch.nn.functional.interpolate(
|
118 |
-
logits_real,
|
119 |
-
size=inputs.shape[-2:],
|
120 |
-
mode="nearest",
|
121 |
-
antialias=False,
|
122 |
-
)
|
123 |
-
logits_fake = torch.nn.functional.interpolate(
|
124 |
-
logits_fake,
|
125 |
-
size=reconstructions.shape[-2:],
|
126 |
-
mode="nearest",
|
127 |
-
antialias=False,
|
128 |
-
)
|
129 |
-
|
130 |
-
# alpha value of logits for overlay
|
131 |
-
alpha_real = torch.abs(logits_real) / high
|
132 |
-
alpha_fake = torch.abs(logits_fake) / high
|
133 |
-
# -> (b, 1, h, w) in range [0, 0.5]
|
134 |
-
# alpha value of lines don't really matter, since the values are the same
|
135 |
-
# for both images and logits anyway
|
136 |
-
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
|
137 |
-
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
|
138 |
-
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
|
139 |
-
# -> (1, h, w)
|
140 |
-
# blend logits and images together
|
141 |
-
|
142 |
-
# prepare logits for plotting
|
143 |
-
logits_real = to_colormap(logits_real)
|
144 |
-
logits_fake = to_colormap(logits_fake)
|
145 |
-
# resize logits
|
146 |
-
# -> (b, 3, h, w)
|
147 |
-
|
148 |
-
# make some grids
|
149 |
-
# add all logits to one plot
|
150 |
-
logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
|
151 |
-
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
|
152 |
-
# I just love how torchvision calls the number of columns `nrow`
|
153 |
-
grid_logits = torch.cat((logits_real, logits_fake), dim=1)
|
154 |
-
# -> (3, h, w)
|
155 |
-
|
156 |
-
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
|
157 |
-
grid_images_fake = torchvision.utils.make_grid(
|
158 |
-
0.5 * reconstructions + 0.5, nrow=4
|
159 |
-
)
|
160 |
-
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
|
161 |
-
# -> (3, h, w) in range [0, 1]
|
162 |
-
|
163 |
-
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
|
164 |
-
|
165 |
-
# Create labeled colorbar
|
166 |
-
dpi = 100
|
167 |
-
height = 128 / dpi
|
168 |
-
width = grid_logits.shape[2] / dpi
|
169 |
-
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
|
170 |
-
img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
|
171 |
-
plt.colorbar(
|
172 |
-
img,
|
173 |
-
cax=ax,
|
174 |
-
orientation="horizontal",
|
175 |
-
fraction=0.9,
|
176 |
-
aspect=width / height,
|
177 |
-
pad=0.0,
|
178 |
-
)
|
179 |
-
img.set_visible(False)
|
180 |
-
fig.tight_layout()
|
181 |
-
fig.canvas.draw()
|
182 |
-
# manually convert figure to numpy
|
183 |
-
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
184 |
-
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
185 |
-
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
|
186 |
-
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
|
187 |
-
|
188 |
-
# Add colorbar to plot
|
189 |
-
annotated_grid = torch.cat((grid_logits, cbar), dim=1)
|
190 |
-
blended_grid = torch.cat((grid_blend, cbar), dim=1)
|
191 |
-
return {
|
192 |
-
"vis_logits": 2 * annotated_grid[None, ...] - 1,
|
193 |
-
"vis_logits_blended": 2 * blended_grid[None, ...] - 1,
|
194 |
-
}
|
195 |
-
|
196 |
-
def calculate_adaptive_weight(
|
197 |
-
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
|
198 |
-
) -> torch.Tensor:
|
199 |
-
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
200 |
-
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
201 |
-
|
202 |
-
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
203 |
-
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
204 |
-
d_weight = d_weight * self.discriminator_weight
|
205 |
-
return d_weight
|
206 |
-
|
207 |
-
def forward(
|
208 |
-
self,
|
209 |
-
inputs: torch.Tensor,
|
210 |
-
reconstructions: torch.Tensor,
|
211 |
-
*, # added because I changed the order here
|
212 |
-
regularization_log: Dict[str, torch.Tensor],
|
213 |
-
optimizer_idx: int,
|
214 |
-
global_step: int,
|
215 |
-
last_layer: torch.Tensor,
|
216 |
-
split: str = "train",
|
217 |
-
weights: Union[None, float, torch.Tensor] = None,
|
218 |
-
) -> Tuple[torch.Tensor, dict]:
|
219 |
-
if self.scale_input_to_tgt_size:
|
220 |
-
inputs = torch.nn.functional.interpolate(
|
221 |
-
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
|
222 |
-
)
|
223 |
-
|
224 |
-
if self.dims > 2:
|
225 |
-
inputs, reconstructions = map(
|
226 |
-
lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
|
227 |
-
(inputs, reconstructions),
|
228 |
-
)
|
229 |
-
|
230 |
-
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
231 |
-
if self.perceptual_weight > 0:
|
232 |
-
p_loss = self.perceptual_loss(
|
233 |
-
inputs.contiguous(), reconstructions.contiguous()
|
234 |
-
)
|
235 |
-
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
236 |
-
|
237 |
-
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
|
238 |
-
|
239 |
-
# now the GAN part
|
240 |
-
if optimizer_idx == 0:
|
241 |
-
# generator update
|
242 |
-
if global_step >= self.discriminator_iter_start or not self.training:
|
243 |
-
logits_fake = self.discriminator(reconstructions.contiguous())
|
244 |
-
g_loss = -torch.mean(logits_fake)
|
245 |
-
if self.training:
|
246 |
-
d_weight = self.calculate_adaptive_weight(
|
247 |
-
nll_loss, g_loss, last_layer=last_layer
|
248 |
-
)
|
249 |
-
else:
|
250 |
-
d_weight = torch.tensor(1.0)
|
251 |
-
else:
|
252 |
-
d_weight = torch.tensor(0.0)
|
253 |
-
g_loss = torch.tensor(0.0, requires_grad=True)
|
254 |
-
|
255 |
-
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
|
256 |
-
log = dict()
|
257 |
-
for k in regularization_log:
|
258 |
-
if k in self.regularization_weights:
|
259 |
-
loss = loss + self.regularization_weights[k] * regularization_log[k]
|
260 |
-
if k in self.additional_log_keys:
|
261 |
-
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
|
262 |
-
|
263 |
-
log.update(
|
264 |
-
{
|
265 |
-
f"{split}/loss/total": loss.clone().detach().mean(),
|
266 |
-
f"{split}/loss/nll": nll_loss.detach().mean(),
|
267 |
-
f"{split}/loss/rec": rec_loss.detach().mean(),
|
268 |
-
f"{split}/loss/g": g_loss.detach().mean(),
|
269 |
-
f"{split}/scalars/logvar": self.logvar.detach(),
|
270 |
-
f"{split}/scalars/d_weight": d_weight.detach(),
|
271 |
-
}
|
272 |
-
)
|
273 |
-
|
274 |
-
return loss, log
|
275 |
-
elif optimizer_idx == 1:
|
276 |
-
# second pass for discriminator update
|
277 |
-
logits_real = self.discriminator(inputs.contiguous().detach())
|
278 |
-
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
279 |
-
|
280 |
-
if global_step >= self.discriminator_iter_start or not self.training:
|
281 |
-
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
|
282 |
-
else:
|
283 |
-
d_loss = torch.tensor(0.0, requires_grad=True)
|
284 |
-
|
285 |
-
log = {
|
286 |
-
f"{split}/loss/disc": d_loss.clone().detach().mean(),
|
287 |
-
f"{split}/logits/real": logits_real.detach().mean(),
|
288 |
-
f"{split}/logits/fake": logits_fake.detach().mean(),
|
289 |
-
}
|
290 |
-
return d_loss, log
|
291 |
-
else:
|
292 |
-
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
|
293 |
-
|
294 |
-
def get_nll_loss(
|
295 |
-
self,
|
296 |
-
rec_loss: torch.Tensor,
|
297 |
-
weights: Optional[Union[float, torch.Tensor]] = None,
|
298 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
299 |
-
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
300 |
-
weighted_nll_loss = nll_loss
|
301 |
-
if weights is not None:
|
302 |
-
weighted_nll_loss = weights * nll_loss
|
303 |
-
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
304 |
-
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
305 |
-
|
306 |
-
return nll_loss, weighted_nll_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/losses/lpips.py
DELETED
@@ -1,73 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
|
4 |
-
from ....util import default, instantiate_from_config
|
5 |
-
from ..lpips.loss.lpips import LPIPS
|
6 |
-
|
7 |
-
|
8 |
-
class LatentLPIPS(nn.Module):
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
decoder_config,
|
12 |
-
perceptual_weight=1.0,
|
13 |
-
latent_weight=1.0,
|
14 |
-
scale_input_to_tgt_size=False,
|
15 |
-
scale_tgt_to_input_size=False,
|
16 |
-
perceptual_weight_on_inputs=0.0,
|
17 |
-
):
|
18 |
-
super().__init__()
|
19 |
-
self.scale_input_to_tgt_size = scale_input_to_tgt_size
|
20 |
-
self.scale_tgt_to_input_size = scale_tgt_to_input_size
|
21 |
-
self.init_decoder(decoder_config)
|
22 |
-
self.perceptual_loss = LPIPS().eval()
|
23 |
-
self.perceptual_weight = perceptual_weight
|
24 |
-
self.latent_weight = latent_weight
|
25 |
-
self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
|
26 |
-
|
27 |
-
def init_decoder(self, config):
|
28 |
-
self.decoder = instantiate_from_config(config)
|
29 |
-
if hasattr(self.decoder, "encoder"):
|
30 |
-
del self.decoder.encoder
|
31 |
-
|
32 |
-
def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
|
33 |
-
log = dict()
|
34 |
-
loss = (latent_inputs - latent_predictions) ** 2
|
35 |
-
log[f"{split}/latent_l2_loss"] = loss.mean().detach()
|
36 |
-
image_reconstructions = None
|
37 |
-
if self.perceptual_weight > 0.0:
|
38 |
-
image_reconstructions = self.decoder.decode(latent_predictions)
|
39 |
-
image_targets = self.decoder.decode(latent_inputs)
|
40 |
-
perceptual_loss = self.perceptual_loss(
|
41 |
-
image_targets.contiguous(), image_reconstructions.contiguous()
|
42 |
-
)
|
43 |
-
loss = (
|
44 |
-
self.latent_weight * loss.mean()
|
45 |
-
+ self.perceptual_weight * perceptual_loss.mean()
|
46 |
-
)
|
47 |
-
log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
|
48 |
-
|
49 |
-
if self.perceptual_weight_on_inputs > 0.0:
|
50 |
-
image_reconstructions = default(
|
51 |
-
image_reconstructions, self.decoder.decode(latent_predictions)
|
52 |
-
)
|
53 |
-
if self.scale_input_to_tgt_size:
|
54 |
-
image_inputs = torch.nn.functional.interpolate(
|
55 |
-
image_inputs,
|
56 |
-
image_reconstructions.shape[2:],
|
57 |
-
mode="bicubic",
|
58 |
-
antialias=True,
|
59 |
-
)
|
60 |
-
elif self.scale_tgt_to_input_size:
|
61 |
-
image_reconstructions = torch.nn.functional.interpolate(
|
62 |
-
image_reconstructions,
|
63 |
-
image_inputs.shape[2:],
|
64 |
-
mode="bicubic",
|
65 |
-
antialias=True,
|
66 |
-
)
|
67 |
-
|
68 |
-
perceptual_loss2 = self.perceptual_loss(
|
69 |
-
image_inputs.contiguous(), image_reconstructions.contiguous()
|
70 |
-
)
|
71 |
-
loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
|
72 |
-
log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
|
73 |
-
return loss, log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/lpips/__init__.py
DELETED
File without changes
|
sgm/modules/autoencoding/lpips/loss/.gitignore
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
vgg.pth
|
|
|
|
sgm/modules/autoencoding/lpips/loss/LICENSE
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
|
2 |
-
All rights reserved.
|
3 |
-
|
4 |
-
Redistribution and use in source and binary forms, with or without
|
5 |
-
modification, are permitted provided that the following conditions are met:
|
6 |
-
|
7 |
-
* Redistributions of source code must retain the above copyright notice, this
|
8 |
-
list of conditions and the following disclaimer.
|
9 |
-
|
10 |
-
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
-
this list of conditions and the following disclaimer in the documentation
|
12 |
-
and/or other materials provided with the distribution.
|
13 |
-
|
14 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
-
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
-
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
-
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
-
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
-
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
-
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
-
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
-
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/lpips/loss/__init__.py
DELETED
File without changes
|
sgm/modules/autoencoding/lpips/loss/lpips.py
DELETED
@@ -1,147 +0,0 @@
|
|
1 |
-
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
-
|
3 |
-
from collections import namedtuple
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
from torchvision import models
|
8 |
-
|
9 |
-
from ..util import get_ckpt_path
|
10 |
-
|
11 |
-
|
12 |
-
class LPIPS(nn.Module):
|
13 |
-
# Learned perceptual metric
|
14 |
-
def __init__(self, use_dropout=True):
|
15 |
-
super().__init__()
|
16 |
-
self.scaling_layer = ScalingLayer()
|
17 |
-
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
18 |
-
self.net = vgg16(pretrained=True, requires_grad=False)
|
19 |
-
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
20 |
-
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
21 |
-
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
22 |
-
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
23 |
-
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
24 |
-
self.load_from_pretrained()
|
25 |
-
for param in self.parameters():
|
26 |
-
param.requires_grad = False
|
27 |
-
|
28 |
-
def load_from_pretrained(self, name="vgg_lpips"):
|
29 |
-
ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
|
30 |
-
self.load_state_dict(
|
31 |
-
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
32 |
-
)
|
33 |
-
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
34 |
-
|
35 |
-
@classmethod
|
36 |
-
def from_pretrained(cls, name="vgg_lpips"):
|
37 |
-
if name != "vgg_lpips":
|
38 |
-
raise NotImplementedError
|
39 |
-
model = cls()
|
40 |
-
ckpt = get_ckpt_path(name)
|
41 |
-
model.load_state_dict(
|
42 |
-
torch.load(ckpt, map_location=torch.device("cpu")), strict=False
|
43 |
-
)
|
44 |
-
return model
|
45 |
-
|
46 |
-
def forward(self, input, target):
|
47 |
-
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
48 |
-
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
49 |
-
feats0, feats1, diffs = {}, {}, {}
|
50 |
-
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
51 |
-
for kk in range(len(self.chns)):
|
52 |
-
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
|
53 |
-
outs1[kk]
|
54 |
-
)
|
55 |
-
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
56 |
-
|
57 |
-
res = [
|
58 |
-
spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
|
59 |
-
for kk in range(len(self.chns))
|
60 |
-
]
|
61 |
-
val = res[0]
|
62 |
-
for l in range(1, len(self.chns)):
|
63 |
-
val += res[l]
|
64 |
-
return val
|
65 |
-
|
66 |
-
|
67 |
-
class ScalingLayer(nn.Module):
|
68 |
-
def __init__(self):
|
69 |
-
super(ScalingLayer, self).__init__()
|
70 |
-
self.register_buffer(
|
71 |
-
"shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
|
72 |
-
)
|
73 |
-
self.register_buffer(
|
74 |
-
"scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
|
75 |
-
)
|
76 |
-
|
77 |
-
def forward(self, inp):
|
78 |
-
return (inp - self.shift) / self.scale
|
79 |
-
|
80 |
-
|
81 |
-
class NetLinLayer(nn.Module):
|
82 |
-
"""A single linear layer which does a 1x1 conv"""
|
83 |
-
|
84 |
-
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
85 |
-
super(NetLinLayer, self).__init__()
|
86 |
-
layers = (
|
87 |
-
[
|
88 |
-
nn.Dropout(),
|
89 |
-
]
|
90 |
-
if (use_dropout)
|
91 |
-
else []
|
92 |
-
)
|
93 |
-
layers += [
|
94 |
-
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
|
95 |
-
]
|
96 |
-
self.model = nn.Sequential(*layers)
|
97 |
-
|
98 |
-
|
99 |
-
class vgg16(torch.nn.Module):
|
100 |
-
def __init__(self, requires_grad=False, pretrained=True):
|
101 |
-
super(vgg16, self).__init__()
|
102 |
-
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
103 |
-
self.slice1 = torch.nn.Sequential()
|
104 |
-
self.slice2 = torch.nn.Sequential()
|
105 |
-
self.slice3 = torch.nn.Sequential()
|
106 |
-
self.slice4 = torch.nn.Sequential()
|
107 |
-
self.slice5 = torch.nn.Sequential()
|
108 |
-
self.N_slices = 5
|
109 |
-
for x in range(4):
|
110 |
-
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
111 |
-
for x in range(4, 9):
|
112 |
-
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
113 |
-
for x in range(9, 16):
|
114 |
-
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
115 |
-
for x in range(16, 23):
|
116 |
-
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
117 |
-
for x in range(23, 30):
|
118 |
-
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
119 |
-
if not requires_grad:
|
120 |
-
for param in self.parameters():
|
121 |
-
param.requires_grad = False
|
122 |
-
|
123 |
-
def forward(self, X):
|
124 |
-
h = self.slice1(X)
|
125 |
-
h_relu1_2 = h
|
126 |
-
h = self.slice2(h)
|
127 |
-
h_relu2_2 = h
|
128 |
-
h = self.slice3(h)
|
129 |
-
h_relu3_3 = h
|
130 |
-
h = self.slice4(h)
|
131 |
-
h_relu4_3 = h
|
132 |
-
h = self.slice5(h)
|
133 |
-
h_relu5_3 = h
|
134 |
-
vgg_outputs = namedtuple(
|
135 |
-
"VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
|
136 |
-
)
|
137 |
-
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
138 |
-
return out
|
139 |
-
|
140 |
-
|
141 |
-
def normalize_tensor(x, eps=1e-10):
|
142 |
-
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
|
143 |
-
return x / (norm_factor + eps)
|
144 |
-
|
145 |
-
|
146 |
-
def spatial_average(x, keepdim=True):
|
147 |
-
return x.mean([2, 3], keepdim=keepdim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/lpips/model/LICENSE
DELETED
@@ -1,58 +0,0 @@
|
|
1 |
-
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
|
2 |
-
All rights reserved.
|
3 |
-
|
4 |
-
Redistribution and use in source and binary forms, with or without
|
5 |
-
modification, are permitted provided that the following conditions are met:
|
6 |
-
|
7 |
-
* Redistributions of source code must retain the above copyright notice, this
|
8 |
-
list of conditions and the following disclaimer.
|
9 |
-
|
10 |
-
* Redistributions in binary form must reproduce the above copyright notice,
|
11 |
-
this list of conditions and the following disclaimer in the documentation
|
12 |
-
and/or other materials provided with the distribution.
|
13 |
-
|
14 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15 |
-
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16 |
-
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
17 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
18 |
-
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
19 |
-
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
20 |
-
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
21 |
-
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
22 |
-
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
23 |
-
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
24 |
-
|
25 |
-
|
26 |
-
--------------------------- LICENSE FOR pix2pix --------------------------------
|
27 |
-
BSD License
|
28 |
-
|
29 |
-
For pix2pix software
|
30 |
-
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
|
31 |
-
All rights reserved.
|
32 |
-
|
33 |
-
Redistribution and use in source and binary forms, with or without
|
34 |
-
modification, are permitted provided that the following conditions are met:
|
35 |
-
|
36 |
-
* Redistributions of source code must retain the above copyright notice, this
|
37 |
-
list of conditions and the following disclaimer.
|
38 |
-
|
39 |
-
* Redistributions in binary form must reproduce the above copyright notice,
|
40 |
-
this list of conditions and the following disclaimer in the documentation
|
41 |
-
and/or other materials provided with the distribution.
|
42 |
-
|
43 |
-
----------------------------- LICENSE FOR DCGAN --------------------------------
|
44 |
-
BSD License
|
45 |
-
|
46 |
-
For dcgan.torch software
|
47 |
-
|
48 |
-
Copyright (c) 2015, Facebook, Inc. All rights reserved.
|
49 |
-
|
50 |
-
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
51 |
-
|
52 |
-
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
53 |
-
|
54 |
-
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
55 |
-
|
56 |
-
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
57 |
-
|
58 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/lpips/model/__init__.py
DELETED
File without changes
|
sgm/modules/autoencoding/lpips/model/model.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
import functools
|
2 |
-
|
3 |
-
import torch.nn as nn
|
4 |
-
|
5 |
-
from ..util import ActNorm
|
6 |
-
|
7 |
-
|
8 |
-
def weights_init(m):
|
9 |
-
classname = m.__class__.__name__
|
10 |
-
if classname.find("Conv") != -1:
|
11 |
-
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
12 |
-
elif classname.find("BatchNorm") != -1:
|
13 |
-
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
14 |
-
nn.init.constant_(m.bias.data, 0)
|
15 |
-
|
16 |
-
|
17 |
-
class NLayerDiscriminator(nn.Module):
|
18 |
-
"""Defines a PatchGAN discriminator as in Pix2Pix
|
19 |
-
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
20 |
-
"""
|
21 |
-
|
22 |
-
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
23 |
-
"""Construct a PatchGAN discriminator
|
24 |
-
Parameters:
|
25 |
-
input_nc (int) -- the number of channels in input images
|
26 |
-
ndf (int) -- the number of filters in the last conv layer
|
27 |
-
n_layers (int) -- the number of conv layers in the discriminator
|
28 |
-
norm_layer -- normalization layer
|
29 |
-
"""
|
30 |
-
super(NLayerDiscriminator, self).__init__()
|
31 |
-
if not use_actnorm:
|
32 |
-
norm_layer = nn.BatchNorm2d
|
33 |
-
else:
|
34 |
-
norm_layer = ActNorm
|
35 |
-
if (
|
36 |
-
type(norm_layer) == functools.partial
|
37 |
-
): # no need to use bias as BatchNorm2d has affine parameters
|
38 |
-
use_bias = norm_layer.func != nn.BatchNorm2d
|
39 |
-
else:
|
40 |
-
use_bias = norm_layer != nn.BatchNorm2d
|
41 |
-
|
42 |
-
kw = 4
|
43 |
-
padw = 1
|
44 |
-
sequence = [
|
45 |
-
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
46 |
-
nn.LeakyReLU(0.2, True),
|
47 |
-
]
|
48 |
-
nf_mult = 1
|
49 |
-
nf_mult_prev = 1
|
50 |
-
for n in range(1, n_layers): # gradually increase the number of filters
|
51 |
-
nf_mult_prev = nf_mult
|
52 |
-
nf_mult = min(2**n, 8)
|
53 |
-
sequence += [
|
54 |
-
nn.Conv2d(
|
55 |
-
ndf * nf_mult_prev,
|
56 |
-
ndf * nf_mult,
|
57 |
-
kernel_size=kw,
|
58 |
-
stride=2,
|
59 |
-
padding=padw,
|
60 |
-
bias=use_bias,
|
61 |
-
),
|
62 |
-
norm_layer(ndf * nf_mult),
|
63 |
-
nn.LeakyReLU(0.2, True),
|
64 |
-
]
|
65 |
-
|
66 |
-
nf_mult_prev = nf_mult
|
67 |
-
nf_mult = min(2**n_layers, 8)
|
68 |
-
sequence += [
|
69 |
-
nn.Conv2d(
|
70 |
-
ndf * nf_mult_prev,
|
71 |
-
ndf * nf_mult,
|
72 |
-
kernel_size=kw,
|
73 |
-
stride=1,
|
74 |
-
padding=padw,
|
75 |
-
bias=use_bias,
|
76 |
-
),
|
77 |
-
norm_layer(ndf * nf_mult),
|
78 |
-
nn.LeakyReLU(0.2, True),
|
79 |
-
]
|
80 |
-
|
81 |
-
sequence += [
|
82 |
-
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
83 |
-
] # output 1 channel prediction map
|
84 |
-
self.main = nn.Sequential(*sequence)
|
85 |
-
|
86 |
-
def forward(self, input):
|
87 |
-
"""Standard forward."""
|
88 |
-
return self.main(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/lpips/util.py
DELETED
@@ -1,128 +0,0 @@
|
|
1 |
-
import hashlib
|
2 |
-
import os
|
3 |
-
|
4 |
-
import requests
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
from tqdm import tqdm
|
8 |
-
|
9 |
-
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
10 |
-
|
11 |
-
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
12 |
-
|
13 |
-
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
14 |
-
|
15 |
-
|
16 |
-
def download(url, local_path, chunk_size=1024):
|
17 |
-
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
18 |
-
with requests.get(url, stream=True) as r:
|
19 |
-
total_size = int(r.headers.get("content-length", 0))
|
20 |
-
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
21 |
-
with open(local_path, "wb") as f:
|
22 |
-
for data in r.iter_content(chunk_size=chunk_size):
|
23 |
-
if data:
|
24 |
-
f.write(data)
|
25 |
-
pbar.update(chunk_size)
|
26 |
-
|
27 |
-
|
28 |
-
def md5_hash(path):
|
29 |
-
with open(path, "rb") as f:
|
30 |
-
content = f.read()
|
31 |
-
return hashlib.md5(content).hexdigest()
|
32 |
-
|
33 |
-
|
34 |
-
def get_ckpt_path(name, root, check=False):
|
35 |
-
assert name in URL_MAP
|
36 |
-
path = os.path.join(root, CKPT_MAP[name])
|
37 |
-
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
38 |
-
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
39 |
-
download(URL_MAP[name], path)
|
40 |
-
md5 = md5_hash(path)
|
41 |
-
assert md5 == MD5_MAP[name], md5
|
42 |
-
return path
|
43 |
-
|
44 |
-
|
45 |
-
class ActNorm(nn.Module):
|
46 |
-
def __init__(
|
47 |
-
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
48 |
-
):
|
49 |
-
assert affine
|
50 |
-
super().__init__()
|
51 |
-
self.logdet = logdet
|
52 |
-
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
53 |
-
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
54 |
-
self.allow_reverse_init = allow_reverse_init
|
55 |
-
|
56 |
-
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
57 |
-
|
58 |
-
def initialize(self, input):
|
59 |
-
with torch.no_grad():
|
60 |
-
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
61 |
-
mean = (
|
62 |
-
flatten.mean(1)
|
63 |
-
.unsqueeze(1)
|
64 |
-
.unsqueeze(2)
|
65 |
-
.unsqueeze(3)
|
66 |
-
.permute(1, 0, 2, 3)
|
67 |
-
)
|
68 |
-
std = (
|
69 |
-
flatten.std(1)
|
70 |
-
.unsqueeze(1)
|
71 |
-
.unsqueeze(2)
|
72 |
-
.unsqueeze(3)
|
73 |
-
.permute(1, 0, 2, 3)
|
74 |
-
)
|
75 |
-
|
76 |
-
self.loc.data.copy_(-mean)
|
77 |
-
self.scale.data.copy_(1 / (std + 1e-6))
|
78 |
-
|
79 |
-
def forward(self, input, reverse=False):
|
80 |
-
if reverse:
|
81 |
-
return self.reverse(input)
|
82 |
-
if len(input.shape) == 2:
|
83 |
-
input = input[:, :, None, None]
|
84 |
-
squeeze = True
|
85 |
-
else:
|
86 |
-
squeeze = False
|
87 |
-
|
88 |
-
_, _, height, width = input.shape
|
89 |
-
|
90 |
-
if self.training and self.initialized.item() == 0:
|
91 |
-
self.initialize(input)
|
92 |
-
self.initialized.fill_(1)
|
93 |
-
|
94 |
-
h = self.scale * (input + self.loc)
|
95 |
-
|
96 |
-
if squeeze:
|
97 |
-
h = h.squeeze(-1).squeeze(-1)
|
98 |
-
|
99 |
-
if self.logdet:
|
100 |
-
log_abs = torch.log(torch.abs(self.scale))
|
101 |
-
logdet = height * width * torch.sum(log_abs)
|
102 |
-
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
103 |
-
return h, logdet
|
104 |
-
|
105 |
-
return h
|
106 |
-
|
107 |
-
def reverse(self, output):
|
108 |
-
if self.training and self.initialized.item() == 0:
|
109 |
-
if not self.allow_reverse_init:
|
110 |
-
raise RuntimeError(
|
111 |
-
"Initializing ActNorm in reverse direction is "
|
112 |
-
"disabled by default. Use allow_reverse_init=True to enable."
|
113 |
-
)
|
114 |
-
else:
|
115 |
-
self.initialize(output)
|
116 |
-
self.initialized.fill_(1)
|
117 |
-
|
118 |
-
if len(output.shape) == 2:
|
119 |
-
output = output[:, :, None, None]
|
120 |
-
squeeze = True
|
121 |
-
else:
|
122 |
-
squeeze = False
|
123 |
-
|
124 |
-
h = output / self.scale - self.loc
|
125 |
-
|
126 |
-
if squeeze:
|
127 |
-
h = h.squeeze(-1).squeeze(-1)
|
128 |
-
return h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/lpips/vqperceptual.py
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
-
|
4 |
-
|
5 |
-
def hinge_d_loss(logits_real, logits_fake):
|
6 |
-
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
7 |
-
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
8 |
-
d_loss = 0.5 * (loss_real + loss_fake)
|
9 |
-
return d_loss
|
10 |
-
|
11 |
-
|
12 |
-
def vanilla_d_loss(logits_real, logits_fake):
|
13 |
-
d_loss = 0.5 * (
|
14 |
-
torch.mean(torch.nn.functional.softplus(-logits_real))
|
15 |
-
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
16 |
-
)
|
17 |
-
return d_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/regularizers/__init__.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
from abc import abstractmethod
|
2 |
-
from typing import Any, Tuple
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
from ....modules.distributions.distributions import \
|
9 |
-
DiagonalGaussianDistribution
|
10 |
-
from .base import AbstractRegularizer
|
11 |
-
|
12 |
-
|
13 |
-
class DiagonalGaussianRegularizer(AbstractRegularizer):
|
14 |
-
def __init__(self, sample: bool = True):
|
15 |
-
super().__init__()
|
16 |
-
self.sample = sample
|
17 |
-
|
18 |
-
def get_trainable_parameters(self) -> Any:
|
19 |
-
yield from ()
|
20 |
-
|
21 |
-
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
22 |
-
log = dict()
|
23 |
-
posterior = DiagonalGaussianDistribution(z)
|
24 |
-
if self.sample:
|
25 |
-
z = posterior.sample()
|
26 |
-
else:
|
27 |
-
z = posterior.mode()
|
28 |
-
kl_loss = posterior.kl()
|
29 |
-
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
30 |
-
log["kl_loss"] = kl_loss
|
31 |
-
return z, log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/regularizers/base.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
from abc import abstractmethod
|
2 |
-
from typing import Any, Tuple
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
-
from torch import nn
|
7 |
-
|
8 |
-
|
9 |
-
class AbstractRegularizer(nn.Module):
|
10 |
-
def __init__(self):
|
11 |
-
super().__init__()
|
12 |
-
|
13 |
-
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
14 |
-
raise NotImplementedError()
|
15 |
-
|
16 |
-
@abstractmethod
|
17 |
-
def get_trainable_parameters(self) -> Any:
|
18 |
-
raise NotImplementedError()
|
19 |
-
|
20 |
-
|
21 |
-
class IdentityRegularizer(AbstractRegularizer):
|
22 |
-
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
|
23 |
-
return z, dict()
|
24 |
-
|
25 |
-
def get_trainable_parameters(self) -> Any:
|
26 |
-
yield from ()
|
27 |
-
|
28 |
-
|
29 |
-
def measure_perplexity(
|
30 |
-
predicted_indices: torch.Tensor, num_centroids: int
|
31 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
32 |
-
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
33 |
-
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
34 |
-
encodings = (
|
35 |
-
F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
|
36 |
-
)
|
37 |
-
avg_probs = encodings.mean(0)
|
38 |
-
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
39 |
-
cluster_use = torch.sum(avg_probs > 0)
|
40 |
-
return perplexity, cluster_use
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/regularizers/quantize.py
DELETED
@@ -1,487 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from abc import abstractmethod
|
3 |
-
from typing import Dict, Iterator, Literal, Optional, Tuple, Union
|
4 |
-
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from einops import rearrange
|
10 |
-
from torch import einsum
|
11 |
-
|
12 |
-
from .base import AbstractRegularizer, measure_perplexity
|
13 |
-
|
14 |
-
logpy = logging.getLogger(__name__)
|
15 |
-
|
16 |
-
|
17 |
-
class AbstractQuantizer(AbstractRegularizer):
|
18 |
-
def __init__(self):
|
19 |
-
super().__init__()
|
20 |
-
# Define these in your init
|
21 |
-
# shape (N,)
|
22 |
-
self.used: Optional[torch.Tensor]
|
23 |
-
self.re_embed: int
|
24 |
-
self.unknown_index: Union[Literal["random"], int]
|
25 |
-
|
26 |
-
def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
|
27 |
-
assert self.used is not None, "You need to define used indices for remap"
|
28 |
-
ishape = inds.shape
|
29 |
-
assert len(ishape) > 1
|
30 |
-
inds = inds.reshape(ishape[0], -1)
|
31 |
-
used = self.used.to(inds)
|
32 |
-
match = (inds[:, :, None] == used[None, None, ...]).long()
|
33 |
-
new = match.argmax(-1)
|
34 |
-
unknown = match.sum(2) < 1
|
35 |
-
if self.unknown_index == "random":
|
36 |
-
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
37 |
-
device=new.device
|
38 |
-
)
|
39 |
-
else:
|
40 |
-
new[unknown] = self.unknown_index
|
41 |
-
return new.reshape(ishape)
|
42 |
-
|
43 |
-
def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
|
44 |
-
assert self.used is not None, "You need to define used indices for remap"
|
45 |
-
ishape = inds.shape
|
46 |
-
assert len(ishape) > 1
|
47 |
-
inds = inds.reshape(ishape[0], -1)
|
48 |
-
used = self.used.to(inds)
|
49 |
-
if self.re_embed > self.used.shape[0]: # extra token
|
50 |
-
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
51 |
-
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
52 |
-
return back.reshape(ishape)
|
53 |
-
|
54 |
-
@abstractmethod
|
55 |
-
def get_codebook_entry(
|
56 |
-
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
57 |
-
) -> torch.Tensor:
|
58 |
-
raise NotImplementedError()
|
59 |
-
|
60 |
-
def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
|
61 |
-
yield from self.parameters()
|
62 |
-
|
63 |
-
|
64 |
-
class GumbelQuantizer(AbstractQuantizer):
|
65 |
-
"""
|
66 |
-
credit to @karpathy:
|
67 |
-
https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
|
68 |
-
Gumbel Softmax trick quantizer
|
69 |
-
Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
|
70 |
-
https://arxiv.org/abs/1611.01144
|
71 |
-
"""
|
72 |
-
|
73 |
-
def __init__(
|
74 |
-
self,
|
75 |
-
num_hiddens: int,
|
76 |
-
embedding_dim: int,
|
77 |
-
n_embed: int,
|
78 |
-
straight_through: bool = True,
|
79 |
-
kl_weight: float = 5e-4,
|
80 |
-
temp_init: float = 1.0,
|
81 |
-
remap: Optional[str] = None,
|
82 |
-
unknown_index: str = "random",
|
83 |
-
loss_key: str = "loss/vq",
|
84 |
-
) -> None:
|
85 |
-
super().__init__()
|
86 |
-
|
87 |
-
self.loss_key = loss_key
|
88 |
-
self.embedding_dim = embedding_dim
|
89 |
-
self.n_embed = n_embed
|
90 |
-
|
91 |
-
self.straight_through = straight_through
|
92 |
-
self.temperature = temp_init
|
93 |
-
self.kl_weight = kl_weight
|
94 |
-
|
95 |
-
self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
|
96 |
-
self.embed = nn.Embedding(n_embed, embedding_dim)
|
97 |
-
|
98 |
-
self.remap = remap
|
99 |
-
if self.remap is not None:
|
100 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
101 |
-
self.re_embed = self.used.shape[0]
|
102 |
-
else:
|
103 |
-
self.used = None
|
104 |
-
self.re_embed = n_embed
|
105 |
-
if unknown_index == "extra":
|
106 |
-
self.unknown_index = self.re_embed
|
107 |
-
self.re_embed = self.re_embed + 1
|
108 |
-
else:
|
109 |
-
assert unknown_index == "random" or isinstance(
|
110 |
-
unknown_index, int
|
111 |
-
), "unknown index needs to be 'random', 'extra' or any integer"
|
112 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
113 |
-
if self.remap is not None:
|
114 |
-
logpy.info(
|
115 |
-
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
116 |
-
f"Using {self.unknown_index} for unknown indices."
|
117 |
-
)
|
118 |
-
|
119 |
-
def forward(
|
120 |
-
self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
|
121 |
-
) -> Tuple[torch.Tensor, Dict]:
|
122 |
-
# force hard = True when we are in eval mode, as we must quantize.
|
123 |
-
# actually, always true seems to work
|
124 |
-
hard = self.straight_through if self.training else True
|
125 |
-
temp = self.temperature if temp is None else temp
|
126 |
-
out_dict = {}
|
127 |
-
logits = self.proj(z)
|
128 |
-
if self.remap is not None:
|
129 |
-
# continue only with used logits
|
130 |
-
full_zeros = torch.zeros_like(logits)
|
131 |
-
logits = logits[:, self.used, ...]
|
132 |
-
|
133 |
-
soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
|
134 |
-
if self.remap is not None:
|
135 |
-
# go back to all entries but unused set to zero
|
136 |
-
full_zeros[:, self.used, ...] = soft_one_hot
|
137 |
-
soft_one_hot = full_zeros
|
138 |
-
z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
139 |
-
|
140 |
-
# + kl divergence to the prior loss
|
141 |
-
qy = F.softmax(logits, dim=1)
|
142 |
-
diff = (
|
143 |
-
self.kl_weight
|
144 |
-
* torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
|
145 |
-
)
|
146 |
-
out_dict[self.loss_key] = diff
|
147 |
-
|
148 |
-
ind = soft_one_hot.argmax(dim=1)
|
149 |
-
out_dict["indices"] = ind
|
150 |
-
if self.remap is not None:
|
151 |
-
ind = self.remap_to_used(ind)
|
152 |
-
|
153 |
-
if return_logits:
|
154 |
-
out_dict["logits"] = logits
|
155 |
-
|
156 |
-
return z_q, out_dict
|
157 |
-
|
158 |
-
def get_codebook_entry(self, indices, shape):
|
159 |
-
# TODO: shape not yet optional
|
160 |
-
b, h, w, c = shape
|
161 |
-
assert b * h * w == indices.shape[0]
|
162 |
-
indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
|
163 |
-
if self.remap is not None:
|
164 |
-
indices = self.unmap_to_all(indices)
|
165 |
-
one_hot = (
|
166 |
-
F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
|
167 |
-
)
|
168 |
-
z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
|
169 |
-
return z_q
|
170 |
-
|
171 |
-
|
172 |
-
class VectorQuantizer(AbstractQuantizer):
|
173 |
-
"""
|
174 |
-
____________________________________________
|
175 |
-
Discretization bottleneck part of the VQ-VAE.
|
176 |
-
Inputs:
|
177 |
-
- n_e : number of embeddings
|
178 |
-
- e_dim : dimension of embedding
|
179 |
-
- beta : commitment cost used in loss term,
|
180 |
-
beta * ||z_e(x)-sg[e]||^2
|
181 |
-
_____________________________________________
|
182 |
-
"""
|
183 |
-
|
184 |
-
def __init__(
|
185 |
-
self,
|
186 |
-
n_e: int,
|
187 |
-
e_dim: int,
|
188 |
-
beta: float = 0.25,
|
189 |
-
remap: Optional[str] = None,
|
190 |
-
unknown_index: str = "random",
|
191 |
-
sane_index_shape: bool = False,
|
192 |
-
log_perplexity: bool = False,
|
193 |
-
embedding_weight_norm: bool = False,
|
194 |
-
loss_key: str = "loss/vq",
|
195 |
-
):
|
196 |
-
super().__init__()
|
197 |
-
self.n_e = n_e
|
198 |
-
self.e_dim = e_dim
|
199 |
-
self.beta = beta
|
200 |
-
self.loss_key = loss_key
|
201 |
-
|
202 |
-
if not embedding_weight_norm:
|
203 |
-
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
204 |
-
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
205 |
-
else:
|
206 |
-
self.embedding = torch.nn.utils.weight_norm(
|
207 |
-
nn.Embedding(self.n_e, self.e_dim), dim=1
|
208 |
-
)
|
209 |
-
|
210 |
-
self.remap = remap
|
211 |
-
if self.remap is not None:
|
212 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
213 |
-
self.re_embed = self.used.shape[0]
|
214 |
-
else:
|
215 |
-
self.used = None
|
216 |
-
self.re_embed = n_e
|
217 |
-
if unknown_index == "extra":
|
218 |
-
self.unknown_index = self.re_embed
|
219 |
-
self.re_embed = self.re_embed + 1
|
220 |
-
else:
|
221 |
-
assert unknown_index == "random" or isinstance(
|
222 |
-
unknown_index, int
|
223 |
-
), "unknown index needs to be 'random', 'extra' or any integer"
|
224 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
225 |
-
if self.remap is not None:
|
226 |
-
logpy.info(
|
227 |
-
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
228 |
-
f"Using {self.unknown_index} for unknown indices."
|
229 |
-
)
|
230 |
-
|
231 |
-
self.sane_index_shape = sane_index_shape
|
232 |
-
self.log_perplexity = log_perplexity
|
233 |
-
|
234 |
-
def forward(
|
235 |
-
self,
|
236 |
-
z: torch.Tensor,
|
237 |
-
) -> Tuple[torch.Tensor, Dict]:
|
238 |
-
do_reshape = z.ndim == 4
|
239 |
-
if do_reshape:
|
240 |
-
# # reshape z -> (batch, height, width, channel) and flatten
|
241 |
-
z = rearrange(z, "b c h w -> b h w c").contiguous()
|
242 |
-
|
243 |
-
else:
|
244 |
-
assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
|
245 |
-
z = z.contiguous()
|
246 |
-
|
247 |
-
z_flattened = z.view(-1, self.e_dim)
|
248 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
249 |
-
|
250 |
-
d = (
|
251 |
-
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
252 |
-
+ torch.sum(self.embedding.weight**2, dim=1)
|
253 |
-
- 2
|
254 |
-
* torch.einsum(
|
255 |
-
"bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
|
256 |
-
)
|
257 |
-
)
|
258 |
-
|
259 |
-
min_encoding_indices = torch.argmin(d, dim=1)
|
260 |
-
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
261 |
-
loss_dict = {}
|
262 |
-
if self.log_perplexity:
|
263 |
-
perplexity, cluster_usage = measure_perplexity(
|
264 |
-
min_encoding_indices.detach(), self.n_e
|
265 |
-
)
|
266 |
-
loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
|
267 |
-
|
268 |
-
# compute loss for embedding
|
269 |
-
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
270 |
-
(z_q - z.detach()) ** 2
|
271 |
-
)
|
272 |
-
loss_dict[self.loss_key] = loss
|
273 |
-
|
274 |
-
# preserve gradients
|
275 |
-
z_q = z + (z_q - z).detach()
|
276 |
-
|
277 |
-
# reshape back to match original input shape
|
278 |
-
if do_reshape:
|
279 |
-
z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
|
280 |
-
|
281 |
-
if self.remap is not None:
|
282 |
-
min_encoding_indices = min_encoding_indices.reshape(
|
283 |
-
z.shape[0], -1
|
284 |
-
) # add batch axis
|
285 |
-
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
286 |
-
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
287 |
-
|
288 |
-
if self.sane_index_shape:
|
289 |
-
if do_reshape:
|
290 |
-
min_encoding_indices = min_encoding_indices.reshape(
|
291 |
-
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
292 |
-
)
|
293 |
-
else:
|
294 |
-
min_encoding_indices = rearrange(
|
295 |
-
min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
|
296 |
-
)
|
297 |
-
|
298 |
-
loss_dict["min_encoding_indices"] = min_encoding_indices
|
299 |
-
|
300 |
-
return z_q, loss_dict
|
301 |
-
|
302 |
-
def get_codebook_entry(
|
303 |
-
self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
|
304 |
-
) -> torch.Tensor:
|
305 |
-
# shape specifying (batch, height, width, channel)
|
306 |
-
if self.remap is not None:
|
307 |
-
assert shape is not None, "Need to give shape for remap"
|
308 |
-
indices = indices.reshape(shape[0], -1) # add batch axis
|
309 |
-
indices = self.unmap_to_all(indices)
|
310 |
-
indices = indices.reshape(-1) # flatten again
|
311 |
-
|
312 |
-
# get quantized latent vectors
|
313 |
-
z_q = self.embedding(indices)
|
314 |
-
|
315 |
-
if shape is not None:
|
316 |
-
z_q = z_q.view(shape)
|
317 |
-
# reshape back to match original input shape
|
318 |
-
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
319 |
-
|
320 |
-
return z_q
|
321 |
-
|
322 |
-
|
323 |
-
class EmbeddingEMA(nn.Module):
|
324 |
-
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
|
325 |
-
super().__init__()
|
326 |
-
self.decay = decay
|
327 |
-
self.eps = eps
|
328 |
-
weight = torch.randn(num_tokens, codebook_dim)
|
329 |
-
self.weight = nn.Parameter(weight, requires_grad=False)
|
330 |
-
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
331 |
-
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
332 |
-
self.update = True
|
333 |
-
|
334 |
-
def forward(self, embed_id):
|
335 |
-
return F.embedding(embed_id, self.weight)
|
336 |
-
|
337 |
-
def cluster_size_ema_update(self, new_cluster_size):
|
338 |
-
self.cluster_size.data.mul_(self.decay).add_(
|
339 |
-
new_cluster_size, alpha=1 - self.decay
|
340 |
-
)
|
341 |
-
|
342 |
-
def embed_avg_ema_update(self, new_embed_avg):
|
343 |
-
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
344 |
-
|
345 |
-
def weight_update(self, num_tokens):
|
346 |
-
n = self.cluster_size.sum()
|
347 |
-
smoothed_cluster_size = (
|
348 |
-
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
349 |
-
)
|
350 |
-
# normalize embedding average with smoothed cluster size
|
351 |
-
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
352 |
-
self.weight.data.copy_(embed_normalized)
|
353 |
-
|
354 |
-
|
355 |
-
class EMAVectorQuantizer(AbstractQuantizer):
|
356 |
-
def __init__(
|
357 |
-
self,
|
358 |
-
n_embed: int,
|
359 |
-
embedding_dim: int,
|
360 |
-
beta: float,
|
361 |
-
decay: float = 0.99,
|
362 |
-
eps: float = 1e-5,
|
363 |
-
remap: Optional[str] = None,
|
364 |
-
unknown_index: str = "random",
|
365 |
-
loss_key: str = "loss/vq",
|
366 |
-
):
|
367 |
-
super().__init__()
|
368 |
-
self.codebook_dim = embedding_dim
|
369 |
-
self.num_tokens = n_embed
|
370 |
-
self.beta = beta
|
371 |
-
self.loss_key = loss_key
|
372 |
-
|
373 |
-
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
|
374 |
-
|
375 |
-
self.remap = remap
|
376 |
-
if self.remap is not None:
|
377 |
-
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
378 |
-
self.re_embed = self.used.shape[0]
|
379 |
-
else:
|
380 |
-
self.used = None
|
381 |
-
self.re_embed = n_embed
|
382 |
-
if unknown_index == "extra":
|
383 |
-
self.unknown_index = self.re_embed
|
384 |
-
self.re_embed = self.re_embed + 1
|
385 |
-
else:
|
386 |
-
assert unknown_index == "random" or isinstance(
|
387 |
-
unknown_index, int
|
388 |
-
), "unknown index needs to be 'random', 'extra' or any integer"
|
389 |
-
self.unknown_index = unknown_index # "random" or "extra" or integer
|
390 |
-
if self.remap is not None:
|
391 |
-
logpy.info(
|
392 |
-
f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
|
393 |
-
f"Using {self.unknown_index} for unknown indices."
|
394 |
-
)
|
395 |
-
|
396 |
-
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
397 |
-
# reshape z -> (batch, height, width, channel) and flatten
|
398 |
-
# z, 'b c h w -> b h w c'
|
399 |
-
z = rearrange(z, "b c h w -> b h w c")
|
400 |
-
z_flattened = z.reshape(-1, self.codebook_dim)
|
401 |
-
|
402 |
-
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
403 |
-
d = (
|
404 |
-
z_flattened.pow(2).sum(dim=1, keepdim=True)
|
405 |
-
+ self.embedding.weight.pow(2).sum(dim=1)
|
406 |
-
- 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
|
407 |
-
) # 'n d -> d n'
|
408 |
-
|
409 |
-
encoding_indices = torch.argmin(d, dim=1)
|
410 |
-
|
411 |
-
z_q = self.embedding(encoding_indices).view(z.shape)
|
412 |
-
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
413 |
-
avg_probs = torch.mean(encodings, dim=0)
|
414 |
-
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
|
415 |
-
|
416 |
-
if self.training and self.embedding.update:
|
417 |
-
# EMA cluster size
|
418 |
-
encodings_sum = encodings.sum(0)
|
419 |
-
self.embedding.cluster_size_ema_update(encodings_sum)
|
420 |
-
# EMA embedding average
|
421 |
-
embed_sum = encodings.transpose(0, 1) @ z_flattened
|
422 |
-
self.embedding.embed_avg_ema_update(embed_sum)
|
423 |
-
# normalize embed_avg and update weight
|
424 |
-
self.embedding.weight_update(self.num_tokens)
|
425 |
-
|
426 |
-
# compute loss for embedding
|
427 |
-
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
428 |
-
|
429 |
-
# preserve gradients
|
430 |
-
z_q = z + (z_q - z).detach()
|
431 |
-
|
432 |
-
# reshape back to match original input shape
|
433 |
-
# z_q, 'b h w c -> b c h w'
|
434 |
-
z_q = rearrange(z_q, "b h w c -> b c h w")
|
435 |
-
|
436 |
-
out_dict = {
|
437 |
-
self.loss_key: loss,
|
438 |
-
"encodings": encodings,
|
439 |
-
"encoding_indices": encoding_indices,
|
440 |
-
"perplexity": perplexity,
|
441 |
-
}
|
442 |
-
|
443 |
-
return z_q, out_dict
|
444 |
-
|
445 |
-
|
446 |
-
class VectorQuantizerWithInputProjection(VectorQuantizer):
|
447 |
-
def __init__(
|
448 |
-
self,
|
449 |
-
input_dim: int,
|
450 |
-
n_codes: int,
|
451 |
-
codebook_dim: int,
|
452 |
-
beta: float = 1.0,
|
453 |
-
output_dim: Optional[int] = None,
|
454 |
-
**kwargs,
|
455 |
-
):
|
456 |
-
super().__init__(n_codes, codebook_dim, beta, **kwargs)
|
457 |
-
self.proj_in = nn.Linear(input_dim, codebook_dim)
|
458 |
-
self.output_dim = output_dim
|
459 |
-
if output_dim is not None:
|
460 |
-
self.proj_out = nn.Linear(codebook_dim, output_dim)
|
461 |
-
else:
|
462 |
-
self.proj_out = nn.Identity()
|
463 |
-
|
464 |
-
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
|
465 |
-
rearr = False
|
466 |
-
in_shape = z.shape
|
467 |
-
|
468 |
-
if z.ndim > 3:
|
469 |
-
rearr = self.output_dim is not None
|
470 |
-
z = rearrange(z, "b c ... -> b (...) c")
|
471 |
-
z = self.proj_in(z)
|
472 |
-
z_q, loss_dict = super().forward(z)
|
473 |
-
|
474 |
-
z_q = self.proj_out(z_q)
|
475 |
-
if rearr:
|
476 |
-
if len(in_shape) == 4:
|
477 |
-
z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
|
478 |
-
elif len(in_shape) == 5:
|
479 |
-
z_q = rearrange(
|
480 |
-
z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
|
481 |
-
)
|
482 |
-
else:
|
483 |
-
raise NotImplementedError(
|
484 |
-
f"rearranging not available for {len(in_shape)}-dimensional input."
|
485 |
-
)
|
486 |
-
|
487 |
-
return z_q, loss_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/autoencoding/temporal_ae.py
DELETED
@@ -1,349 +0,0 @@
|
|
1 |
-
from typing import Callable, Iterable, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from einops import rearrange, repeat
|
5 |
-
|
6 |
-
from sgm.modules.diffusionmodules.model import (
|
7 |
-
XFORMERS_IS_AVAILABLE,
|
8 |
-
AttnBlock,
|
9 |
-
Decoder,
|
10 |
-
MemoryEfficientAttnBlock,
|
11 |
-
ResnetBlock,
|
12 |
-
)
|
13 |
-
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
|
14 |
-
from sgm.modules.video_attention import VideoTransformerBlock
|
15 |
-
from sgm.util import partialclass
|
16 |
-
|
17 |
-
|
18 |
-
class VideoResBlock(ResnetBlock):
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
out_channels,
|
22 |
-
*args,
|
23 |
-
dropout=0.0,
|
24 |
-
video_kernel_size=3,
|
25 |
-
alpha=0.0,
|
26 |
-
merge_strategy="learned",
|
27 |
-
**kwargs,
|
28 |
-
):
|
29 |
-
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
|
30 |
-
if video_kernel_size is None:
|
31 |
-
video_kernel_size = [3, 1, 1]
|
32 |
-
self.time_stack = ResBlock(
|
33 |
-
channels=out_channels,
|
34 |
-
emb_channels=0,
|
35 |
-
dropout=dropout,
|
36 |
-
dims=3,
|
37 |
-
use_scale_shift_norm=False,
|
38 |
-
use_conv=False,
|
39 |
-
up=False,
|
40 |
-
down=False,
|
41 |
-
kernel_size=video_kernel_size,
|
42 |
-
use_checkpoint=False,
|
43 |
-
skip_t_emb=True,
|
44 |
-
)
|
45 |
-
|
46 |
-
self.merge_strategy = merge_strategy
|
47 |
-
if self.merge_strategy == "fixed":
|
48 |
-
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
49 |
-
elif self.merge_strategy == "learned":
|
50 |
-
self.register_parameter(
|
51 |
-
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
52 |
-
)
|
53 |
-
else:
|
54 |
-
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
55 |
-
|
56 |
-
def get_alpha(self, bs):
|
57 |
-
if self.merge_strategy == "fixed":
|
58 |
-
return self.mix_factor
|
59 |
-
elif self.merge_strategy == "learned":
|
60 |
-
return torch.sigmoid(self.mix_factor)
|
61 |
-
else:
|
62 |
-
raise NotImplementedError()
|
63 |
-
|
64 |
-
def forward(self, x, temb, skip_video=False, timesteps=None):
|
65 |
-
if timesteps is None:
|
66 |
-
timesteps = self.timesteps
|
67 |
-
|
68 |
-
b, c, h, w = x.shape
|
69 |
-
|
70 |
-
x = super().forward(x, temb)
|
71 |
-
|
72 |
-
if not skip_video:
|
73 |
-
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
74 |
-
|
75 |
-
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
76 |
-
|
77 |
-
x = self.time_stack(x, temb)
|
78 |
-
|
79 |
-
alpha = self.get_alpha(bs=b // timesteps)
|
80 |
-
x = alpha * x + (1.0 - alpha) * x_mix
|
81 |
-
|
82 |
-
x = rearrange(x, "b c t h w -> (b t) c h w")
|
83 |
-
return x
|
84 |
-
|
85 |
-
|
86 |
-
class AE3DConv(torch.nn.Conv2d):
|
87 |
-
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
|
88 |
-
super().__init__(in_channels, out_channels, *args, **kwargs)
|
89 |
-
if isinstance(video_kernel_size, Iterable):
|
90 |
-
padding = [int(k // 2) for k in video_kernel_size]
|
91 |
-
else:
|
92 |
-
padding = int(video_kernel_size // 2)
|
93 |
-
|
94 |
-
self.time_mix_conv = torch.nn.Conv3d(
|
95 |
-
in_channels=out_channels,
|
96 |
-
out_channels=out_channels,
|
97 |
-
kernel_size=video_kernel_size,
|
98 |
-
padding=padding,
|
99 |
-
)
|
100 |
-
|
101 |
-
def forward(self, input, timesteps, skip_video=False):
|
102 |
-
x = super().forward(input)
|
103 |
-
if skip_video:
|
104 |
-
return x
|
105 |
-
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
|
106 |
-
x = self.time_mix_conv(x)
|
107 |
-
return rearrange(x, "b c t h w -> (b t) c h w")
|
108 |
-
|
109 |
-
|
110 |
-
class VideoBlock(AttnBlock):
|
111 |
-
def __init__(
|
112 |
-
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
113 |
-
):
|
114 |
-
super().__init__(in_channels)
|
115 |
-
# no context, single headed, as in base class
|
116 |
-
self.time_mix_block = VideoTransformerBlock(
|
117 |
-
dim=in_channels,
|
118 |
-
n_heads=1,
|
119 |
-
d_head=in_channels,
|
120 |
-
checkpoint=False,
|
121 |
-
ff_in=True,
|
122 |
-
attn_mode="softmax",
|
123 |
-
)
|
124 |
-
|
125 |
-
time_embed_dim = self.in_channels * 4
|
126 |
-
self.video_time_embed = torch.nn.Sequential(
|
127 |
-
torch.nn.Linear(self.in_channels, time_embed_dim),
|
128 |
-
torch.nn.SiLU(),
|
129 |
-
torch.nn.Linear(time_embed_dim, self.in_channels),
|
130 |
-
)
|
131 |
-
|
132 |
-
self.merge_strategy = merge_strategy
|
133 |
-
if self.merge_strategy == "fixed":
|
134 |
-
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
135 |
-
elif self.merge_strategy == "learned":
|
136 |
-
self.register_parameter(
|
137 |
-
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
138 |
-
)
|
139 |
-
else:
|
140 |
-
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
141 |
-
|
142 |
-
def forward(self, x, timesteps, skip_video=False):
|
143 |
-
if skip_video:
|
144 |
-
return super().forward(x)
|
145 |
-
|
146 |
-
x_in = x
|
147 |
-
x = self.attention(x)
|
148 |
-
h, w = x.shape[2:]
|
149 |
-
x = rearrange(x, "b c h w -> b (h w) c")
|
150 |
-
|
151 |
-
x_mix = x
|
152 |
-
num_frames = torch.arange(timesteps, device=x.device)
|
153 |
-
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
154 |
-
num_frames = rearrange(num_frames, "b t -> (b t)")
|
155 |
-
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
156 |
-
emb = self.video_time_embed(t_emb) # b, n_channels
|
157 |
-
emb = emb[:, None, :]
|
158 |
-
x_mix = x_mix + emb
|
159 |
-
|
160 |
-
alpha = self.get_alpha()
|
161 |
-
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
162 |
-
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
163 |
-
|
164 |
-
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
165 |
-
x = self.proj_out(x)
|
166 |
-
|
167 |
-
return x_in + x
|
168 |
-
|
169 |
-
def get_alpha(
|
170 |
-
self,
|
171 |
-
):
|
172 |
-
if self.merge_strategy == "fixed":
|
173 |
-
return self.mix_factor
|
174 |
-
elif self.merge_strategy == "learned":
|
175 |
-
return torch.sigmoid(self.mix_factor)
|
176 |
-
else:
|
177 |
-
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
178 |
-
|
179 |
-
|
180 |
-
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
|
181 |
-
def __init__(
|
182 |
-
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
|
183 |
-
):
|
184 |
-
super().__init__(in_channels)
|
185 |
-
# no context, single headed, as in base class
|
186 |
-
self.time_mix_block = VideoTransformerBlock(
|
187 |
-
dim=in_channels,
|
188 |
-
n_heads=1,
|
189 |
-
d_head=in_channels,
|
190 |
-
checkpoint=False,
|
191 |
-
ff_in=True,
|
192 |
-
attn_mode="softmax-xformers",
|
193 |
-
)
|
194 |
-
|
195 |
-
time_embed_dim = self.in_channels * 4
|
196 |
-
self.video_time_embed = torch.nn.Sequential(
|
197 |
-
torch.nn.Linear(self.in_channels, time_embed_dim),
|
198 |
-
torch.nn.SiLU(),
|
199 |
-
torch.nn.Linear(time_embed_dim, self.in_channels),
|
200 |
-
)
|
201 |
-
|
202 |
-
self.merge_strategy = merge_strategy
|
203 |
-
if self.merge_strategy == "fixed":
|
204 |
-
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
205 |
-
elif self.merge_strategy == "learned":
|
206 |
-
self.register_parameter(
|
207 |
-
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
208 |
-
)
|
209 |
-
else:
|
210 |
-
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
211 |
-
|
212 |
-
def forward(self, x, timesteps, skip_time_block=False):
|
213 |
-
if skip_time_block:
|
214 |
-
return super().forward(x)
|
215 |
-
|
216 |
-
x_in = x
|
217 |
-
x = self.attention(x)
|
218 |
-
h, w = x.shape[2:]
|
219 |
-
x = rearrange(x, "b c h w -> b (h w) c")
|
220 |
-
|
221 |
-
x_mix = x
|
222 |
-
num_frames = torch.arange(timesteps, device=x.device)
|
223 |
-
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
|
224 |
-
num_frames = rearrange(num_frames, "b t -> (b t)")
|
225 |
-
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
|
226 |
-
emb = self.video_time_embed(t_emb) # b, n_channels
|
227 |
-
emb = emb[:, None, :]
|
228 |
-
x_mix = x_mix + emb
|
229 |
-
|
230 |
-
alpha = self.get_alpha()
|
231 |
-
x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
|
232 |
-
x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
|
233 |
-
|
234 |
-
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
235 |
-
x = self.proj_out(x)
|
236 |
-
|
237 |
-
return x_in + x
|
238 |
-
|
239 |
-
def get_alpha(
|
240 |
-
self,
|
241 |
-
):
|
242 |
-
if self.merge_strategy == "fixed":
|
243 |
-
return self.mix_factor
|
244 |
-
elif self.merge_strategy == "learned":
|
245 |
-
return torch.sigmoid(self.mix_factor)
|
246 |
-
else:
|
247 |
-
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
|
248 |
-
|
249 |
-
|
250 |
-
def make_time_attn(
|
251 |
-
in_channels,
|
252 |
-
attn_type="vanilla",
|
253 |
-
attn_kwargs=None,
|
254 |
-
alpha: float = 0,
|
255 |
-
merge_strategy: str = "learned",
|
256 |
-
):
|
257 |
-
assert attn_type in [
|
258 |
-
"vanilla",
|
259 |
-
"vanilla-xformers",
|
260 |
-
], f"attn_type {attn_type} not supported for spatio-temporal attention"
|
261 |
-
print(
|
262 |
-
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
|
263 |
-
)
|
264 |
-
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
|
265 |
-
print(
|
266 |
-
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
|
267 |
-
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
268 |
-
)
|
269 |
-
attn_type = "vanilla"
|
270 |
-
|
271 |
-
if attn_type == "vanilla":
|
272 |
-
assert attn_kwargs is None
|
273 |
-
return partialclass(
|
274 |
-
VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
275 |
-
)
|
276 |
-
elif attn_type == "vanilla-xformers":
|
277 |
-
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
278 |
-
return partialclass(
|
279 |
-
MemoryEfficientVideoBlock,
|
280 |
-
in_channels,
|
281 |
-
alpha=alpha,
|
282 |
-
merge_strategy=merge_strategy,
|
283 |
-
)
|
284 |
-
else:
|
285 |
-
return NotImplementedError()
|
286 |
-
|
287 |
-
|
288 |
-
class Conv2DWrapper(torch.nn.Conv2d):
|
289 |
-
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
|
290 |
-
return super().forward(input)
|
291 |
-
|
292 |
-
|
293 |
-
class VideoDecoder(Decoder):
|
294 |
-
available_time_modes = ["all", "conv-only", "attn-only"]
|
295 |
-
|
296 |
-
def __init__(
|
297 |
-
self,
|
298 |
-
*args,
|
299 |
-
video_kernel_size: Union[int, list] = 3,
|
300 |
-
alpha: float = 0.0,
|
301 |
-
merge_strategy: str = "learned",
|
302 |
-
time_mode: str = "conv-only",
|
303 |
-
**kwargs,
|
304 |
-
):
|
305 |
-
self.video_kernel_size = video_kernel_size
|
306 |
-
self.alpha = alpha
|
307 |
-
self.merge_strategy = merge_strategy
|
308 |
-
self.time_mode = time_mode
|
309 |
-
assert (
|
310 |
-
self.time_mode in self.available_time_modes
|
311 |
-
), f"time_mode parameter has to be in {self.available_time_modes}"
|
312 |
-
super().__init__(*args, **kwargs)
|
313 |
-
|
314 |
-
def get_last_layer(self, skip_time_mix=False, **kwargs):
|
315 |
-
if self.time_mode == "attn-only":
|
316 |
-
raise NotImplementedError("TODO")
|
317 |
-
else:
|
318 |
-
return (
|
319 |
-
self.conv_out.time_mix_conv.weight
|
320 |
-
if not skip_time_mix
|
321 |
-
else self.conv_out.weight
|
322 |
-
)
|
323 |
-
|
324 |
-
def _make_attn(self) -> Callable:
|
325 |
-
if self.time_mode not in ["conv-only", "only-last-conv"]:
|
326 |
-
return partialclass(
|
327 |
-
make_time_attn,
|
328 |
-
alpha=self.alpha,
|
329 |
-
merge_strategy=self.merge_strategy,
|
330 |
-
)
|
331 |
-
else:
|
332 |
-
return super()._make_attn()
|
333 |
-
|
334 |
-
def _make_conv(self) -> Callable:
|
335 |
-
if self.time_mode != "attn-only":
|
336 |
-
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
|
337 |
-
else:
|
338 |
-
return Conv2DWrapper
|
339 |
-
|
340 |
-
def _make_resblock(self) -> Callable:
|
341 |
-
if self.time_mode not in ["attn-only", "only-last-conv"]:
|
342 |
-
return partialclass(
|
343 |
-
VideoResBlock,
|
344 |
-
video_kernel_size=self.video_kernel_size,
|
345 |
-
alpha=self.alpha,
|
346 |
-
merge_strategy=self.merge_strategy,
|
347 |
-
)
|
348 |
-
else:
|
349 |
-
return super()._make_resblock()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/__init__.py
DELETED
File without changes
|
sgm/modules/diffusionmodules/denoiser.py
DELETED
@@ -1,75 +0,0 @@
|
|
1 |
-
from typing import Dict, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
from ...util import append_dims, instantiate_from_config
|
7 |
-
from .denoiser_scaling import DenoiserScaling
|
8 |
-
from .discretizer import Discretization
|
9 |
-
|
10 |
-
|
11 |
-
class Denoiser(nn.Module):
|
12 |
-
def __init__(self, scaling_config: Dict):
|
13 |
-
super().__init__()
|
14 |
-
|
15 |
-
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
|
16 |
-
|
17 |
-
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
18 |
-
return sigma
|
19 |
-
|
20 |
-
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
21 |
-
return c_noise
|
22 |
-
|
23 |
-
def forward(
|
24 |
-
self,
|
25 |
-
network: nn.Module,
|
26 |
-
input: torch.Tensor,
|
27 |
-
sigma: torch.Tensor,
|
28 |
-
cond: Dict,
|
29 |
-
**additional_model_inputs,
|
30 |
-
) -> torch.Tensor:
|
31 |
-
sigma = self.possibly_quantize_sigma(sigma)
|
32 |
-
sigma_shape = sigma.shape
|
33 |
-
sigma = append_dims(sigma, input.ndim)
|
34 |
-
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
|
35 |
-
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
|
36 |
-
return (
|
37 |
-
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
|
38 |
-
+ input * c_skip
|
39 |
-
)
|
40 |
-
|
41 |
-
|
42 |
-
class DiscreteDenoiser(Denoiser):
|
43 |
-
def __init__(
|
44 |
-
self,
|
45 |
-
scaling_config: Dict,
|
46 |
-
num_idx: int,
|
47 |
-
discretization_config: Dict,
|
48 |
-
do_append_zero: bool = False,
|
49 |
-
quantize_c_noise: bool = True,
|
50 |
-
flip: bool = True,
|
51 |
-
):
|
52 |
-
super().__init__(scaling_config)
|
53 |
-
self.discretization: Discretization = instantiate_from_config(
|
54 |
-
discretization_config
|
55 |
-
)
|
56 |
-
sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
|
57 |
-
self.register_buffer("sigmas", sigmas)
|
58 |
-
self.quantize_c_noise = quantize_c_noise
|
59 |
-
self.num_idx = num_idx
|
60 |
-
|
61 |
-
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
|
62 |
-
dists = sigma - self.sigmas[:, None]
|
63 |
-
return dists.abs().argmin(dim=0).view(sigma.shape)
|
64 |
-
|
65 |
-
def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
|
66 |
-
return self.sigmas[idx]
|
67 |
-
|
68 |
-
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
|
69 |
-
return self.idx_to_sigma(self.sigma_to_idx(sigma))
|
70 |
-
|
71 |
-
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
|
72 |
-
if self.quantize_c_noise:
|
73 |
-
return self.sigma_to_idx(c_noise)
|
74 |
-
else:
|
75 |
-
return c_noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/denoiser_scaling.py
DELETED
@@ -1,59 +0,0 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
-
from typing import Tuple
|
3 |
-
|
4 |
-
import torch
|
5 |
-
|
6 |
-
|
7 |
-
class DenoiserScaling(ABC):
|
8 |
-
@abstractmethod
|
9 |
-
def __call__(
|
10 |
-
self, sigma: torch.Tensor
|
11 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
12 |
-
pass
|
13 |
-
|
14 |
-
|
15 |
-
class EDMScaling:
|
16 |
-
def __init__(self, sigma_data: float = 0.5):
|
17 |
-
self.sigma_data = sigma_data
|
18 |
-
|
19 |
-
def __call__(
|
20 |
-
self, sigma: torch.Tensor
|
21 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
22 |
-
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
23 |
-
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
24 |
-
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
25 |
-
c_noise = 0.25 * sigma.log()
|
26 |
-
return c_skip, c_out, c_in, c_noise
|
27 |
-
|
28 |
-
|
29 |
-
class EpsScaling:
|
30 |
-
def __call__(
|
31 |
-
self, sigma: torch.Tensor
|
32 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
33 |
-
c_skip = torch.ones_like(sigma, device=sigma.device)
|
34 |
-
c_out = -sigma
|
35 |
-
c_in = 1 / (sigma**2 + 1.0) ** 0.5
|
36 |
-
c_noise = sigma.clone()
|
37 |
-
return c_skip, c_out, c_in, c_noise
|
38 |
-
|
39 |
-
|
40 |
-
class VScaling:
|
41 |
-
def __call__(
|
42 |
-
self, sigma: torch.Tensor
|
43 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
44 |
-
c_skip = 1.0 / (sigma**2 + 1.0)
|
45 |
-
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
46 |
-
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
47 |
-
c_noise = sigma.clone()
|
48 |
-
return c_skip, c_out, c_in, c_noise
|
49 |
-
|
50 |
-
|
51 |
-
class VScalingWithEDMcNoise(DenoiserScaling):
|
52 |
-
def __call__(
|
53 |
-
self, sigma: torch.Tensor
|
54 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
55 |
-
c_skip = 1.0 / (sigma**2 + 1.0)
|
56 |
-
c_out = -sigma / (sigma**2 + 1.0) ** 0.5
|
57 |
-
c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
|
58 |
-
c_noise = 0.25 * sigma.log()
|
59 |
-
return c_skip, c_out, c_in, c_noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/denoiser_weighting.py
DELETED
@@ -1,24 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
class UnitWeighting:
|
5 |
-
def __call__(self, sigma):
|
6 |
-
return torch.ones_like(sigma, device=sigma.device)
|
7 |
-
|
8 |
-
|
9 |
-
class EDMWeighting:
|
10 |
-
def __init__(self, sigma_data=0.5):
|
11 |
-
self.sigma_data = sigma_data
|
12 |
-
|
13 |
-
def __call__(self, sigma):
|
14 |
-
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
15 |
-
|
16 |
-
|
17 |
-
class VWeighting(EDMWeighting):
|
18 |
-
def __init__(self):
|
19 |
-
super().__init__(sigma_data=1.0)
|
20 |
-
|
21 |
-
|
22 |
-
class EpsWeighting:
|
23 |
-
def __call__(self, sigma):
|
24 |
-
return sigma**-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/discretizer.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
from abc import abstractmethod
|
2 |
-
from functools import partial
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
|
7 |
-
from ...modules.diffusionmodules.util import make_beta_schedule
|
8 |
-
from ...util import append_zero
|
9 |
-
|
10 |
-
|
11 |
-
def generate_roughly_equally_spaced_steps(
|
12 |
-
num_substeps: int, max_step: int
|
13 |
-
) -> np.ndarray:
|
14 |
-
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
|
15 |
-
|
16 |
-
|
17 |
-
class Discretization:
|
18 |
-
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
|
19 |
-
sigmas = self.get_sigmas(n, device=device)
|
20 |
-
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
21 |
-
return sigmas if not flip else torch.flip(sigmas, (0,))
|
22 |
-
|
23 |
-
@abstractmethod
|
24 |
-
def get_sigmas(self, n, device):
|
25 |
-
pass
|
26 |
-
|
27 |
-
|
28 |
-
class EDMDiscretization(Discretization):
|
29 |
-
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
|
30 |
-
self.sigma_min = sigma_min
|
31 |
-
self.sigma_max = sigma_max
|
32 |
-
self.rho = rho
|
33 |
-
|
34 |
-
def get_sigmas(self, n, device="cpu"):
|
35 |
-
ramp = torch.linspace(0, 1, n, device=device)
|
36 |
-
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
37 |
-
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
38 |
-
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
|
39 |
-
return sigmas
|
40 |
-
|
41 |
-
|
42 |
-
class LegacyDDPMDiscretization(Discretization):
|
43 |
-
def __init__(
|
44 |
-
self,
|
45 |
-
linear_start=0.00085,
|
46 |
-
linear_end=0.0120,
|
47 |
-
num_timesteps=1000,
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
self.num_timesteps = num_timesteps
|
51 |
-
betas = make_beta_schedule(
|
52 |
-
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
53 |
-
)
|
54 |
-
alphas = 1.0 - betas
|
55 |
-
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
56 |
-
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
57 |
-
|
58 |
-
def get_sigmas(self, n, device="cpu"):
|
59 |
-
if n < self.num_timesteps:
|
60 |
-
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
61 |
-
alphas_cumprod = self.alphas_cumprod[timesteps]
|
62 |
-
elif n == self.num_timesteps:
|
63 |
-
alphas_cumprod = self.alphas_cumprod
|
64 |
-
else:
|
65 |
-
raise ValueError
|
66 |
-
|
67 |
-
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
68 |
-
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
69 |
-
return torch.flip(sigmas, (0,))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/guiders.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from abc import ABC, abstractmethod
|
3 |
-
from typing import Dict, List, Optional, Tuple, Union
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from einops import rearrange, repeat
|
7 |
-
|
8 |
-
from ...util import append_dims, default
|
9 |
-
|
10 |
-
logpy = logging.getLogger(__name__)
|
11 |
-
|
12 |
-
|
13 |
-
class Guider(ABC):
|
14 |
-
@abstractmethod
|
15 |
-
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
16 |
-
pass
|
17 |
-
|
18 |
-
def prepare_inputs(
|
19 |
-
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
20 |
-
) -> Tuple[torch.Tensor, float, Dict]:
|
21 |
-
pass
|
22 |
-
|
23 |
-
|
24 |
-
class VanillaCFG(Guider):
|
25 |
-
def __init__(self, scale: float):
|
26 |
-
self.scale = scale
|
27 |
-
|
28 |
-
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
29 |
-
x_u, x_c = x.chunk(2)
|
30 |
-
x_pred = x_u + self.scale * (x_c - x_u)
|
31 |
-
return x_pred
|
32 |
-
|
33 |
-
def prepare_inputs(self, x, s, c, uc):
|
34 |
-
c_out = dict()
|
35 |
-
|
36 |
-
for k in c:
|
37 |
-
if k in ["vector", "crossattn", "concat"]:
|
38 |
-
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
39 |
-
else:
|
40 |
-
assert c[k] == uc[k]
|
41 |
-
c_out[k] = c[k]
|
42 |
-
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
43 |
-
|
44 |
-
|
45 |
-
class IdentityGuider(Guider):
|
46 |
-
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
|
47 |
-
return x
|
48 |
-
|
49 |
-
def prepare_inputs(
|
50 |
-
self, x: torch.Tensor, s: float, c: Dict, uc: Dict
|
51 |
-
) -> Tuple[torch.Tensor, float, Dict]:
|
52 |
-
c_out = dict()
|
53 |
-
|
54 |
-
for k in c:
|
55 |
-
c_out[k] = c[k]
|
56 |
-
|
57 |
-
return x, s, c_out
|
58 |
-
|
59 |
-
|
60 |
-
class LinearPredictionGuider(Guider):
|
61 |
-
def __init__(
|
62 |
-
self,
|
63 |
-
max_scale: float,
|
64 |
-
num_frames: int,
|
65 |
-
min_scale: float = 1.0,
|
66 |
-
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
67 |
-
):
|
68 |
-
self.min_scale = min_scale
|
69 |
-
self.max_scale = max_scale
|
70 |
-
self.num_frames = num_frames
|
71 |
-
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
|
72 |
-
|
73 |
-
additional_cond_keys = default(additional_cond_keys, [])
|
74 |
-
if isinstance(additional_cond_keys, str):
|
75 |
-
additional_cond_keys = [additional_cond_keys]
|
76 |
-
self.additional_cond_keys = additional_cond_keys
|
77 |
-
|
78 |
-
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
79 |
-
x_u, x_c = x.chunk(2)
|
80 |
-
|
81 |
-
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
|
82 |
-
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
|
83 |
-
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
|
84 |
-
scale = append_dims(scale, x_u.ndim).to(x_u.device)
|
85 |
-
|
86 |
-
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
|
87 |
-
|
88 |
-
def prepare_inputs(
|
89 |
-
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
|
90 |
-
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
91 |
-
c_out = dict()
|
92 |
-
|
93 |
-
for k in c:
|
94 |
-
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
|
95 |
-
c_out[k] = torch.cat((uc[k], c[k]), 0)
|
96 |
-
else:
|
97 |
-
assert c[k] == uc[k]
|
98 |
-
c_out[k] = c[k]
|
99 |
-
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/loss.py
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
from typing import Dict, List, Optional, Tuple, Union
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
from ...modules.autoencoding.lpips.loss.lpips import LPIPS
|
7 |
-
from ...modules.encoders.modules import GeneralConditioner
|
8 |
-
from ...util import append_dims, instantiate_from_config
|
9 |
-
from .denoiser import Denoiser
|
10 |
-
|
11 |
-
|
12 |
-
class StandardDiffusionLoss(nn.Module):
|
13 |
-
def __init__(
|
14 |
-
self,
|
15 |
-
sigma_sampler_config: dict,
|
16 |
-
loss_weighting_config: dict,
|
17 |
-
loss_type: str = "l2",
|
18 |
-
offset_noise_level: float = 0.0,
|
19 |
-
batch2model_keys: Optional[Union[str, List[str]]] = None,
|
20 |
-
):
|
21 |
-
super().__init__()
|
22 |
-
|
23 |
-
assert loss_type in ["l2", "l1", "lpips"]
|
24 |
-
|
25 |
-
self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
|
26 |
-
self.loss_weighting = instantiate_from_config(loss_weighting_config)
|
27 |
-
|
28 |
-
self.loss_type = loss_type
|
29 |
-
self.offset_noise_level = offset_noise_level
|
30 |
-
|
31 |
-
if loss_type == "lpips":
|
32 |
-
self.lpips = LPIPS().eval()
|
33 |
-
|
34 |
-
if not batch2model_keys:
|
35 |
-
batch2model_keys = []
|
36 |
-
|
37 |
-
if isinstance(batch2model_keys, str):
|
38 |
-
batch2model_keys = [batch2model_keys]
|
39 |
-
|
40 |
-
self.batch2model_keys = set(batch2model_keys)
|
41 |
-
|
42 |
-
def get_noised_input(
|
43 |
-
self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
|
44 |
-
) -> torch.Tensor:
|
45 |
-
noised_input = input + noise * sigmas_bc
|
46 |
-
return noised_input
|
47 |
-
|
48 |
-
def forward(
|
49 |
-
self,
|
50 |
-
network: nn.Module,
|
51 |
-
denoiser: Denoiser,
|
52 |
-
conditioner: GeneralConditioner,
|
53 |
-
input: torch.Tensor,
|
54 |
-
batch: Dict,
|
55 |
-
) -> torch.Tensor:
|
56 |
-
cond = conditioner(batch)
|
57 |
-
return self._forward(network, denoiser, cond, input, batch)
|
58 |
-
|
59 |
-
def _forward(
|
60 |
-
self,
|
61 |
-
network: nn.Module,
|
62 |
-
denoiser: Denoiser,
|
63 |
-
cond: Dict,
|
64 |
-
input: torch.Tensor,
|
65 |
-
batch: Dict,
|
66 |
-
) -> Tuple[torch.Tensor, Dict]:
|
67 |
-
additional_model_inputs = {
|
68 |
-
key: batch[key] for key in self.batch2model_keys.intersection(batch)
|
69 |
-
}
|
70 |
-
sigmas = self.sigma_sampler(input.shape[0]).to(input)
|
71 |
-
|
72 |
-
noise = torch.randn_like(input)
|
73 |
-
if self.offset_noise_level > 0.0:
|
74 |
-
offset_shape = (
|
75 |
-
(input.shape[0], 1, input.shape[2])
|
76 |
-
if self.n_frames is not None
|
77 |
-
else (input.shape[0], input.shape[1])
|
78 |
-
)
|
79 |
-
noise = noise + self.offset_noise_level * append_dims(
|
80 |
-
torch.randn(offset_shape, device=input.device),
|
81 |
-
input.ndim,
|
82 |
-
)
|
83 |
-
sigmas_bc = append_dims(sigmas, input.ndim)
|
84 |
-
noised_input = self.get_noised_input(sigmas_bc, noise, input)
|
85 |
-
|
86 |
-
model_output = denoiser(
|
87 |
-
network, noised_input, sigmas, cond, **additional_model_inputs
|
88 |
-
)
|
89 |
-
w = append_dims(self.loss_weighting(sigmas), input.ndim)
|
90 |
-
return self.get_loss(model_output, input, w)
|
91 |
-
|
92 |
-
def get_loss(self, model_output, target, w):
|
93 |
-
if self.loss_type == "l2":
|
94 |
-
return torch.mean(
|
95 |
-
(w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
|
96 |
-
)
|
97 |
-
elif self.loss_type == "l1":
|
98 |
-
return torch.mean(
|
99 |
-
(w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
|
100 |
-
)
|
101 |
-
elif self.loss_type == "lpips":
|
102 |
-
loss = self.lpips(model_output, target).reshape(-1)
|
103 |
-
return loss
|
104 |
-
else:
|
105 |
-
raise NotImplementedError(f"Unknown loss type {self.loss_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/loss_weighting.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
class DiffusionLossWeighting(ABC):
|
7 |
-
@abstractmethod
|
8 |
-
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
9 |
-
pass
|
10 |
-
|
11 |
-
|
12 |
-
class UnitWeighting(DiffusionLossWeighting):
|
13 |
-
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
14 |
-
return torch.ones_like(sigma, device=sigma.device)
|
15 |
-
|
16 |
-
|
17 |
-
class EDMWeighting(DiffusionLossWeighting):
|
18 |
-
def __init__(self, sigma_data: float = 0.5):
|
19 |
-
self.sigma_data = sigma_data
|
20 |
-
|
21 |
-
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
22 |
-
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
23 |
-
|
24 |
-
|
25 |
-
class VWeighting(EDMWeighting):
|
26 |
-
def __init__(self):
|
27 |
-
super().__init__(sigma_data=1.0)
|
28 |
-
|
29 |
-
|
30 |
-
class EpsWeighting(DiffusionLossWeighting):
|
31 |
-
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
|
32 |
-
return sigma**-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/model.py
DELETED
@@ -1,748 +0,0 @@
|
|
1 |
-
# pytorch_diffusion + derived encoder decoder
|
2 |
-
import logging
|
3 |
-
import math
|
4 |
-
from typing import Any, Callable, Optional
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
from einops import rearrange
|
10 |
-
from packaging import version
|
11 |
-
|
12 |
-
logpy = logging.getLogger(__name__)
|
13 |
-
|
14 |
-
try:
|
15 |
-
import xformers
|
16 |
-
import xformers.ops
|
17 |
-
|
18 |
-
XFORMERS_IS_AVAILABLE = True
|
19 |
-
except:
|
20 |
-
XFORMERS_IS_AVAILABLE = False
|
21 |
-
logpy.warning("no module 'xformers'. Processing without...")
|
22 |
-
|
23 |
-
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
24 |
-
|
25 |
-
|
26 |
-
def get_timestep_embedding(timesteps, embedding_dim):
|
27 |
-
"""
|
28 |
-
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
29 |
-
From Fairseq.
|
30 |
-
Build sinusoidal embeddings.
|
31 |
-
This matches the implementation in tensor2tensor, but differs slightly
|
32 |
-
from the description in Section 3.5 of "Attention Is All You Need".
|
33 |
-
"""
|
34 |
-
assert len(timesteps.shape) == 1
|
35 |
-
|
36 |
-
half_dim = embedding_dim // 2
|
37 |
-
emb = math.log(10000) / (half_dim - 1)
|
38 |
-
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
39 |
-
emb = emb.to(device=timesteps.device)
|
40 |
-
emb = timesteps.float()[:, None] * emb[None, :]
|
41 |
-
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
42 |
-
if embedding_dim % 2 == 1: # zero pad
|
43 |
-
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
44 |
-
return emb
|
45 |
-
|
46 |
-
|
47 |
-
def nonlinearity(x):
|
48 |
-
# swish
|
49 |
-
return x * torch.sigmoid(x)
|
50 |
-
|
51 |
-
|
52 |
-
def Normalize(in_channels, num_groups=32):
|
53 |
-
return torch.nn.GroupNorm(
|
54 |
-
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
55 |
-
)
|
56 |
-
|
57 |
-
|
58 |
-
class Upsample(nn.Module):
|
59 |
-
def __init__(self, in_channels, with_conv):
|
60 |
-
super().__init__()
|
61 |
-
self.with_conv = with_conv
|
62 |
-
if self.with_conv:
|
63 |
-
self.conv = torch.nn.Conv2d(
|
64 |
-
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
65 |
-
)
|
66 |
-
|
67 |
-
def forward(self, x):
|
68 |
-
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
69 |
-
if self.with_conv:
|
70 |
-
x = self.conv(x)
|
71 |
-
return x
|
72 |
-
|
73 |
-
|
74 |
-
class Downsample(nn.Module):
|
75 |
-
def __init__(self, in_channels, with_conv):
|
76 |
-
super().__init__()
|
77 |
-
self.with_conv = with_conv
|
78 |
-
if self.with_conv:
|
79 |
-
# no asymmetric padding in torch conv, must do it ourselves
|
80 |
-
self.conv = torch.nn.Conv2d(
|
81 |
-
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
82 |
-
)
|
83 |
-
|
84 |
-
def forward(self, x):
|
85 |
-
if self.with_conv:
|
86 |
-
pad = (0, 1, 0, 1)
|
87 |
-
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
88 |
-
x = self.conv(x)
|
89 |
-
else:
|
90 |
-
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
91 |
-
return x
|
92 |
-
|
93 |
-
|
94 |
-
class ResnetBlock(nn.Module):
|
95 |
-
def __init__(
|
96 |
-
self,
|
97 |
-
*,
|
98 |
-
in_channels,
|
99 |
-
out_channels=None,
|
100 |
-
conv_shortcut=False,
|
101 |
-
dropout,
|
102 |
-
temb_channels=512,
|
103 |
-
):
|
104 |
-
super().__init__()
|
105 |
-
self.in_channels = in_channels
|
106 |
-
out_channels = in_channels if out_channels is None else out_channels
|
107 |
-
self.out_channels = out_channels
|
108 |
-
self.use_conv_shortcut = conv_shortcut
|
109 |
-
|
110 |
-
self.norm1 = Normalize(in_channels)
|
111 |
-
self.conv1 = torch.nn.Conv2d(
|
112 |
-
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
113 |
-
)
|
114 |
-
if temb_channels > 0:
|
115 |
-
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
116 |
-
self.norm2 = Normalize(out_channels)
|
117 |
-
self.dropout = torch.nn.Dropout(dropout)
|
118 |
-
self.conv2 = torch.nn.Conv2d(
|
119 |
-
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
120 |
-
)
|
121 |
-
if self.in_channels != self.out_channels:
|
122 |
-
if self.use_conv_shortcut:
|
123 |
-
self.conv_shortcut = torch.nn.Conv2d(
|
124 |
-
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
125 |
-
)
|
126 |
-
else:
|
127 |
-
self.nin_shortcut = torch.nn.Conv2d(
|
128 |
-
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
129 |
-
)
|
130 |
-
|
131 |
-
def forward(self, x, temb):
|
132 |
-
h = x
|
133 |
-
h = self.norm1(h)
|
134 |
-
h = nonlinearity(h)
|
135 |
-
h = self.conv1(h)
|
136 |
-
|
137 |
-
if temb is not None:
|
138 |
-
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
139 |
-
|
140 |
-
h = self.norm2(h)
|
141 |
-
h = nonlinearity(h)
|
142 |
-
h = self.dropout(h)
|
143 |
-
h = self.conv2(h)
|
144 |
-
|
145 |
-
if self.in_channels != self.out_channels:
|
146 |
-
if self.use_conv_shortcut:
|
147 |
-
x = self.conv_shortcut(x)
|
148 |
-
else:
|
149 |
-
x = self.nin_shortcut(x)
|
150 |
-
|
151 |
-
return x + h
|
152 |
-
|
153 |
-
|
154 |
-
class LinAttnBlock(LinearAttention):
|
155 |
-
"""to match AttnBlock usage"""
|
156 |
-
|
157 |
-
def __init__(self, in_channels):
|
158 |
-
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
159 |
-
|
160 |
-
|
161 |
-
class AttnBlock(nn.Module):
|
162 |
-
def __init__(self, in_channels):
|
163 |
-
super().__init__()
|
164 |
-
self.in_channels = in_channels
|
165 |
-
|
166 |
-
self.norm = Normalize(in_channels)
|
167 |
-
self.q = torch.nn.Conv2d(
|
168 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
169 |
-
)
|
170 |
-
self.k = torch.nn.Conv2d(
|
171 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
172 |
-
)
|
173 |
-
self.v = torch.nn.Conv2d(
|
174 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
175 |
-
)
|
176 |
-
self.proj_out = torch.nn.Conv2d(
|
177 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
178 |
-
)
|
179 |
-
|
180 |
-
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
181 |
-
h_ = self.norm(h_)
|
182 |
-
q = self.q(h_)
|
183 |
-
k = self.k(h_)
|
184 |
-
v = self.v(h_)
|
185 |
-
|
186 |
-
b, c, h, w = q.shape
|
187 |
-
q, k, v = map(
|
188 |
-
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
|
189 |
-
)
|
190 |
-
h_ = torch.nn.functional.scaled_dot_product_attention(
|
191 |
-
q, k, v
|
192 |
-
) # scale is dim ** -0.5 per default
|
193 |
-
# compute attention
|
194 |
-
|
195 |
-
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
196 |
-
|
197 |
-
def forward(self, x, **kwargs):
|
198 |
-
h_ = x
|
199 |
-
h_ = self.attention(h_)
|
200 |
-
h_ = self.proj_out(h_)
|
201 |
-
return x + h_
|
202 |
-
|
203 |
-
|
204 |
-
class MemoryEfficientAttnBlock(nn.Module):
|
205 |
-
"""
|
206 |
-
Uses xformers efficient implementation,
|
207 |
-
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
208 |
-
Note: this is a single-head self-attention operation
|
209 |
-
"""
|
210 |
-
|
211 |
-
#
|
212 |
-
def __init__(self, in_channels):
|
213 |
-
super().__init__()
|
214 |
-
self.in_channels = in_channels
|
215 |
-
|
216 |
-
self.norm = Normalize(in_channels)
|
217 |
-
self.q = torch.nn.Conv2d(
|
218 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
219 |
-
)
|
220 |
-
self.k = torch.nn.Conv2d(
|
221 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
222 |
-
)
|
223 |
-
self.v = torch.nn.Conv2d(
|
224 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
225 |
-
)
|
226 |
-
self.proj_out = torch.nn.Conv2d(
|
227 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
228 |
-
)
|
229 |
-
self.attention_op: Optional[Any] = None
|
230 |
-
|
231 |
-
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
232 |
-
h_ = self.norm(h_)
|
233 |
-
q = self.q(h_)
|
234 |
-
k = self.k(h_)
|
235 |
-
v = self.v(h_)
|
236 |
-
|
237 |
-
# compute attention
|
238 |
-
B, C, H, W = q.shape
|
239 |
-
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
|
240 |
-
|
241 |
-
q, k, v = map(
|
242 |
-
lambda t: t.unsqueeze(3)
|
243 |
-
.reshape(B, t.shape[1], 1, C)
|
244 |
-
.permute(0, 2, 1, 3)
|
245 |
-
.reshape(B * 1, t.shape[1], C)
|
246 |
-
.contiguous(),
|
247 |
-
(q, k, v),
|
248 |
-
)
|
249 |
-
out = xformers.ops.memory_efficient_attention(
|
250 |
-
q, k, v, attn_bias=None, op=self.attention_op
|
251 |
-
)
|
252 |
-
|
253 |
-
out = (
|
254 |
-
out.unsqueeze(0)
|
255 |
-
.reshape(B, 1, out.shape[1], C)
|
256 |
-
.permute(0, 2, 1, 3)
|
257 |
-
.reshape(B, out.shape[1], C)
|
258 |
-
)
|
259 |
-
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
|
260 |
-
|
261 |
-
def forward(self, x, **kwargs):
|
262 |
-
h_ = x
|
263 |
-
h_ = self.attention(h_)
|
264 |
-
h_ = self.proj_out(h_)
|
265 |
-
return x + h_
|
266 |
-
|
267 |
-
|
268 |
-
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
269 |
-
def forward(self, x, context=None, mask=None, **unused_kwargs):
|
270 |
-
b, c, h, w = x.shape
|
271 |
-
x = rearrange(x, "b c h w -> b (h w) c")
|
272 |
-
out = super().forward(x, context=context, mask=mask)
|
273 |
-
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
|
274 |
-
return x + out
|
275 |
-
|
276 |
-
|
277 |
-
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
278 |
-
assert attn_type in [
|
279 |
-
"vanilla",
|
280 |
-
"vanilla-xformers",
|
281 |
-
"memory-efficient-cross-attn",
|
282 |
-
"linear",
|
283 |
-
"none",
|
284 |
-
], f"attn_type {attn_type} unknown"
|
285 |
-
if (
|
286 |
-
version.parse(torch.__version__) < version.parse("2.0.0")
|
287 |
-
and attn_type != "none"
|
288 |
-
):
|
289 |
-
assert XFORMERS_IS_AVAILABLE, (
|
290 |
-
f"We do not support vanilla attention in {torch.__version__} anymore, "
|
291 |
-
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
292 |
-
)
|
293 |
-
attn_type = "vanilla-xformers"
|
294 |
-
logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
295 |
-
if attn_type == "vanilla":
|
296 |
-
assert attn_kwargs is None
|
297 |
-
return AttnBlock(in_channels)
|
298 |
-
elif attn_type == "vanilla-xformers":
|
299 |
-
logpy.info(
|
300 |
-
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
301 |
-
)
|
302 |
-
return MemoryEfficientAttnBlock(in_channels)
|
303 |
-
elif type == "memory-efficient-cross-attn":
|
304 |
-
attn_kwargs["query_dim"] = in_channels
|
305 |
-
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
306 |
-
elif attn_type == "none":
|
307 |
-
return nn.Identity(in_channels)
|
308 |
-
else:
|
309 |
-
return LinAttnBlock(in_channels)
|
310 |
-
|
311 |
-
|
312 |
-
class Model(nn.Module):
|
313 |
-
def __init__(
|
314 |
-
self,
|
315 |
-
*,
|
316 |
-
ch,
|
317 |
-
out_ch,
|
318 |
-
ch_mult=(1, 2, 4, 8),
|
319 |
-
num_res_blocks,
|
320 |
-
attn_resolutions,
|
321 |
-
dropout=0.0,
|
322 |
-
resamp_with_conv=True,
|
323 |
-
in_channels,
|
324 |
-
resolution,
|
325 |
-
use_timestep=True,
|
326 |
-
use_linear_attn=False,
|
327 |
-
attn_type="vanilla",
|
328 |
-
):
|
329 |
-
super().__init__()
|
330 |
-
if use_linear_attn:
|
331 |
-
attn_type = "linear"
|
332 |
-
self.ch = ch
|
333 |
-
self.temb_ch = self.ch * 4
|
334 |
-
self.num_resolutions = len(ch_mult)
|
335 |
-
self.num_res_blocks = num_res_blocks
|
336 |
-
self.resolution = resolution
|
337 |
-
self.in_channels = in_channels
|
338 |
-
|
339 |
-
self.use_timestep = use_timestep
|
340 |
-
if self.use_timestep:
|
341 |
-
# timestep embedding
|
342 |
-
self.temb = nn.Module()
|
343 |
-
self.temb.dense = nn.ModuleList(
|
344 |
-
[
|
345 |
-
torch.nn.Linear(self.ch, self.temb_ch),
|
346 |
-
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
347 |
-
]
|
348 |
-
)
|
349 |
-
|
350 |
-
# downsampling
|
351 |
-
self.conv_in = torch.nn.Conv2d(
|
352 |
-
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
353 |
-
)
|
354 |
-
|
355 |
-
curr_res = resolution
|
356 |
-
in_ch_mult = (1,) + tuple(ch_mult)
|
357 |
-
self.down = nn.ModuleList()
|
358 |
-
for i_level in range(self.num_resolutions):
|
359 |
-
block = nn.ModuleList()
|
360 |
-
attn = nn.ModuleList()
|
361 |
-
block_in = ch * in_ch_mult[i_level]
|
362 |
-
block_out = ch * ch_mult[i_level]
|
363 |
-
for i_block in range(self.num_res_blocks):
|
364 |
-
block.append(
|
365 |
-
ResnetBlock(
|
366 |
-
in_channels=block_in,
|
367 |
-
out_channels=block_out,
|
368 |
-
temb_channels=self.temb_ch,
|
369 |
-
dropout=dropout,
|
370 |
-
)
|
371 |
-
)
|
372 |
-
block_in = block_out
|
373 |
-
if curr_res in attn_resolutions:
|
374 |
-
attn.append(make_attn(block_in, attn_type=attn_type))
|
375 |
-
down = nn.Module()
|
376 |
-
down.block = block
|
377 |
-
down.attn = attn
|
378 |
-
if i_level != self.num_resolutions - 1:
|
379 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
380 |
-
curr_res = curr_res // 2
|
381 |
-
self.down.append(down)
|
382 |
-
|
383 |
-
# middle
|
384 |
-
self.mid = nn.Module()
|
385 |
-
self.mid.block_1 = ResnetBlock(
|
386 |
-
in_channels=block_in,
|
387 |
-
out_channels=block_in,
|
388 |
-
temb_channels=self.temb_ch,
|
389 |
-
dropout=dropout,
|
390 |
-
)
|
391 |
-
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
392 |
-
self.mid.block_2 = ResnetBlock(
|
393 |
-
in_channels=block_in,
|
394 |
-
out_channels=block_in,
|
395 |
-
temb_channels=self.temb_ch,
|
396 |
-
dropout=dropout,
|
397 |
-
)
|
398 |
-
|
399 |
-
# upsampling
|
400 |
-
self.up = nn.ModuleList()
|
401 |
-
for i_level in reversed(range(self.num_resolutions)):
|
402 |
-
block = nn.ModuleList()
|
403 |
-
attn = nn.ModuleList()
|
404 |
-
block_out = ch * ch_mult[i_level]
|
405 |
-
skip_in = ch * ch_mult[i_level]
|
406 |
-
for i_block in range(self.num_res_blocks + 1):
|
407 |
-
if i_block == self.num_res_blocks:
|
408 |
-
skip_in = ch * in_ch_mult[i_level]
|
409 |
-
block.append(
|
410 |
-
ResnetBlock(
|
411 |
-
in_channels=block_in + skip_in,
|
412 |
-
out_channels=block_out,
|
413 |
-
temb_channels=self.temb_ch,
|
414 |
-
dropout=dropout,
|
415 |
-
)
|
416 |
-
)
|
417 |
-
block_in = block_out
|
418 |
-
if curr_res in attn_resolutions:
|
419 |
-
attn.append(make_attn(block_in, attn_type=attn_type))
|
420 |
-
up = nn.Module()
|
421 |
-
up.block = block
|
422 |
-
up.attn = attn
|
423 |
-
if i_level != 0:
|
424 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
425 |
-
curr_res = curr_res * 2
|
426 |
-
self.up.insert(0, up) # prepend to get consistent order
|
427 |
-
|
428 |
-
# end
|
429 |
-
self.norm_out = Normalize(block_in)
|
430 |
-
self.conv_out = torch.nn.Conv2d(
|
431 |
-
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
432 |
-
)
|
433 |
-
|
434 |
-
def forward(self, x, t=None, context=None):
|
435 |
-
# assert x.shape[2] == x.shape[3] == self.resolution
|
436 |
-
if context is not None:
|
437 |
-
# assume aligned context, cat along channel axis
|
438 |
-
x = torch.cat((x, context), dim=1)
|
439 |
-
if self.use_timestep:
|
440 |
-
# timestep embedding
|
441 |
-
assert t is not None
|
442 |
-
temb = get_timestep_embedding(t, self.ch)
|
443 |
-
temb = self.temb.dense[0](temb)
|
444 |
-
temb = nonlinearity(temb)
|
445 |
-
temb = self.temb.dense[1](temb)
|
446 |
-
else:
|
447 |
-
temb = None
|
448 |
-
|
449 |
-
# downsampling
|
450 |
-
hs = [self.conv_in(x)]
|
451 |
-
for i_level in range(self.num_resolutions):
|
452 |
-
for i_block in range(self.num_res_blocks):
|
453 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
454 |
-
if len(self.down[i_level].attn) > 0:
|
455 |
-
h = self.down[i_level].attn[i_block](h)
|
456 |
-
hs.append(h)
|
457 |
-
if i_level != self.num_resolutions - 1:
|
458 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
459 |
-
|
460 |
-
# middle
|
461 |
-
h = hs[-1]
|
462 |
-
h = self.mid.block_1(h, temb)
|
463 |
-
h = self.mid.attn_1(h)
|
464 |
-
h = self.mid.block_2(h, temb)
|
465 |
-
|
466 |
-
# upsampling
|
467 |
-
for i_level in reversed(range(self.num_resolutions)):
|
468 |
-
for i_block in range(self.num_res_blocks + 1):
|
469 |
-
h = self.up[i_level].block[i_block](
|
470 |
-
torch.cat([h, hs.pop()], dim=1), temb
|
471 |
-
)
|
472 |
-
if len(self.up[i_level].attn) > 0:
|
473 |
-
h = self.up[i_level].attn[i_block](h)
|
474 |
-
if i_level != 0:
|
475 |
-
h = self.up[i_level].upsample(h)
|
476 |
-
|
477 |
-
# end
|
478 |
-
h = self.norm_out(h)
|
479 |
-
h = nonlinearity(h)
|
480 |
-
h = self.conv_out(h)
|
481 |
-
return h
|
482 |
-
|
483 |
-
def get_last_layer(self):
|
484 |
-
return self.conv_out.weight
|
485 |
-
|
486 |
-
|
487 |
-
class Encoder(nn.Module):
|
488 |
-
def __init__(
|
489 |
-
self,
|
490 |
-
*,
|
491 |
-
ch,
|
492 |
-
out_ch,
|
493 |
-
ch_mult=(1, 2, 4, 8),
|
494 |
-
num_res_blocks,
|
495 |
-
attn_resolutions,
|
496 |
-
dropout=0.0,
|
497 |
-
resamp_with_conv=True,
|
498 |
-
in_channels,
|
499 |
-
resolution,
|
500 |
-
z_channels,
|
501 |
-
double_z=True,
|
502 |
-
use_linear_attn=False,
|
503 |
-
attn_type="vanilla",
|
504 |
-
**ignore_kwargs,
|
505 |
-
):
|
506 |
-
super().__init__()
|
507 |
-
if use_linear_attn:
|
508 |
-
attn_type = "linear"
|
509 |
-
self.ch = ch
|
510 |
-
self.temb_ch = 0
|
511 |
-
self.num_resolutions = len(ch_mult)
|
512 |
-
self.num_res_blocks = num_res_blocks
|
513 |
-
self.resolution = resolution
|
514 |
-
self.in_channels = in_channels
|
515 |
-
|
516 |
-
# downsampling
|
517 |
-
self.conv_in = torch.nn.Conv2d(
|
518 |
-
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
519 |
-
)
|
520 |
-
|
521 |
-
curr_res = resolution
|
522 |
-
in_ch_mult = (1,) + tuple(ch_mult)
|
523 |
-
self.in_ch_mult = in_ch_mult
|
524 |
-
self.down = nn.ModuleList()
|
525 |
-
for i_level in range(self.num_resolutions):
|
526 |
-
block = nn.ModuleList()
|
527 |
-
attn = nn.ModuleList()
|
528 |
-
block_in = ch * in_ch_mult[i_level]
|
529 |
-
block_out = ch * ch_mult[i_level]
|
530 |
-
for i_block in range(self.num_res_blocks):
|
531 |
-
block.append(
|
532 |
-
ResnetBlock(
|
533 |
-
in_channels=block_in,
|
534 |
-
out_channels=block_out,
|
535 |
-
temb_channels=self.temb_ch,
|
536 |
-
dropout=dropout,
|
537 |
-
)
|
538 |
-
)
|
539 |
-
block_in = block_out
|
540 |
-
if curr_res in attn_resolutions:
|
541 |
-
attn.append(make_attn(block_in, attn_type=attn_type))
|
542 |
-
down = nn.Module()
|
543 |
-
down.block = block
|
544 |
-
down.attn = attn
|
545 |
-
if i_level != self.num_resolutions - 1:
|
546 |
-
down.downsample = Downsample(block_in, resamp_with_conv)
|
547 |
-
curr_res = curr_res // 2
|
548 |
-
self.down.append(down)
|
549 |
-
|
550 |
-
# middle
|
551 |
-
self.mid = nn.Module()
|
552 |
-
self.mid.block_1 = ResnetBlock(
|
553 |
-
in_channels=block_in,
|
554 |
-
out_channels=block_in,
|
555 |
-
temb_channels=self.temb_ch,
|
556 |
-
dropout=dropout,
|
557 |
-
)
|
558 |
-
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
559 |
-
self.mid.block_2 = ResnetBlock(
|
560 |
-
in_channels=block_in,
|
561 |
-
out_channels=block_in,
|
562 |
-
temb_channels=self.temb_ch,
|
563 |
-
dropout=dropout,
|
564 |
-
)
|
565 |
-
|
566 |
-
# end
|
567 |
-
self.norm_out = Normalize(block_in)
|
568 |
-
self.conv_out = torch.nn.Conv2d(
|
569 |
-
block_in,
|
570 |
-
2 * z_channels if double_z else z_channels,
|
571 |
-
kernel_size=3,
|
572 |
-
stride=1,
|
573 |
-
padding=1,
|
574 |
-
)
|
575 |
-
|
576 |
-
def forward(self, x):
|
577 |
-
# timestep embedding
|
578 |
-
temb = None
|
579 |
-
|
580 |
-
# downsampling
|
581 |
-
hs = [self.conv_in(x)]
|
582 |
-
for i_level in range(self.num_resolutions):
|
583 |
-
for i_block in range(self.num_res_blocks):
|
584 |
-
h = self.down[i_level].block[i_block](hs[-1], temb)
|
585 |
-
if len(self.down[i_level].attn) > 0:
|
586 |
-
h = self.down[i_level].attn[i_block](h)
|
587 |
-
hs.append(h)
|
588 |
-
if i_level != self.num_resolutions - 1:
|
589 |
-
hs.append(self.down[i_level].downsample(hs[-1]))
|
590 |
-
|
591 |
-
# middle
|
592 |
-
h = hs[-1]
|
593 |
-
h = self.mid.block_1(h, temb)
|
594 |
-
h = self.mid.attn_1(h)
|
595 |
-
h = self.mid.block_2(h, temb)
|
596 |
-
|
597 |
-
# end
|
598 |
-
h = self.norm_out(h)
|
599 |
-
h = nonlinearity(h)
|
600 |
-
h = self.conv_out(h)
|
601 |
-
return h
|
602 |
-
|
603 |
-
|
604 |
-
class Decoder(nn.Module):
|
605 |
-
def __init__(
|
606 |
-
self,
|
607 |
-
*,
|
608 |
-
ch,
|
609 |
-
out_ch,
|
610 |
-
ch_mult=(1, 2, 4, 8),
|
611 |
-
num_res_blocks,
|
612 |
-
attn_resolutions,
|
613 |
-
dropout=0.0,
|
614 |
-
resamp_with_conv=True,
|
615 |
-
in_channels,
|
616 |
-
resolution,
|
617 |
-
z_channels,
|
618 |
-
give_pre_end=False,
|
619 |
-
tanh_out=False,
|
620 |
-
use_linear_attn=False,
|
621 |
-
attn_type="vanilla",
|
622 |
-
**ignorekwargs,
|
623 |
-
):
|
624 |
-
super().__init__()
|
625 |
-
if use_linear_attn:
|
626 |
-
attn_type = "linear"
|
627 |
-
self.ch = ch
|
628 |
-
self.temb_ch = 0
|
629 |
-
self.num_resolutions = len(ch_mult)
|
630 |
-
self.num_res_blocks = num_res_blocks
|
631 |
-
self.resolution = resolution
|
632 |
-
self.in_channels = in_channels
|
633 |
-
self.give_pre_end = give_pre_end
|
634 |
-
self.tanh_out = tanh_out
|
635 |
-
|
636 |
-
# compute in_ch_mult, block_in and curr_res at lowest res
|
637 |
-
in_ch_mult = (1,) + tuple(ch_mult)
|
638 |
-
block_in = ch * ch_mult[self.num_resolutions - 1]
|
639 |
-
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
640 |
-
self.z_shape = (1, z_channels, curr_res, curr_res)
|
641 |
-
logpy.info(
|
642 |
-
"Working with z of shape {} = {} dimensions.".format(
|
643 |
-
self.z_shape, np.prod(self.z_shape)
|
644 |
-
)
|
645 |
-
)
|
646 |
-
|
647 |
-
make_attn_cls = self._make_attn()
|
648 |
-
make_resblock_cls = self._make_resblock()
|
649 |
-
make_conv_cls = self._make_conv()
|
650 |
-
# z to block_in
|
651 |
-
self.conv_in = torch.nn.Conv2d(
|
652 |
-
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
653 |
-
)
|
654 |
-
|
655 |
-
# middle
|
656 |
-
self.mid = nn.Module()
|
657 |
-
self.mid.block_1 = make_resblock_cls(
|
658 |
-
in_channels=block_in,
|
659 |
-
out_channels=block_in,
|
660 |
-
temb_channels=self.temb_ch,
|
661 |
-
dropout=dropout,
|
662 |
-
)
|
663 |
-
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
|
664 |
-
self.mid.block_2 = make_resblock_cls(
|
665 |
-
in_channels=block_in,
|
666 |
-
out_channels=block_in,
|
667 |
-
temb_channels=self.temb_ch,
|
668 |
-
dropout=dropout,
|
669 |
-
)
|
670 |
-
|
671 |
-
# upsampling
|
672 |
-
self.up = nn.ModuleList()
|
673 |
-
for i_level in reversed(range(self.num_resolutions)):
|
674 |
-
block = nn.ModuleList()
|
675 |
-
attn = nn.ModuleList()
|
676 |
-
block_out = ch * ch_mult[i_level]
|
677 |
-
for i_block in range(self.num_res_blocks + 1):
|
678 |
-
block.append(
|
679 |
-
make_resblock_cls(
|
680 |
-
in_channels=block_in,
|
681 |
-
out_channels=block_out,
|
682 |
-
temb_channels=self.temb_ch,
|
683 |
-
dropout=dropout,
|
684 |
-
)
|
685 |
-
)
|
686 |
-
block_in = block_out
|
687 |
-
if curr_res in attn_resolutions:
|
688 |
-
attn.append(make_attn_cls(block_in, attn_type=attn_type))
|
689 |
-
up = nn.Module()
|
690 |
-
up.block = block
|
691 |
-
up.attn = attn
|
692 |
-
if i_level != 0:
|
693 |
-
up.upsample = Upsample(block_in, resamp_with_conv)
|
694 |
-
curr_res = curr_res * 2
|
695 |
-
self.up.insert(0, up) # prepend to get consistent order
|
696 |
-
|
697 |
-
# end
|
698 |
-
self.norm_out = Normalize(block_in)
|
699 |
-
self.conv_out = make_conv_cls(
|
700 |
-
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
701 |
-
)
|
702 |
-
|
703 |
-
def _make_attn(self) -> Callable:
|
704 |
-
return make_attn
|
705 |
-
|
706 |
-
def _make_resblock(self) -> Callable:
|
707 |
-
return ResnetBlock
|
708 |
-
|
709 |
-
def _make_conv(self) -> Callable:
|
710 |
-
return torch.nn.Conv2d
|
711 |
-
|
712 |
-
def get_last_layer(self, **kwargs):
|
713 |
-
return self.conv_out.weight
|
714 |
-
|
715 |
-
def forward(self, z, **kwargs):
|
716 |
-
# assert z.shape[1:] == self.z_shape[1:]
|
717 |
-
self.last_z_shape = z.shape
|
718 |
-
|
719 |
-
# timestep embedding
|
720 |
-
temb = None
|
721 |
-
|
722 |
-
# z to block_in
|
723 |
-
h = self.conv_in(z)
|
724 |
-
|
725 |
-
# middle
|
726 |
-
h = self.mid.block_1(h, temb, **kwargs)
|
727 |
-
h = self.mid.attn_1(h, **kwargs)
|
728 |
-
h = self.mid.block_2(h, temb, **kwargs)
|
729 |
-
|
730 |
-
# upsampling
|
731 |
-
for i_level in reversed(range(self.num_resolutions)):
|
732 |
-
for i_block in range(self.num_res_blocks + 1):
|
733 |
-
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
734 |
-
if len(self.up[i_level].attn) > 0:
|
735 |
-
h = self.up[i_level].attn[i_block](h, **kwargs)
|
736 |
-
if i_level != 0:
|
737 |
-
h = self.up[i_level].upsample(h)
|
738 |
-
|
739 |
-
# end
|
740 |
-
if self.give_pre_end:
|
741 |
-
return h
|
742 |
-
|
743 |
-
h = self.norm_out(h)
|
744 |
-
h = nonlinearity(h)
|
745 |
-
h = self.conv_out(h, **kwargs)
|
746 |
-
if self.tanh_out:
|
747 |
-
h = torch.tanh(h)
|
748 |
-
return h
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/openaimodel.py
DELETED
@@ -1,853 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import math
|
3 |
-
from abc import abstractmethod
|
4 |
-
from typing import Iterable, List, Optional, Tuple, Union
|
5 |
-
|
6 |
-
import torch as th
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from einops import rearrange
|
10 |
-
from torch.utils.checkpoint import checkpoint
|
11 |
-
|
12 |
-
from ...modules.attention import SpatialTransformer
|
13 |
-
from ...modules.diffusionmodules.util import (avg_pool_nd, conv_nd, linear,
|
14 |
-
normalization,
|
15 |
-
timestep_embedding, zero_module)
|
16 |
-
from ...modules.video_attention import SpatialVideoTransformer
|
17 |
-
from ...util import exists
|
18 |
-
|
19 |
-
logpy = logging.getLogger(__name__)
|
20 |
-
|
21 |
-
|
22 |
-
class AttentionPool2d(nn.Module):
|
23 |
-
"""
|
24 |
-
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
25 |
-
"""
|
26 |
-
|
27 |
-
def __init__(
|
28 |
-
self,
|
29 |
-
spacial_dim: int,
|
30 |
-
embed_dim: int,
|
31 |
-
num_heads_channels: int,
|
32 |
-
output_dim: Optional[int] = None,
|
33 |
-
):
|
34 |
-
super().__init__()
|
35 |
-
self.positional_embedding = nn.Parameter(
|
36 |
-
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
|
37 |
-
)
|
38 |
-
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
39 |
-
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
40 |
-
self.num_heads = embed_dim // num_heads_channels
|
41 |
-
self.attention = QKVAttention(self.num_heads)
|
42 |
-
|
43 |
-
def forward(self, x: th.Tensor) -> th.Tensor:
|
44 |
-
b, c, _ = x.shape
|
45 |
-
x = x.reshape(b, c, -1)
|
46 |
-
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
|
47 |
-
x = x + self.positional_embedding[None, :, :].to(x.dtype)
|
48 |
-
x = self.qkv_proj(x)
|
49 |
-
x = self.attention(x)
|
50 |
-
x = self.c_proj(x)
|
51 |
-
return x[:, :, 0]
|
52 |
-
|
53 |
-
|
54 |
-
class TimestepBlock(nn.Module):
|
55 |
-
"""
|
56 |
-
Any module where forward() takes timestep embeddings as a second argument.
|
57 |
-
"""
|
58 |
-
|
59 |
-
@abstractmethod
|
60 |
-
def forward(self, x: th.Tensor, emb: th.Tensor):
|
61 |
-
"""
|
62 |
-
Apply the module to `x` given `emb` timestep embeddings.
|
63 |
-
"""
|
64 |
-
|
65 |
-
|
66 |
-
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
67 |
-
"""
|
68 |
-
A sequential module that passes timestep embeddings to the children that
|
69 |
-
support it as an extra input.
|
70 |
-
"""
|
71 |
-
|
72 |
-
def forward(
|
73 |
-
self,
|
74 |
-
x: th.Tensor,
|
75 |
-
emb: th.Tensor,
|
76 |
-
context: Optional[th.Tensor] = None,
|
77 |
-
image_only_indicator: Optional[th.Tensor] = None,
|
78 |
-
time_context: Optional[int] = None,
|
79 |
-
num_video_frames: Optional[int] = None,
|
80 |
-
):
|
81 |
-
from ...modules.diffusionmodules.video_model import VideoResBlock
|
82 |
-
|
83 |
-
for layer in self:
|
84 |
-
module = layer
|
85 |
-
|
86 |
-
if isinstance(module, TimestepBlock) and not isinstance(
|
87 |
-
module, VideoResBlock
|
88 |
-
):
|
89 |
-
x = layer(x, emb)
|
90 |
-
elif isinstance(module, VideoResBlock):
|
91 |
-
x = layer(x, emb, num_video_frames, image_only_indicator)
|
92 |
-
elif isinstance(module, SpatialVideoTransformer):
|
93 |
-
x = layer(
|
94 |
-
x,
|
95 |
-
context,
|
96 |
-
time_context,
|
97 |
-
num_video_frames,
|
98 |
-
image_only_indicator,
|
99 |
-
)
|
100 |
-
elif isinstance(module, SpatialTransformer):
|
101 |
-
x = layer(x, context)
|
102 |
-
else:
|
103 |
-
x = layer(x)
|
104 |
-
return x
|
105 |
-
|
106 |
-
|
107 |
-
class Upsample(nn.Module):
|
108 |
-
"""
|
109 |
-
An upsampling layer with an optional convolution.
|
110 |
-
:param channels: channels in the inputs and outputs.
|
111 |
-
:param use_conv: a bool determining if a convolution is applied.
|
112 |
-
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
113 |
-
upsampling occurs in the inner-two dimensions.
|
114 |
-
"""
|
115 |
-
|
116 |
-
def __init__(
|
117 |
-
self,
|
118 |
-
channels: int,
|
119 |
-
use_conv: bool,
|
120 |
-
dims: int = 2,
|
121 |
-
out_channels: Optional[int] = None,
|
122 |
-
padding: int = 1,
|
123 |
-
third_up: bool = False,
|
124 |
-
kernel_size: int = 3,
|
125 |
-
scale_factor: int = 2,
|
126 |
-
):
|
127 |
-
super().__init__()
|
128 |
-
self.channels = channels
|
129 |
-
self.out_channels = out_channels or channels
|
130 |
-
self.use_conv = use_conv
|
131 |
-
self.dims = dims
|
132 |
-
self.third_up = third_up
|
133 |
-
self.scale_factor = scale_factor
|
134 |
-
if use_conv:
|
135 |
-
self.conv = conv_nd(
|
136 |
-
dims, self.channels, self.out_channels, kernel_size, padding=padding
|
137 |
-
)
|
138 |
-
|
139 |
-
def forward(self, x: th.Tensor) -> th.Tensor:
|
140 |
-
assert x.shape[1] == self.channels
|
141 |
-
|
142 |
-
if self.dims == 3:
|
143 |
-
t_factor = 1 if not self.third_up else self.scale_factor
|
144 |
-
x = F.interpolate(
|
145 |
-
x,
|
146 |
-
(
|
147 |
-
t_factor * x.shape[2],
|
148 |
-
x.shape[3] * self.scale_factor,
|
149 |
-
x.shape[4] * self.scale_factor,
|
150 |
-
),
|
151 |
-
mode="nearest",
|
152 |
-
)
|
153 |
-
else:
|
154 |
-
x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
|
155 |
-
if self.use_conv:
|
156 |
-
x = self.conv(x)
|
157 |
-
return x
|
158 |
-
|
159 |
-
|
160 |
-
class Downsample(nn.Module):
|
161 |
-
"""
|
162 |
-
A downsampling layer with an optional convolution.
|
163 |
-
:param channels: channels in the inputs and outputs.
|
164 |
-
:param use_conv: a bool determining if a convolution is applied.
|
165 |
-
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
166 |
-
downsampling occurs in the inner-two dimensions.
|
167 |
-
"""
|
168 |
-
|
169 |
-
def __init__(
|
170 |
-
self,
|
171 |
-
channels: int,
|
172 |
-
use_conv: bool,
|
173 |
-
dims: int = 2,
|
174 |
-
out_channels: Optional[int] = None,
|
175 |
-
padding: int = 1,
|
176 |
-
third_down: bool = False,
|
177 |
-
):
|
178 |
-
super().__init__()
|
179 |
-
self.channels = channels
|
180 |
-
self.out_channels = out_channels or channels
|
181 |
-
self.use_conv = use_conv
|
182 |
-
self.dims = dims
|
183 |
-
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
|
184 |
-
if use_conv:
|
185 |
-
logpy.info(f"Building a Downsample layer with {dims} dims.")
|
186 |
-
logpy.info(
|
187 |
-
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
|
188 |
-
f"kernel-size: 3, stride: {stride}, padding: {padding}"
|
189 |
-
)
|
190 |
-
if dims == 3:
|
191 |
-
logpy.info(f" --> Downsampling third axis (time): {third_down}")
|
192 |
-
self.op = conv_nd(
|
193 |
-
dims,
|
194 |
-
self.channels,
|
195 |
-
self.out_channels,
|
196 |
-
3,
|
197 |
-
stride=stride,
|
198 |
-
padding=padding,
|
199 |
-
)
|
200 |
-
else:
|
201 |
-
assert self.channels == self.out_channels
|
202 |
-
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
203 |
-
|
204 |
-
def forward(self, x: th.Tensor) -> th.Tensor:
|
205 |
-
assert x.shape[1] == self.channels
|
206 |
-
|
207 |
-
return self.op(x)
|
208 |
-
|
209 |
-
|
210 |
-
class ResBlock(TimestepBlock):
|
211 |
-
"""
|
212 |
-
A residual block that can optionally change the number of channels.
|
213 |
-
:param channels: the number of input channels.
|
214 |
-
:param emb_channels: the number of timestep embedding channels.
|
215 |
-
:param dropout: the rate of dropout.
|
216 |
-
:param out_channels: if specified, the number of out channels.
|
217 |
-
:param use_conv: if True and out_channels is specified, use a spatial
|
218 |
-
convolution instead of a smaller 1x1 convolution to change the
|
219 |
-
channels in the skip connection.
|
220 |
-
:param dims: determines if the signal is 1D, 2D, or 3D.
|
221 |
-
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
222 |
-
:param up: if True, use this block for upsampling.
|
223 |
-
:param down: if True, use this block for downsampling.
|
224 |
-
"""
|
225 |
-
|
226 |
-
def __init__(
|
227 |
-
self,
|
228 |
-
channels: int,
|
229 |
-
emb_channels: int,
|
230 |
-
dropout: float,
|
231 |
-
out_channels: Optional[int] = None,
|
232 |
-
use_conv: bool = False,
|
233 |
-
use_scale_shift_norm: bool = False,
|
234 |
-
dims: int = 2,
|
235 |
-
use_checkpoint: bool = False,
|
236 |
-
up: bool = False,
|
237 |
-
down: bool = False,
|
238 |
-
kernel_size: int = 3,
|
239 |
-
exchange_temb_dims: bool = False,
|
240 |
-
skip_t_emb: bool = False,
|
241 |
-
):
|
242 |
-
super().__init__()
|
243 |
-
self.channels = channels
|
244 |
-
self.emb_channels = emb_channels
|
245 |
-
self.dropout = dropout
|
246 |
-
self.out_channels = out_channels or channels
|
247 |
-
self.use_conv = use_conv
|
248 |
-
self.use_checkpoint = use_checkpoint
|
249 |
-
self.use_scale_shift_norm = use_scale_shift_norm
|
250 |
-
self.exchange_temb_dims = exchange_temb_dims
|
251 |
-
|
252 |
-
if isinstance(kernel_size, Iterable):
|
253 |
-
padding = [k // 2 for k in kernel_size]
|
254 |
-
else:
|
255 |
-
padding = kernel_size // 2
|
256 |
-
|
257 |
-
self.in_layers = nn.Sequential(
|
258 |
-
normalization(channels),
|
259 |
-
nn.SiLU(),
|
260 |
-
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
261 |
-
)
|
262 |
-
|
263 |
-
self.updown = up or down
|
264 |
-
|
265 |
-
if up:
|
266 |
-
self.h_upd = Upsample(channels, False, dims)
|
267 |
-
self.x_upd = Upsample(channels, False, dims)
|
268 |
-
elif down:
|
269 |
-
self.h_upd = Downsample(channels, False, dims)
|
270 |
-
self.x_upd = Downsample(channels, False, dims)
|
271 |
-
else:
|
272 |
-
self.h_upd = self.x_upd = nn.Identity()
|
273 |
-
|
274 |
-
self.skip_t_emb = skip_t_emb
|
275 |
-
self.emb_out_channels = (
|
276 |
-
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
277 |
-
)
|
278 |
-
if self.skip_t_emb:
|
279 |
-
logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}")
|
280 |
-
assert not self.use_scale_shift_norm
|
281 |
-
self.emb_layers = None
|
282 |
-
self.exchange_temb_dims = False
|
283 |
-
else:
|
284 |
-
self.emb_layers = nn.Sequential(
|
285 |
-
nn.SiLU(),
|
286 |
-
linear(
|
287 |
-
emb_channels,
|
288 |
-
self.emb_out_channels,
|
289 |
-
),
|
290 |
-
)
|
291 |
-
|
292 |
-
self.out_layers = nn.Sequential(
|
293 |
-
normalization(self.out_channels),
|
294 |
-
nn.SiLU(),
|
295 |
-
nn.Dropout(p=dropout),
|
296 |
-
zero_module(
|
297 |
-
conv_nd(
|
298 |
-
dims,
|
299 |
-
self.out_channels,
|
300 |
-
self.out_channels,
|
301 |
-
kernel_size,
|
302 |
-
padding=padding,
|
303 |
-
)
|
304 |
-
),
|
305 |
-
)
|
306 |
-
|
307 |
-
if self.out_channels == channels:
|
308 |
-
self.skip_connection = nn.Identity()
|
309 |
-
elif use_conv:
|
310 |
-
self.skip_connection = conv_nd(
|
311 |
-
dims, channels, self.out_channels, kernel_size, padding=padding
|
312 |
-
)
|
313 |
-
else:
|
314 |
-
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
315 |
-
|
316 |
-
def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
|
317 |
-
"""
|
318 |
-
Apply the block to a Tensor, conditioned on a timestep embedding.
|
319 |
-
:param x: an [N x C x ...] Tensor of features.
|
320 |
-
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
321 |
-
:return: an [N x C x ...] Tensor of outputs.
|
322 |
-
"""
|
323 |
-
if self.use_checkpoint:
|
324 |
-
return checkpoint(self._forward, x, emb)
|
325 |
-
else:
|
326 |
-
return self._forward(x, emb)
|
327 |
-
|
328 |
-
def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
|
329 |
-
if self.updown:
|
330 |
-
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
331 |
-
h = in_rest(x)
|
332 |
-
h = self.h_upd(h)
|
333 |
-
x = self.x_upd(x)
|
334 |
-
h = in_conv(h)
|
335 |
-
else:
|
336 |
-
h = self.in_layers(x)
|
337 |
-
|
338 |
-
if self.skip_t_emb:
|
339 |
-
emb_out = th.zeros_like(h)
|
340 |
-
else:
|
341 |
-
emb_out = self.emb_layers(emb).type(h.dtype)
|
342 |
-
while len(emb_out.shape) < len(h.shape):
|
343 |
-
emb_out = emb_out[..., None]
|
344 |
-
if self.use_scale_shift_norm:
|
345 |
-
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
346 |
-
scale, shift = th.chunk(emb_out, 2, dim=1)
|
347 |
-
h = out_norm(h) * (1 + scale) + shift
|
348 |
-
h = out_rest(h)
|
349 |
-
else:
|
350 |
-
if self.exchange_temb_dims:
|
351 |
-
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
352 |
-
h = h + emb_out
|
353 |
-
h = self.out_layers(h)
|
354 |
-
return self.skip_connection(x) + h
|
355 |
-
|
356 |
-
|
357 |
-
class AttentionBlock(nn.Module):
|
358 |
-
"""
|
359 |
-
An attention block that allows spatial positions to attend to each other.
|
360 |
-
Originally ported from here, but adapted to the N-d case.
|
361 |
-
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
362 |
-
"""
|
363 |
-
|
364 |
-
def __init__(
|
365 |
-
self,
|
366 |
-
channels: int,
|
367 |
-
num_heads: int = 1,
|
368 |
-
num_head_channels: int = -1,
|
369 |
-
use_checkpoint: bool = False,
|
370 |
-
use_new_attention_order: bool = False,
|
371 |
-
):
|
372 |
-
super().__init__()
|
373 |
-
self.channels = channels
|
374 |
-
if num_head_channels == -1:
|
375 |
-
self.num_heads = num_heads
|
376 |
-
else:
|
377 |
-
assert (
|
378 |
-
channels % num_head_channels == 0
|
379 |
-
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
380 |
-
self.num_heads = channels // num_head_channels
|
381 |
-
self.use_checkpoint = use_checkpoint
|
382 |
-
self.norm = normalization(channels)
|
383 |
-
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
384 |
-
if use_new_attention_order:
|
385 |
-
# split qkv before split heads
|
386 |
-
self.attention = QKVAttention(self.num_heads)
|
387 |
-
else:
|
388 |
-
# split heads before split qkv
|
389 |
-
self.attention = QKVAttentionLegacy(self.num_heads)
|
390 |
-
|
391 |
-
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
392 |
-
|
393 |
-
def forward(self, x: th.Tensor, **kwargs) -> th.Tensor:
|
394 |
-
return checkpoint(self._forward, x)
|
395 |
-
|
396 |
-
def _forward(self, x: th.Tensor) -> th.Tensor:
|
397 |
-
b, c, *spatial = x.shape
|
398 |
-
x = x.reshape(b, c, -1)
|
399 |
-
qkv = self.qkv(self.norm(x))
|
400 |
-
h = self.attention(qkv)
|
401 |
-
h = self.proj_out(h)
|
402 |
-
return (x + h).reshape(b, c, *spatial)
|
403 |
-
|
404 |
-
|
405 |
-
class QKVAttentionLegacy(nn.Module):
|
406 |
-
"""
|
407 |
-
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
408 |
-
"""
|
409 |
-
|
410 |
-
def __init__(self, n_heads: int):
|
411 |
-
super().__init__()
|
412 |
-
self.n_heads = n_heads
|
413 |
-
|
414 |
-
def forward(self, qkv: th.Tensor) -> th.Tensor:
|
415 |
-
"""
|
416 |
-
Apply QKV attention.
|
417 |
-
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
418 |
-
:return: an [N x (H * C) x T] tensor after attention.
|
419 |
-
"""
|
420 |
-
bs, width, length = qkv.shape
|
421 |
-
assert width % (3 * self.n_heads) == 0
|
422 |
-
ch = width // (3 * self.n_heads)
|
423 |
-
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
424 |
-
scale = 1 / math.sqrt(math.sqrt(ch))
|
425 |
-
weight = th.einsum(
|
426 |
-
"bct,bcs->bts", q * scale, k * scale
|
427 |
-
) # More stable with f16 than dividing afterwards
|
428 |
-
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
429 |
-
a = th.einsum("bts,bcs->bct", weight, v)
|
430 |
-
return a.reshape(bs, -1, length)
|
431 |
-
|
432 |
-
|
433 |
-
class QKVAttention(nn.Module):
|
434 |
-
"""
|
435 |
-
A module which performs QKV attention and splits in a different order.
|
436 |
-
"""
|
437 |
-
|
438 |
-
def __init__(self, n_heads: int):
|
439 |
-
super().__init__()
|
440 |
-
self.n_heads = n_heads
|
441 |
-
|
442 |
-
def forward(self, qkv: th.Tensor) -> th.Tensor:
|
443 |
-
"""
|
444 |
-
Apply QKV attention.
|
445 |
-
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
446 |
-
:return: an [N x (H * C) x T] tensor after attention.
|
447 |
-
"""
|
448 |
-
bs, width, length = qkv.shape
|
449 |
-
assert width % (3 * self.n_heads) == 0
|
450 |
-
ch = width // (3 * self.n_heads)
|
451 |
-
q, k, v = qkv.chunk(3, dim=1)
|
452 |
-
scale = 1 / math.sqrt(math.sqrt(ch))
|
453 |
-
weight = th.einsum(
|
454 |
-
"bct,bcs->bts",
|
455 |
-
(q * scale).view(bs * self.n_heads, ch, length),
|
456 |
-
(k * scale).view(bs * self.n_heads, ch, length),
|
457 |
-
) # More stable with f16 than dividing afterwards
|
458 |
-
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
459 |
-
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
460 |
-
return a.reshape(bs, -1, length)
|
461 |
-
|
462 |
-
|
463 |
-
class Timestep(nn.Module):
|
464 |
-
def __init__(self, dim: int):
|
465 |
-
super().__init__()
|
466 |
-
self.dim = dim
|
467 |
-
|
468 |
-
def forward(self, t: th.Tensor) -> th.Tensor:
|
469 |
-
return timestep_embedding(t, self.dim)
|
470 |
-
|
471 |
-
|
472 |
-
class UNetModel(nn.Module):
|
473 |
-
"""
|
474 |
-
The full UNet model with attention and timestep embedding.
|
475 |
-
:param in_channels: channels in the input Tensor.
|
476 |
-
:param model_channels: base channel count for the model.
|
477 |
-
:param out_channels: channels in the output Tensor.
|
478 |
-
:param num_res_blocks: number of residual blocks per downsample.
|
479 |
-
:param attention_resolutions: a collection of downsample rates at which
|
480 |
-
attention will take place. May be a set, list, or tuple.
|
481 |
-
For example, if this contains 4, then at 4x downsampling, attention
|
482 |
-
will be used.
|
483 |
-
:param dropout: the dropout probability.
|
484 |
-
:param channel_mult: channel multiplier for each level of the UNet.
|
485 |
-
:param conv_resample: if True, use learned convolutions for upsampling and
|
486 |
-
downsampling.
|
487 |
-
:param dims: determines if the signal is 1D, 2D, or 3D.
|
488 |
-
:param num_classes: if specified (as an int), then this model will be
|
489 |
-
class-conditional with `num_classes` classes.
|
490 |
-
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
491 |
-
:param num_heads: the number of attention heads in each attention layer.
|
492 |
-
:param num_heads_channels: if specified, ignore num_heads and instead use
|
493 |
-
a fixed channel width per attention head.
|
494 |
-
:param num_heads_upsample: works with num_heads to set a different number
|
495 |
-
of heads for upsampling. Deprecated.
|
496 |
-
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
497 |
-
:param resblock_updown: use residual blocks for up/downsampling.
|
498 |
-
:param use_new_attention_order: use a different attention pattern for potentially
|
499 |
-
increased efficiency.
|
500 |
-
"""
|
501 |
-
|
502 |
-
def __init__(
|
503 |
-
self,
|
504 |
-
in_channels: int,
|
505 |
-
model_channels: int,
|
506 |
-
out_channels: int,
|
507 |
-
num_res_blocks: int,
|
508 |
-
attention_resolutions: int,
|
509 |
-
dropout: float = 0.0,
|
510 |
-
channel_mult: Union[List, Tuple] = (1, 2, 4, 8),
|
511 |
-
conv_resample: bool = True,
|
512 |
-
dims: int = 2,
|
513 |
-
num_classes: Optional[Union[int, str]] = None,
|
514 |
-
use_checkpoint: bool = False,
|
515 |
-
num_heads: int = -1,
|
516 |
-
num_head_channels: int = -1,
|
517 |
-
num_heads_upsample: int = -1,
|
518 |
-
use_scale_shift_norm: bool = False,
|
519 |
-
resblock_updown: bool = False,
|
520 |
-
transformer_depth: int = 1,
|
521 |
-
context_dim: Optional[int] = None,
|
522 |
-
disable_self_attentions: Optional[List[bool]] = None,
|
523 |
-
num_attention_blocks: Optional[List[int]] = None,
|
524 |
-
disable_middle_self_attn: bool = False,
|
525 |
-
disable_middle_transformer: bool = False,
|
526 |
-
use_linear_in_transformer: bool = False,
|
527 |
-
spatial_transformer_attn_type: str = "softmax",
|
528 |
-
adm_in_channels: Optional[int] = None,
|
529 |
-
):
|
530 |
-
super().__init__()
|
531 |
-
|
532 |
-
if num_heads_upsample == -1:
|
533 |
-
num_heads_upsample = num_heads
|
534 |
-
|
535 |
-
if num_heads == -1:
|
536 |
-
assert (
|
537 |
-
num_head_channels != -1
|
538 |
-
), "Either num_heads or num_head_channels has to be set"
|
539 |
-
|
540 |
-
if num_head_channels == -1:
|
541 |
-
assert (
|
542 |
-
num_heads != -1
|
543 |
-
), "Either num_heads or num_head_channels has to be set"
|
544 |
-
|
545 |
-
self.in_channels = in_channels
|
546 |
-
self.model_channels = model_channels
|
547 |
-
self.out_channels = out_channels
|
548 |
-
if isinstance(transformer_depth, int):
|
549 |
-
transformer_depth = len(channel_mult) * [transformer_depth]
|
550 |
-
transformer_depth_middle = transformer_depth[-1]
|
551 |
-
|
552 |
-
if isinstance(num_res_blocks, int):
|
553 |
-
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
554 |
-
else:
|
555 |
-
if len(num_res_blocks) != len(channel_mult):
|
556 |
-
raise ValueError(
|
557 |
-
"provide num_res_blocks either as an int (globally constant) or "
|
558 |
-
"as a list/tuple (per-level) with the same length as channel_mult"
|
559 |
-
)
|
560 |
-
self.num_res_blocks = num_res_blocks
|
561 |
-
|
562 |
-
if disable_self_attentions is not None:
|
563 |
-
assert len(disable_self_attentions) == len(channel_mult)
|
564 |
-
if num_attention_blocks is not None:
|
565 |
-
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
566 |
-
assert all(
|
567 |
-
map(
|
568 |
-
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
|
569 |
-
range(len(num_attention_blocks)),
|
570 |
-
)
|
571 |
-
)
|
572 |
-
logpy.info(
|
573 |
-
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
574 |
-
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
575 |
-
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
576 |
-
f"attention will still not be set."
|
577 |
-
)
|
578 |
-
|
579 |
-
self.attention_resolutions = attention_resolutions
|
580 |
-
self.dropout = dropout
|
581 |
-
self.channel_mult = channel_mult
|
582 |
-
self.conv_resample = conv_resample
|
583 |
-
self.num_classes = num_classes
|
584 |
-
self.use_checkpoint = use_checkpoint
|
585 |
-
self.num_heads = num_heads
|
586 |
-
self.num_head_channels = num_head_channels
|
587 |
-
self.num_heads_upsample = num_heads_upsample
|
588 |
-
|
589 |
-
time_embed_dim = model_channels * 4
|
590 |
-
self.time_embed = nn.Sequential(
|
591 |
-
linear(model_channels, time_embed_dim),
|
592 |
-
nn.SiLU(),
|
593 |
-
linear(time_embed_dim, time_embed_dim),
|
594 |
-
)
|
595 |
-
|
596 |
-
if self.num_classes is not None:
|
597 |
-
if isinstance(self.num_classes, int):
|
598 |
-
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
599 |
-
elif self.num_classes == "continuous":
|
600 |
-
logpy.info("setting up linear c_adm embedding layer")
|
601 |
-
self.label_emb = nn.Linear(1, time_embed_dim)
|
602 |
-
elif self.num_classes == "timestep":
|
603 |
-
self.label_emb = nn.Sequential(
|
604 |
-
Timestep(model_channels),
|
605 |
-
nn.Sequential(
|
606 |
-
linear(model_channels, time_embed_dim),
|
607 |
-
nn.SiLU(),
|
608 |
-
linear(time_embed_dim, time_embed_dim),
|
609 |
-
),
|
610 |
-
)
|
611 |
-
elif self.num_classes == "sequential":
|
612 |
-
assert adm_in_channels is not None
|
613 |
-
self.label_emb = nn.Sequential(
|
614 |
-
nn.Sequential(
|
615 |
-
linear(adm_in_channels, time_embed_dim),
|
616 |
-
nn.SiLU(),
|
617 |
-
linear(time_embed_dim, time_embed_dim),
|
618 |
-
)
|
619 |
-
)
|
620 |
-
else:
|
621 |
-
raise ValueError
|
622 |
-
|
623 |
-
self.input_blocks = nn.ModuleList(
|
624 |
-
[
|
625 |
-
TimestepEmbedSequential(
|
626 |
-
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
627 |
-
)
|
628 |
-
]
|
629 |
-
)
|
630 |
-
self._feature_size = model_channels
|
631 |
-
input_block_chans = [model_channels]
|
632 |
-
ch = model_channels
|
633 |
-
ds = 1
|
634 |
-
for level, mult in enumerate(channel_mult):
|
635 |
-
for nr in range(self.num_res_blocks[level]):
|
636 |
-
layers = [
|
637 |
-
ResBlock(
|
638 |
-
ch,
|
639 |
-
time_embed_dim,
|
640 |
-
dropout,
|
641 |
-
out_channels=mult * model_channels,
|
642 |
-
dims=dims,
|
643 |
-
use_checkpoint=use_checkpoint,
|
644 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
645 |
-
)
|
646 |
-
]
|
647 |
-
ch = mult * model_channels
|
648 |
-
if ds in attention_resolutions:
|
649 |
-
if num_head_channels == -1:
|
650 |
-
dim_head = ch // num_heads
|
651 |
-
else:
|
652 |
-
num_heads = ch // num_head_channels
|
653 |
-
dim_head = num_head_channels
|
654 |
-
|
655 |
-
if context_dim is not None and exists(disable_self_attentions):
|
656 |
-
disabled_sa = disable_self_attentions[level]
|
657 |
-
else:
|
658 |
-
disabled_sa = False
|
659 |
-
|
660 |
-
if (
|
661 |
-
not exists(num_attention_blocks)
|
662 |
-
or nr < num_attention_blocks[level]
|
663 |
-
):
|
664 |
-
layers.append(
|
665 |
-
SpatialTransformer(
|
666 |
-
ch,
|
667 |
-
num_heads,
|
668 |
-
dim_head,
|
669 |
-
depth=transformer_depth[level],
|
670 |
-
context_dim=context_dim,
|
671 |
-
disable_self_attn=disabled_sa,
|
672 |
-
use_linear=use_linear_in_transformer,
|
673 |
-
attn_type=spatial_transformer_attn_type,
|
674 |
-
use_checkpoint=use_checkpoint,
|
675 |
-
)
|
676 |
-
)
|
677 |
-
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
678 |
-
self._feature_size += ch
|
679 |
-
input_block_chans.append(ch)
|
680 |
-
if level != len(channel_mult) - 1:
|
681 |
-
out_ch = ch
|
682 |
-
self.input_blocks.append(
|
683 |
-
TimestepEmbedSequential(
|
684 |
-
ResBlock(
|
685 |
-
ch,
|
686 |
-
time_embed_dim,
|
687 |
-
dropout,
|
688 |
-
out_channels=out_ch,
|
689 |
-
dims=dims,
|
690 |
-
use_checkpoint=use_checkpoint,
|
691 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
692 |
-
down=True,
|
693 |
-
)
|
694 |
-
if resblock_updown
|
695 |
-
else Downsample(
|
696 |
-
ch, conv_resample, dims=dims, out_channels=out_ch
|
697 |
-
)
|
698 |
-
)
|
699 |
-
)
|
700 |
-
ch = out_ch
|
701 |
-
input_block_chans.append(ch)
|
702 |
-
ds *= 2
|
703 |
-
self._feature_size += ch
|
704 |
-
|
705 |
-
if num_head_channels == -1:
|
706 |
-
dim_head = ch // num_heads
|
707 |
-
else:
|
708 |
-
num_heads = ch // num_head_channels
|
709 |
-
dim_head = num_head_channels
|
710 |
-
|
711 |
-
self.middle_block = TimestepEmbedSequential(
|
712 |
-
ResBlock(
|
713 |
-
ch,
|
714 |
-
time_embed_dim,
|
715 |
-
dropout,
|
716 |
-
out_channels=ch,
|
717 |
-
dims=dims,
|
718 |
-
use_checkpoint=use_checkpoint,
|
719 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
720 |
-
),
|
721 |
-
SpatialTransformer(
|
722 |
-
ch,
|
723 |
-
num_heads,
|
724 |
-
dim_head,
|
725 |
-
depth=transformer_depth_middle,
|
726 |
-
context_dim=context_dim,
|
727 |
-
disable_self_attn=disable_middle_self_attn,
|
728 |
-
use_linear=use_linear_in_transformer,
|
729 |
-
attn_type=spatial_transformer_attn_type,
|
730 |
-
use_checkpoint=use_checkpoint,
|
731 |
-
)
|
732 |
-
if not disable_middle_transformer
|
733 |
-
else th.nn.Identity(),
|
734 |
-
ResBlock(
|
735 |
-
ch,
|
736 |
-
time_embed_dim,
|
737 |
-
dropout,
|
738 |
-
dims=dims,
|
739 |
-
use_checkpoint=use_checkpoint,
|
740 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
741 |
-
),
|
742 |
-
)
|
743 |
-
self._feature_size += ch
|
744 |
-
|
745 |
-
self.output_blocks = nn.ModuleList([])
|
746 |
-
for level, mult in list(enumerate(channel_mult))[::-1]:
|
747 |
-
for i in range(self.num_res_blocks[level] + 1):
|
748 |
-
ich = input_block_chans.pop()
|
749 |
-
layers = [
|
750 |
-
ResBlock(
|
751 |
-
ch + ich,
|
752 |
-
time_embed_dim,
|
753 |
-
dropout,
|
754 |
-
out_channels=model_channels * mult,
|
755 |
-
dims=dims,
|
756 |
-
use_checkpoint=use_checkpoint,
|
757 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
758 |
-
)
|
759 |
-
]
|
760 |
-
ch = model_channels * mult
|
761 |
-
if ds in attention_resolutions:
|
762 |
-
if num_head_channels == -1:
|
763 |
-
dim_head = ch // num_heads
|
764 |
-
else:
|
765 |
-
num_heads = ch // num_head_channels
|
766 |
-
dim_head = num_head_channels
|
767 |
-
|
768 |
-
if exists(disable_self_attentions):
|
769 |
-
disabled_sa = disable_self_attentions[level]
|
770 |
-
else:
|
771 |
-
disabled_sa = False
|
772 |
-
|
773 |
-
if (
|
774 |
-
not exists(num_attention_blocks)
|
775 |
-
or i < num_attention_blocks[level]
|
776 |
-
):
|
777 |
-
layers.append(
|
778 |
-
SpatialTransformer(
|
779 |
-
ch,
|
780 |
-
num_heads,
|
781 |
-
dim_head,
|
782 |
-
depth=transformer_depth[level],
|
783 |
-
context_dim=context_dim,
|
784 |
-
disable_self_attn=disabled_sa,
|
785 |
-
use_linear=use_linear_in_transformer,
|
786 |
-
attn_type=spatial_transformer_attn_type,
|
787 |
-
use_checkpoint=use_checkpoint,
|
788 |
-
)
|
789 |
-
)
|
790 |
-
if level and i == self.num_res_blocks[level]:
|
791 |
-
out_ch = ch
|
792 |
-
layers.append(
|
793 |
-
ResBlock(
|
794 |
-
ch,
|
795 |
-
time_embed_dim,
|
796 |
-
dropout,
|
797 |
-
out_channels=out_ch,
|
798 |
-
dims=dims,
|
799 |
-
use_checkpoint=use_checkpoint,
|
800 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
801 |
-
up=True,
|
802 |
-
)
|
803 |
-
if resblock_updown
|
804 |
-
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
805 |
-
)
|
806 |
-
ds //= 2
|
807 |
-
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
808 |
-
self._feature_size += ch
|
809 |
-
|
810 |
-
self.out = nn.Sequential(
|
811 |
-
normalization(ch),
|
812 |
-
nn.SiLU(),
|
813 |
-
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
814 |
-
)
|
815 |
-
|
816 |
-
def forward(
|
817 |
-
self,
|
818 |
-
x: th.Tensor,
|
819 |
-
timesteps: Optional[th.Tensor] = None,
|
820 |
-
context: Optional[th.Tensor] = None,
|
821 |
-
y: Optional[th.Tensor] = None,
|
822 |
-
**kwargs,
|
823 |
-
) -> th.Tensor:
|
824 |
-
"""
|
825 |
-
Apply the model to an input batch.
|
826 |
-
:param x: an [N x C x ...] Tensor of inputs.
|
827 |
-
:param timesteps: a 1-D batch of timesteps.
|
828 |
-
:param context: conditioning plugged in via crossattn
|
829 |
-
:param y: an [N] Tensor of labels, if class-conditional.
|
830 |
-
:return: an [N x C x ...] Tensor of outputs.
|
831 |
-
"""
|
832 |
-
assert (y is not None) == (
|
833 |
-
self.num_classes is not None
|
834 |
-
), "must specify y if and only if the model is class-conditional"
|
835 |
-
hs = []
|
836 |
-
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
837 |
-
emb = self.time_embed(t_emb)
|
838 |
-
|
839 |
-
if self.num_classes is not None:
|
840 |
-
assert y.shape[0] == x.shape[0]
|
841 |
-
emb = emb + self.label_emb(y)
|
842 |
-
|
843 |
-
h = x
|
844 |
-
for module in self.input_blocks:
|
845 |
-
h = module(h, emb, context)
|
846 |
-
hs.append(h)
|
847 |
-
h = self.middle_block(h, emb, context)
|
848 |
-
for module in self.output_blocks:
|
849 |
-
h = th.cat([h, hs.pop()], dim=1)
|
850 |
-
h = module(h, emb, context)
|
851 |
-
h = h.type(x.dtype)
|
852 |
-
|
853 |
-
return self.out(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/sampling.py
DELETED
@@ -1,362 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
|
3 |
-
"""
|
4 |
-
|
5 |
-
|
6 |
-
from typing import Dict, Union
|
7 |
-
|
8 |
-
import torch
|
9 |
-
from omegaconf import ListConfig, OmegaConf
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step,
|
13 |
-
linear_multistep_coeff,
|
14 |
-
to_d, to_neg_log_sigma,
|
15 |
-
to_sigma)
|
16 |
-
from ...util import append_dims, default, instantiate_from_config
|
17 |
-
|
18 |
-
DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
19 |
-
|
20 |
-
|
21 |
-
class BaseDiffusionSampler:
|
22 |
-
def __init__(
|
23 |
-
self,
|
24 |
-
discretization_config: Union[Dict, ListConfig, OmegaConf],
|
25 |
-
num_steps: Union[int, None] = None,
|
26 |
-
guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
|
27 |
-
verbose: bool = False,
|
28 |
-
device: str = "cuda",
|
29 |
-
):
|
30 |
-
self.num_steps = num_steps
|
31 |
-
self.discretization = instantiate_from_config(discretization_config)
|
32 |
-
self.guider = instantiate_from_config(
|
33 |
-
default(
|
34 |
-
guider_config,
|
35 |
-
DEFAULT_GUIDER,
|
36 |
-
)
|
37 |
-
)
|
38 |
-
self.verbose = verbose
|
39 |
-
self.device = device
|
40 |
-
|
41 |
-
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
42 |
-
sigmas = self.discretization(
|
43 |
-
self.num_steps if num_steps is None else num_steps, device=self.device
|
44 |
-
)
|
45 |
-
uc = default(uc, cond)
|
46 |
-
|
47 |
-
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
48 |
-
num_sigmas = len(sigmas)
|
49 |
-
|
50 |
-
s_in = x.new_ones([x.shape[0]])
|
51 |
-
|
52 |
-
return x, s_in, sigmas, num_sigmas, cond, uc
|
53 |
-
|
54 |
-
def denoise(self, x, denoiser, sigma, cond, uc):
|
55 |
-
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
|
56 |
-
denoised = self.guider(denoised, sigma)
|
57 |
-
return denoised
|
58 |
-
|
59 |
-
def get_sigma_gen(self, num_sigmas):
|
60 |
-
sigma_generator = range(num_sigmas - 1)
|
61 |
-
if self.verbose:
|
62 |
-
print("#" * 30, " Sampling setting ", "#" * 30)
|
63 |
-
print(f"Sampler: {self.__class__.__name__}")
|
64 |
-
print(f"Discretization: {self.discretization.__class__.__name__}")
|
65 |
-
print(f"Guider: {self.guider.__class__.__name__}")
|
66 |
-
sigma_generator = tqdm(
|
67 |
-
sigma_generator,
|
68 |
-
total=num_sigmas,
|
69 |
-
desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
|
70 |
-
)
|
71 |
-
return sigma_generator
|
72 |
-
|
73 |
-
|
74 |
-
class SingleStepDiffusionSampler(BaseDiffusionSampler):
|
75 |
-
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
|
76 |
-
raise NotImplementedError
|
77 |
-
|
78 |
-
def euler_step(self, x, d, dt):
|
79 |
-
return x + dt * d
|
80 |
-
|
81 |
-
|
82 |
-
class EDMSampler(SingleStepDiffusionSampler):
|
83 |
-
def __init__(
|
84 |
-
self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
|
85 |
-
):
|
86 |
-
super().__init__(*args, **kwargs)
|
87 |
-
|
88 |
-
self.s_churn = s_churn
|
89 |
-
self.s_tmin = s_tmin
|
90 |
-
self.s_tmax = s_tmax
|
91 |
-
self.s_noise = s_noise
|
92 |
-
|
93 |
-
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
|
94 |
-
sigma_hat = sigma * (gamma + 1.0)
|
95 |
-
if gamma > 0:
|
96 |
-
eps = torch.randn_like(x) * self.s_noise
|
97 |
-
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
|
98 |
-
|
99 |
-
denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
|
100 |
-
d = to_d(x, sigma_hat, denoised)
|
101 |
-
dt = append_dims(next_sigma - sigma_hat, x.ndim)
|
102 |
-
|
103 |
-
euler_step = self.euler_step(x, d, dt)
|
104 |
-
x = self.possible_correction_step(
|
105 |
-
euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
106 |
-
)
|
107 |
-
return x
|
108 |
-
|
109 |
-
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
110 |
-
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
111 |
-
x, cond, uc, num_steps
|
112 |
-
)
|
113 |
-
|
114 |
-
for i in self.get_sigma_gen(num_sigmas):
|
115 |
-
gamma = (
|
116 |
-
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
|
117 |
-
if self.s_tmin <= sigmas[i] <= self.s_tmax
|
118 |
-
else 0.0
|
119 |
-
)
|
120 |
-
x = self.sampler_step(
|
121 |
-
s_in * sigmas[i],
|
122 |
-
s_in * sigmas[i + 1],
|
123 |
-
denoiser,
|
124 |
-
x,
|
125 |
-
cond,
|
126 |
-
uc,
|
127 |
-
gamma,
|
128 |
-
)
|
129 |
-
|
130 |
-
return x
|
131 |
-
|
132 |
-
|
133 |
-
class AncestralSampler(SingleStepDiffusionSampler):
|
134 |
-
def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
|
135 |
-
super().__init__(*args, **kwargs)
|
136 |
-
|
137 |
-
self.eta = eta
|
138 |
-
self.s_noise = s_noise
|
139 |
-
self.noise_sampler = lambda x: torch.randn_like(x)
|
140 |
-
|
141 |
-
def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
|
142 |
-
d = to_d(x, sigma, denoised)
|
143 |
-
dt = append_dims(sigma_down - sigma, x.ndim)
|
144 |
-
|
145 |
-
return self.euler_step(x, d, dt)
|
146 |
-
|
147 |
-
def ancestral_step(self, x, sigma, next_sigma, sigma_up):
|
148 |
-
x = torch.where(
|
149 |
-
append_dims(next_sigma, x.ndim) > 0.0,
|
150 |
-
x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
|
151 |
-
x,
|
152 |
-
)
|
153 |
-
return x
|
154 |
-
|
155 |
-
def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
|
156 |
-
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
157 |
-
x, cond, uc, num_steps
|
158 |
-
)
|
159 |
-
|
160 |
-
for i in self.get_sigma_gen(num_sigmas):
|
161 |
-
x = self.sampler_step(
|
162 |
-
s_in * sigmas[i],
|
163 |
-
s_in * sigmas[i + 1],
|
164 |
-
denoiser,
|
165 |
-
x,
|
166 |
-
cond,
|
167 |
-
uc,
|
168 |
-
)
|
169 |
-
|
170 |
-
return x
|
171 |
-
|
172 |
-
|
173 |
-
class LinearMultistepSampler(BaseDiffusionSampler):
|
174 |
-
def __init__(
|
175 |
-
self,
|
176 |
-
order=4,
|
177 |
-
*args,
|
178 |
-
**kwargs,
|
179 |
-
):
|
180 |
-
super().__init__(*args, **kwargs)
|
181 |
-
|
182 |
-
self.order = order
|
183 |
-
|
184 |
-
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
185 |
-
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
186 |
-
x, cond, uc, num_steps
|
187 |
-
)
|
188 |
-
|
189 |
-
ds = []
|
190 |
-
sigmas_cpu = sigmas.detach().cpu().numpy()
|
191 |
-
for i in self.get_sigma_gen(num_sigmas):
|
192 |
-
sigma = s_in * sigmas[i]
|
193 |
-
denoised = denoiser(
|
194 |
-
*self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
|
195 |
-
)
|
196 |
-
denoised = self.guider(denoised, sigma)
|
197 |
-
d = to_d(x, sigma, denoised)
|
198 |
-
ds.append(d)
|
199 |
-
if len(ds) > self.order:
|
200 |
-
ds.pop(0)
|
201 |
-
cur_order = min(i + 1, self.order)
|
202 |
-
coeffs = [
|
203 |
-
linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
|
204 |
-
for j in range(cur_order)
|
205 |
-
]
|
206 |
-
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
207 |
-
|
208 |
-
return x
|
209 |
-
|
210 |
-
|
211 |
-
class EulerEDMSampler(EDMSampler):
|
212 |
-
def possible_correction_step(
|
213 |
-
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
214 |
-
):
|
215 |
-
return euler_step
|
216 |
-
|
217 |
-
|
218 |
-
class HeunEDMSampler(EDMSampler):
|
219 |
-
def possible_correction_step(
|
220 |
-
self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
|
221 |
-
):
|
222 |
-
if torch.sum(next_sigma) < 1e-14:
|
223 |
-
# Save a network evaluation if all noise levels are 0
|
224 |
-
return euler_step
|
225 |
-
else:
|
226 |
-
denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
|
227 |
-
d_new = to_d(euler_step, next_sigma, denoised)
|
228 |
-
d_prime = (d + d_new) / 2.0
|
229 |
-
|
230 |
-
# apply correction if noise level is not 0
|
231 |
-
x = torch.where(
|
232 |
-
append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
|
233 |
-
)
|
234 |
-
return x
|
235 |
-
|
236 |
-
|
237 |
-
class EulerAncestralSampler(AncestralSampler):
|
238 |
-
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
|
239 |
-
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
|
240 |
-
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
241 |
-
x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
|
242 |
-
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
|
243 |
-
|
244 |
-
return x
|
245 |
-
|
246 |
-
|
247 |
-
class DPMPP2SAncestralSampler(AncestralSampler):
|
248 |
-
def get_variables(self, sigma, sigma_down):
|
249 |
-
t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
|
250 |
-
h = t_next - t
|
251 |
-
s = t + 0.5 * h
|
252 |
-
return h, s, t, t_next
|
253 |
-
|
254 |
-
def get_mult(self, h, s, t, t_next):
|
255 |
-
mult1 = to_sigma(s) / to_sigma(t)
|
256 |
-
mult2 = (-0.5 * h).expm1()
|
257 |
-
mult3 = to_sigma(t_next) / to_sigma(t)
|
258 |
-
mult4 = (-h).expm1()
|
259 |
-
|
260 |
-
return mult1, mult2, mult3, mult4
|
261 |
-
|
262 |
-
def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
|
263 |
-
sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
|
264 |
-
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
265 |
-
x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
|
266 |
-
|
267 |
-
if torch.sum(sigma_down) < 1e-14:
|
268 |
-
# Save a network evaluation if all noise levels are 0
|
269 |
-
x = x_euler
|
270 |
-
else:
|
271 |
-
h, s, t, t_next = self.get_variables(sigma, sigma_down)
|
272 |
-
mult = [
|
273 |
-
append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
|
274 |
-
]
|
275 |
-
|
276 |
-
x2 = mult[0] * x - mult[1] * denoised
|
277 |
-
denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
|
278 |
-
x_dpmpp2s = mult[2] * x - mult[3] * denoised2
|
279 |
-
|
280 |
-
# apply correction if noise level is not 0
|
281 |
-
x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
|
282 |
-
|
283 |
-
x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
|
284 |
-
return x
|
285 |
-
|
286 |
-
|
287 |
-
class DPMPP2MSampler(BaseDiffusionSampler):
|
288 |
-
def get_variables(self, sigma, next_sigma, previous_sigma=None):
|
289 |
-
t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
|
290 |
-
h = t_next - t
|
291 |
-
|
292 |
-
if previous_sigma is not None:
|
293 |
-
h_last = t - to_neg_log_sigma(previous_sigma)
|
294 |
-
r = h_last / h
|
295 |
-
return h, r, t, t_next
|
296 |
-
else:
|
297 |
-
return h, None, t, t_next
|
298 |
-
|
299 |
-
def get_mult(self, h, r, t, t_next, previous_sigma):
|
300 |
-
mult1 = to_sigma(t_next) / to_sigma(t)
|
301 |
-
mult2 = (-h).expm1()
|
302 |
-
|
303 |
-
if previous_sigma is not None:
|
304 |
-
mult3 = 1 + 1 / (2 * r)
|
305 |
-
mult4 = 1 / (2 * r)
|
306 |
-
return mult1, mult2, mult3, mult4
|
307 |
-
else:
|
308 |
-
return mult1, mult2
|
309 |
-
|
310 |
-
def sampler_step(
|
311 |
-
self,
|
312 |
-
old_denoised,
|
313 |
-
previous_sigma,
|
314 |
-
sigma,
|
315 |
-
next_sigma,
|
316 |
-
denoiser,
|
317 |
-
x,
|
318 |
-
cond,
|
319 |
-
uc=None,
|
320 |
-
):
|
321 |
-
denoised = self.denoise(x, denoiser, sigma, cond, uc)
|
322 |
-
|
323 |
-
h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
|
324 |
-
mult = [
|
325 |
-
append_dims(mult, x.ndim)
|
326 |
-
for mult in self.get_mult(h, r, t, t_next, previous_sigma)
|
327 |
-
]
|
328 |
-
|
329 |
-
x_standard = mult[0] * x - mult[1] * denoised
|
330 |
-
if old_denoised is None or torch.sum(next_sigma) < 1e-14:
|
331 |
-
# Save a network evaluation if all noise levels are 0 or on the first step
|
332 |
-
return x_standard, denoised
|
333 |
-
else:
|
334 |
-
denoised_d = mult[2] * denoised - mult[3] * old_denoised
|
335 |
-
x_advanced = mult[0] * x - mult[1] * denoised_d
|
336 |
-
|
337 |
-
# apply correction if noise level is not 0 and not first step
|
338 |
-
x = torch.where(
|
339 |
-
append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
|
340 |
-
)
|
341 |
-
|
342 |
-
return x, denoised
|
343 |
-
|
344 |
-
def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
|
345 |
-
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
346 |
-
x, cond, uc, num_steps
|
347 |
-
)
|
348 |
-
|
349 |
-
old_denoised = None
|
350 |
-
for i in self.get_sigma_gen(num_sigmas):
|
351 |
-
x, old_denoised = self.sampler_step(
|
352 |
-
old_denoised,
|
353 |
-
None if i == 0 else s_in * sigmas[i - 1],
|
354 |
-
s_in * sigmas[i],
|
355 |
-
s_in * sigmas[i + 1],
|
356 |
-
denoiser,
|
357 |
-
x,
|
358 |
-
cond,
|
359 |
-
uc=uc,
|
360 |
-
)
|
361 |
-
|
362 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/sampling_utils.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from scipy import integrate
|
3 |
-
|
4 |
-
from ...util import append_dims
|
5 |
-
|
6 |
-
|
7 |
-
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
|
8 |
-
if order - 1 > i:
|
9 |
-
raise ValueError(f"Order {order} too high for step {i}")
|
10 |
-
|
11 |
-
def fn(tau):
|
12 |
-
prod = 1.0
|
13 |
-
for k in range(order):
|
14 |
-
if j == k:
|
15 |
-
continue
|
16 |
-
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
17 |
-
return prod
|
18 |
-
|
19 |
-
return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
|
20 |
-
|
21 |
-
|
22 |
-
def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
|
23 |
-
if not eta:
|
24 |
-
return sigma_to, 0.0
|
25 |
-
sigma_up = torch.minimum(
|
26 |
-
sigma_to,
|
27 |
-
eta
|
28 |
-
* (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
|
29 |
-
)
|
30 |
-
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
31 |
-
return sigma_down, sigma_up
|
32 |
-
|
33 |
-
|
34 |
-
def to_d(x, sigma, denoised):
|
35 |
-
return (x - denoised) / append_dims(sigma, x.ndim)
|
36 |
-
|
37 |
-
|
38 |
-
def to_neg_log_sigma(sigma):
|
39 |
-
return sigma.log().neg()
|
40 |
-
|
41 |
-
|
42 |
-
def to_sigma(neg_log_sigma):
|
43 |
-
return neg_log_sigma.neg().exp()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/sigma_sampling.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
from ...util import default, instantiate_from_config
|
4 |
-
|
5 |
-
|
6 |
-
class EDMSampling:
|
7 |
-
def __init__(self, p_mean=-1.2, p_std=1.2):
|
8 |
-
self.p_mean = p_mean
|
9 |
-
self.p_std = p_std
|
10 |
-
|
11 |
-
def __call__(self, n_samples, rand=None):
|
12 |
-
log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
|
13 |
-
return log_sigma.exp()
|
14 |
-
|
15 |
-
|
16 |
-
class DiscreteSampling:
|
17 |
-
def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
|
18 |
-
self.num_idx = num_idx
|
19 |
-
self.sigmas = instantiate_from_config(discretization_config)(
|
20 |
-
num_idx, do_append_zero=do_append_zero, flip=flip
|
21 |
-
)
|
22 |
-
|
23 |
-
def idx_to_sigma(self, idx):
|
24 |
-
return self.sigmas[idx]
|
25 |
-
|
26 |
-
def __call__(self, n_samples, rand=None):
|
27 |
-
idx = default(
|
28 |
-
rand,
|
29 |
-
torch.randint(0, self.num_idx, (n_samples,)),
|
30 |
-
)
|
31 |
-
return self.idx_to_sigma(idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/util.py
DELETED
@@ -1,369 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
partially adopted from
|
3 |
-
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
4 |
-
and
|
5 |
-
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
6 |
-
and
|
7 |
-
https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
8 |
-
|
9 |
-
thanks!
|
10 |
-
"""
|
11 |
-
|
12 |
-
import math
|
13 |
-
from typing import Optional
|
14 |
-
|
15 |
-
import torch
|
16 |
-
import torch.nn as nn
|
17 |
-
from einops import rearrange, repeat
|
18 |
-
|
19 |
-
|
20 |
-
def make_beta_schedule(
|
21 |
-
schedule,
|
22 |
-
n_timestep,
|
23 |
-
linear_start=1e-4,
|
24 |
-
linear_end=2e-2,
|
25 |
-
):
|
26 |
-
if schedule == "linear":
|
27 |
-
betas = (
|
28 |
-
torch.linspace(
|
29 |
-
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
30 |
-
)
|
31 |
-
** 2
|
32 |
-
)
|
33 |
-
return betas.numpy()
|
34 |
-
|
35 |
-
|
36 |
-
def extract_into_tensor(a, t, x_shape):
|
37 |
-
b, *_ = t.shape
|
38 |
-
out = a.gather(-1, t)
|
39 |
-
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
40 |
-
|
41 |
-
|
42 |
-
def mixed_checkpoint(func, inputs: dict, params, flag):
|
43 |
-
"""
|
44 |
-
Evaluate a function without caching intermediate activations, allowing for
|
45 |
-
reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
|
46 |
-
borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
|
47 |
-
it also works with non-tensor inputs
|
48 |
-
:param func: the function to evaluate.
|
49 |
-
:param inputs: the argument dictionary to pass to `func`.
|
50 |
-
:param params: a sequence of parameters `func` depends on but does not
|
51 |
-
explicitly take as arguments.
|
52 |
-
:param flag: if False, disable gradient checkpointing.
|
53 |
-
"""
|
54 |
-
if flag:
|
55 |
-
tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
|
56 |
-
tensor_inputs = [
|
57 |
-
inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
|
58 |
-
]
|
59 |
-
non_tensor_keys = [
|
60 |
-
key for key in inputs if not isinstance(inputs[key], torch.Tensor)
|
61 |
-
]
|
62 |
-
non_tensor_inputs = [
|
63 |
-
inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
|
64 |
-
]
|
65 |
-
args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
|
66 |
-
return MixedCheckpointFunction.apply(
|
67 |
-
func,
|
68 |
-
len(tensor_inputs),
|
69 |
-
len(non_tensor_inputs),
|
70 |
-
tensor_keys,
|
71 |
-
non_tensor_keys,
|
72 |
-
*args,
|
73 |
-
)
|
74 |
-
else:
|
75 |
-
return func(**inputs)
|
76 |
-
|
77 |
-
|
78 |
-
class MixedCheckpointFunction(torch.autograd.Function):
|
79 |
-
@staticmethod
|
80 |
-
def forward(
|
81 |
-
ctx,
|
82 |
-
run_function,
|
83 |
-
length_tensors,
|
84 |
-
length_non_tensors,
|
85 |
-
tensor_keys,
|
86 |
-
non_tensor_keys,
|
87 |
-
*args,
|
88 |
-
):
|
89 |
-
ctx.end_tensors = length_tensors
|
90 |
-
ctx.end_non_tensors = length_tensors + length_non_tensors
|
91 |
-
ctx.gpu_autocast_kwargs = {
|
92 |
-
"enabled": torch.is_autocast_enabled(),
|
93 |
-
"dtype": torch.get_autocast_gpu_dtype(),
|
94 |
-
"cache_enabled": torch.is_autocast_cache_enabled(),
|
95 |
-
}
|
96 |
-
assert (
|
97 |
-
len(tensor_keys) == length_tensors
|
98 |
-
and len(non_tensor_keys) == length_non_tensors
|
99 |
-
)
|
100 |
-
|
101 |
-
ctx.input_tensors = {
|
102 |
-
key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
|
103 |
-
}
|
104 |
-
ctx.input_non_tensors = {
|
105 |
-
key: val
|
106 |
-
for (key, val) in zip(
|
107 |
-
non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
|
108 |
-
)
|
109 |
-
}
|
110 |
-
ctx.run_function = run_function
|
111 |
-
ctx.input_params = list(args[ctx.end_non_tensors :])
|
112 |
-
|
113 |
-
with torch.no_grad():
|
114 |
-
output_tensors = ctx.run_function(
|
115 |
-
**ctx.input_tensors, **ctx.input_non_tensors
|
116 |
-
)
|
117 |
-
return output_tensors
|
118 |
-
|
119 |
-
@staticmethod
|
120 |
-
def backward(ctx, *output_grads):
|
121 |
-
# additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
|
122 |
-
ctx.input_tensors = {
|
123 |
-
key: ctx.input_tensors[key].detach().requires_grad_(True)
|
124 |
-
for key in ctx.input_tensors
|
125 |
-
}
|
126 |
-
|
127 |
-
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
128 |
-
# Fixes a bug where the first op in run_function modifies the
|
129 |
-
# Tensor storage in place, which is not allowed for detach()'d
|
130 |
-
# Tensors.
|
131 |
-
shallow_copies = {
|
132 |
-
key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
|
133 |
-
for key in ctx.input_tensors
|
134 |
-
}
|
135 |
-
# shallow_copies.update(additional_args)
|
136 |
-
output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
|
137 |
-
input_grads = torch.autograd.grad(
|
138 |
-
output_tensors,
|
139 |
-
list(ctx.input_tensors.values()) + ctx.input_params,
|
140 |
-
output_grads,
|
141 |
-
allow_unused=True,
|
142 |
-
)
|
143 |
-
del ctx.input_tensors
|
144 |
-
del ctx.input_params
|
145 |
-
del output_tensors
|
146 |
-
return (
|
147 |
-
(None, None, None, None, None)
|
148 |
-
+ input_grads[: ctx.end_tensors]
|
149 |
-
+ (None,) * (ctx.end_non_tensors - ctx.end_tensors)
|
150 |
-
+ input_grads[ctx.end_tensors :]
|
151 |
-
)
|
152 |
-
|
153 |
-
|
154 |
-
def checkpoint(func, inputs, params, flag):
|
155 |
-
"""
|
156 |
-
Evaluate a function without caching intermediate activations, allowing for
|
157 |
-
reduced memory at the expense of extra compute in the backward pass.
|
158 |
-
:param func: the function to evaluate.
|
159 |
-
:param inputs: the argument sequence to pass to `func`.
|
160 |
-
:param params: a sequence of parameters `func` depends on but does not
|
161 |
-
explicitly take as arguments.
|
162 |
-
:param flag: if False, disable gradient checkpointing.
|
163 |
-
"""
|
164 |
-
if flag:
|
165 |
-
args = tuple(inputs) + tuple(params)
|
166 |
-
return CheckpointFunction.apply(func, len(inputs), *args)
|
167 |
-
else:
|
168 |
-
return func(*inputs)
|
169 |
-
|
170 |
-
|
171 |
-
class CheckpointFunction(torch.autograd.Function):
|
172 |
-
@staticmethod
|
173 |
-
def forward(ctx, run_function, length, *args):
|
174 |
-
ctx.run_function = run_function
|
175 |
-
ctx.input_tensors = list(args[:length])
|
176 |
-
ctx.input_params = list(args[length:])
|
177 |
-
ctx.gpu_autocast_kwargs = {
|
178 |
-
"enabled": torch.is_autocast_enabled(),
|
179 |
-
"dtype": torch.get_autocast_gpu_dtype(),
|
180 |
-
"cache_enabled": torch.is_autocast_cache_enabled(),
|
181 |
-
}
|
182 |
-
with torch.no_grad():
|
183 |
-
output_tensors = ctx.run_function(*ctx.input_tensors)
|
184 |
-
return output_tensors
|
185 |
-
|
186 |
-
@staticmethod
|
187 |
-
def backward(ctx, *output_grads):
|
188 |
-
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
189 |
-
with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
190 |
-
# Fixes a bug where the first op in run_function modifies the
|
191 |
-
# Tensor storage in place, which is not allowed for detach()'d
|
192 |
-
# Tensors.
|
193 |
-
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
194 |
-
output_tensors = ctx.run_function(*shallow_copies)
|
195 |
-
input_grads = torch.autograd.grad(
|
196 |
-
output_tensors,
|
197 |
-
ctx.input_tensors + ctx.input_params,
|
198 |
-
output_grads,
|
199 |
-
allow_unused=True,
|
200 |
-
)
|
201 |
-
del ctx.input_tensors
|
202 |
-
del ctx.input_params
|
203 |
-
del output_tensors
|
204 |
-
return (None, None) + input_grads
|
205 |
-
|
206 |
-
|
207 |
-
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
208 |
-
"""
|
209 |
-
Create sinusoidal timestep embeddings.
|
210 |
-
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
211 |
-
These may be fractional.
|
212 |
-
:param dim: the dimension of the output.
|
213 |
-
:param max_period: controls the minimum frequency of the embeddings.
|
214 |
-
:return: an [N x dim] Tensor of positional embeddings.
|
215 |
-
"""
|
216 |
-
if not repeat_only:
|
217 |
-
half = dim // 2
|
218 |
-
freqs = torch.exp(
|
219 |
-
-math.log(max_period)
|
220 |
-
* torch.arange(start=0, end=half, dtype=torch.float32)
|
221 |
-
/ half
|
222 |
-
).to(device=timesteps.device)
|
223 |
-
args = timesteps[:, None].float() * freqs[None]
|
224 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
225 |
-
if dim % 2:
|
226 |
-
embedding = torch.cat(
|
227 |
-
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
228 |
-
)
|
229 |
-
else:
|
230 |
-
embedding = repeat(timesteps, "b -> b d", d=dim)
|
231 |
-
return embedding
|
232 |
-
|
233 |
-
|
234 |
-
def zero_module(module):
|
235 |
-
"""
|
236 |
-
Zero out the parameters of a module and return it.
|
237 |
-
"""
|
238 |
-
for p in module.parameters():
|
239 |
-
p.detach().zero_()
|
240 |
-
return module
|
241 |
-
|
242 |
-
|
243 |
-
def scale_module(module, scale):
|
244 |
-
"""
|
245 |
-
Scale the parameters of a module and return it.
|
246 |
-
"""
|
247 |
-
for p in module.parameters():
|
248 |
-
p.detach().mul_(scale)
|
249 |
-
return module
|
250 |
-
|
251 |
-
|
252 |
-
def mean_flat(tensor):
|
253 |
-
"""
|
254 |
-
Take the mean over all non-batch dimensions.
|
255 |
-
"""
|
256 |
-
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
257 |
-
|
258 |
-
|
259 |
-
def normalization(channels):
|
260 |
-
"""
|
261 |
-
Make a standard normalization layer.
|
262 |
-
:param channels: number of input channels.
|
263 |
-
:return: an nn.Module for normalization.
|
264 |
-
"""
|
265 |
-
return GroupNorm32(32, channels)
|
266 |
-
|
267 |
-
|
268 |
-
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
269 |
-
class SiLU(nn.Module):
|
270 |
-
def forward(self, x):
|
271 |
-
return x * torch.sigmoid(x)
|
272 |
-
|
273 |
-
|
274 |
-
class GroupNorm32(nn.GroupNorm):
|
275 |
-
def forward(self, x):
|
276 |
-
return super().forward(x.float()).type(x.dtype)
|
277 |
-
|
278 |
-
|
279 |
-
def conv_nd(dims, *args, **kwargs):
|
280 |
-
"""
|
281 |
-
Create a 1D, 2D, or 3D convolution module.
|
282 |
-
"""
|
283 |
-
if dims == 1:
|
284 |
-
return nn.Conv1d(*args, **kwargs)
|
285 |
-
elif dims == 2:
|
286 |
-
return nn.Conv2d(*args, **kwargs)
|
287 |
-
elif dims == 3:
|
288 |
-
return nn.Conv3d(*args, **kwargs)
|
289 |
-
raise ValueError(f"unsupported dimensions: {dims}")
|
290 |
-
|
291 |
-
|
292 |
-
def linear(*args, **kwargs):
|
293 |
-
"""
|
294 |
-
Create a linear module.
|
295 |
-
"""
|
296 |
-
return nn.Linear(*args, **kwargs)
|
297 |
-
|
298 |
-
|
299 |
-
def avg_pool_nd(dims, *args, **kwargs):
|
300 |
-
"""
|
301 |
-
Create a 1D, 2D, or 3D average pooling module.
|
302 |
-
"""
|
303 |
-
if dims == 1:
|
304 |
-
return nn.AvgPool1d(*args, **kwargs)
|
305 |
-
elif dims == 2:
|
306 |
-
return nn.AvgPool2d(*args, **kwargs)
|
307 |
-
elif dims == 3:
|
308 |
-
return nn.AvgPool3d(*args, **kwargs)
|
309 |
-
raise ValueError(f"unsupported dimensions: {dims}")
|
310 |
-
|
311 |
-
|
312 |
-
class AlphaBlender(nn.Module):
|
313 |
-
strategies = ["learned", "fixed", "learned_with_images"]
|
314 |
-
|
315 |
-
def __init__(
|
316 |
-
self,
|
317 |
-
alpha: float,
|
318 |
-
merge_strategy: str = "learned_with_images",
|
319 |
-
rearrange_pattern: str = "b t -> (b t) 1 1",
|
320 |
-
):
|
321 |
-
super().__init__()
|
322 |
-
self.merge_strategy = merge_strategy
|
323 |
-
self.rearrange_pattern = rearrange_pattern
|
324 |
-
|
325 |
-
assert (
|
326 |
-
merge_strategy in self.strategies
|
327 |
-
), f"merge_strategy needs to be in {self.strategies}"
|
328 |
-
|
329 |
-
if self.merge_strategy == "fixed":
|
330 |
-
self.register_buffer("mix_factor", torch.Tensor([alpha]))
|
331 |
-
elif (
|
332 |
-
self.merge_strategy == "learned"
|
333 |
-
or self.merge_strategy == "learned_with_images"
|
334 |
-
):
|
335 |
-
self.register_parameter(
|
336 |
-
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
|
337 |
-
)
|
338 |
-
else:
|
339 |
-
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
|
340 |
-
|
341 |
-
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
|
342 |
-
if self.merge_strategy == "fixed":
|
343 |
-
alpha = self.mix_factor
|
344 |
-
elif self.merge_strategy == "learned":
|
345 |
-
alpha = torch.sigmoid(self.mix_factor)
|
346 |
-
elif self.merge_strategy == "learned_with_images":
|
347 |
-
assert image_only_indicator is not None, "need image_only_indicator ..."
|
348 |
-
alpha = torch.where(
|
349 |
-
image_only_indicator.bool(),
|
350 |
-
torch.ones(1, 1, device=image_only_indicator.device),
|
351 |
-
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
|
352 |
-
)
|
353 |
-
alpha = rearrange(alpha, self.rearrange_pattern)
|
354 |
-
else:
|
355 |
-
raise NotImplementedError
|
356 |
-
return alpha
|
357 |
-
|
358 |
-
def forward(
|
359 |
-
self,
|
360 |
-
x_spatial: torch.Tensor,
|
361 |
-
x_temporal: torch.Tensor,
|
362 |
-
image_only_indicator: Optional[torch.Tensor] = None,
|
363 |
-
) -> torch.Tensor:
|
364 |
-
alpha = self.get_alpha(image_only_indicator)
|
365 |
-
x = (
|
366 |
-
alpha.to(x_spatial.dtype) * x_spatial
|
367 |
-
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
|
368 |
-
)
|
369 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/video_model.py
DELETED
@@ -1,493 +0,0 @@
|
|
1 |
-
from functools import partial
|
2 |
-
from typing import List, Optional, Union
|
3 |
-
|
4 |
-
from einops import rearrange
|
5 |
-
|
6 |
-
from ...modules.diffusionmodules.openaimodel import *
|
7 |
-
from ...modules.video_attention import SpatialVideoTransformer
|
8 |
-
from ...util import default
|
9 |
-
from .util import AlphaBlender
|
10 |
-
|
11 |
-
|
12 |
-
class VideoResBlock(ResBlock):
|
13 |
-
def __init__(
|
14 |
-
self,
|
15 |
-
channels: int,
|
16 |
-
emb_channels: int,
|
17 |
-
dropout: float,
|
18 |
-
video_kernel_size: Union[int, List[int]] = 3,
|
19 |
-
merge_strategy: str = "fixed",
|
20 |
-
merge_factor: float = 0.5,
|
21 |
-
out_channels: Optional[int] = None,
|
22 |
-
use_conv: bool = False,
|
23 |
-
use_scale_shift_norm: bool = False,
|
24 |
-
dims: int = 2,
|
25 |
-
use_checkpoint: bool = False,
|
26 |
-
up: bool = False,
|
27 |
-
down: bool = False,
|
28 |
-
):
|
29 |
-
super().__init__(
|
30 |
-
channels,
|
31 |
-
emb_channels,
|
32 |
-
dropout,
|
33 |
-
out_channels=out_channels,
|
34 |
-
use_conv=use_conv,
|
35 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
36 |
-
dims=dims,
|
37 |
-
use_checkpoint=use_checkpoint,
|
38 |
-
up=up,
|
39 |
-
down=down,
|
40 |
-
)
|
41 |
-
|
42 |
-
self.time_stack = ResBlock(
|
43 |
-
default(out_channels, channels),
|
44 |
-
emb_channels,
|
45 |
-
dropout=dropout,
|
46 |
-
dims=3,
|
47 |
-
out_channels=default(out_channels, channels),
|
48 |
-
use_scale_shift_norm=False,
|
49 |
-
use_conv=False,
|
50 |
-
up=False,
|
51 |
-
down=False,
|
52 |
-
kernel_size=video_kernel_size,
|
53 |
-
use_checkpoint=use_checkpoint,
|
54 |
-
exchange_temb_dims=True,
|
55 |
-
)
|
56 |
-
self.time_mixer = AlphaBlender(
|
57 |
-
alpha=merge_factor,
|
58 |
-
merge_strategy=merge_strategy,
|
59 |
-
rearrange_pattern="b t -> b 1 t 1 1",
|
60 |
-
)
|
61 |
-
|
62 |
-
def forward(
|
63 |
-
self,
|
64 |
-
x: th.Tensor,
|
65 |
-
emb: th.Tensor,
|
66 |
-
num_video_frames: int,
|
67 |
-
image_only_indicator: Optional[th.Tensor] = None,
|
68 |
-
) -> th.Tensor:
|
69 |
-
x = super().forward(x, emb)
|
70 |
-
|
71 |
-
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
72 |
-
x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
|
73 |
-
|
74 |
-
x = self.time_stack(
|
75 |
-
x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
|
76 |
-
)
|
77 |
-
x = self.time_mixer(
|
78 |
-
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
79 |
-
)
|
80 |
-
x = rearrange(x, "b c t h w -> (b t) c h w")
|
81 |
-
return x
|
82 |
-
|
83 |
-
|
84 |
-
class VideoUNet(nn.Module):
|
85 |
-
def __init__(
|
86 |
-
self,
|
87 |
-
in_channels: int,
|
88 |
-
model_channels: int,
|
89 |
-
out_channels: int,
|
90 |
-
num_res_blocks: int,
|
91 |
-
attention_resolutions: int,
|
92 |
-
dropout: float = 0.0,
|
93 |
-
channel_mult: List[int] = (1, 2, 4, 8),
|
94 |
-
conv_resample: bool = True,
|
95 |
-
dims: int = 2,
|
96 |
-
num_classes: Optional[int] = None,
|
97 |
-
use_checkpoint: bool = False,
|
98 |
-
num_heads: int = -1,
|
99 |
-
num_head_channels: int = -1,
|
100 |
-
num_heads_upsample: int = -1,
|
101 |
-
use_scale_shift_norm: bool = False,
|
102 |
-
resblock_updown: bool = False,
|
103 |
-
transformer_depth: Union[List[int], int] = 1,
|
104 |
-
transformer_depth_middle: Optional[int] = None,
|
105 |
-
context_dim: Optional[int] = None,
|
106 |
-
time_downup: bool = False,
|
107 |
-
time_context_dim: Optional[int] = None,
|
108 |
-
extra_ff_mix_layer: bool = False,
|
109 |
-
use_spatial_context: bool = False,
|
110 |
-
merge_strategy: str = "fixed",
|
111 |
-
merge_factor: float = 0.5,
|
112 |
-
spatial_transformer_attn_type: str = "softmax",
|
113 |
-
video_kernel_size: Union[int, List[int]] = 3,
|
114 |
-
use_linear_in_transformer: bool = False,
|
115 |
-
adm_in_channels: Optional[int] = None,
|
116 |
-
disable_temporal_crossattention: bool = False,
|
117 |
-
max_ddpm_temb_period: int = 10000,
|
118 |
-
):
|
119 |
-
super().__init__()
|
120 |
-
assert context_dim is not None
|
121 |
-
|
122 |
-
if num_heads_upsample == -1:
|
123 |
-
num_heads_upsample = num_heads
|
124 |
-
|
125 |
-
if num_heads == -1:
|
126 |
-
assert num_head_channels != -1
|
127 |
-
|
128 |
-
if num_head_channels == -1:
|
129 |
-
assert num_heads != -1
|
130 |
-
|
131 |
-
self.in_channels = in_channels
|
132 |
-
self.model_channels = model_channels
|
133 |
-
self.out_channels = out_channels
|
134 |
-
if isinstance(transformer_depth, int):
|
135 |
-
transformer_depth = len(channel_mult) * [transformer_depth]
|
136 |
-
transformer_depth_middle = default(
|
137 |
-
transformer_depth_middle, transformer_depth[-1]
|
138 |
-
)
|
139 |
-
|
140 |
-
self.num_res_blocks = num_res_blocks
|
141 |
-
self.attention_resolutions = attention_resolutions
|
142 |
-
self.dropout = dropout
|
143 |
-
self.channel_mult = channel_mult
|
144 |
-
self.conv_resample = conv_resample
|
145 |
-
self.num_classes = num_classes
|
146 |
-
self.use_checkpoint = use_checkpoint
|
147 |
-
self.num_heads = num_heads
|
148 |
-
self.num_head_channels = num_head_channels
|
149 |
-
self.num_heads_upsample = num_heads_upsample
|
150 |
-
|
151 |
-
time_embed_dim = model_channels * 4
|
152 |
-
self.time_embed = nn.Sequential(
|
153 |
-
linear(model_channels, time_embed_dim),
|
154 |
-
nn.SiLU(),
|
155 |
-
linear(time_embed_dim, time_embed_dim),
|
156 |
-
)
|
157 |
-
|
158 |
-
if self.num_classes is not None:
|
159 |
-
if isinstance(self.num_classes, int):
|
160 |
-
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
161 |
-
elif self.num_classes == "continuous":
|
162 |
-
print("setting up linear c_adm embedding layer")
|
163 |
-
self.label_emb = nn.Linear(1, time_embed_dim)
|
164 |
-
elif self.num_classes == "timestep":
|
165 |
-
self.label_emb = nn.Sequential(
|
166 |
-
Timestep(model_channels),
|
167 |
-
nn.Sequential(
|
168 |
-
linear(model_channels, time_embed_dim),
|
169 |
-
nn.SiLU(),
|
170 |
-
linear(time_embed_dim, time_embed_dim),
|
171 |
-
),
|
172 |
-
)
|
173 |
-
|
174 |
-
elif self.num_classes == "sequential":
|
175 |
-
assert adm_in_channels is not None
|
176 |
-
self.label_emb = nn.Sequential(
|
177 |
-
nn.Sequential(
|
178 |
-
linear(adm_in_channels, time_embed_dim),
|
179 |
-
nn.SiLU(),
|
180 |
-
linear(time_embed_dim, time_embed_dim),
|
181 |
-
)
|
182 |
-
)
|
183 |
-
else:
|
184 |
-
raise ValueError()
|
185 |
-
|
186 |
-
self.input_blocks = nn.ModuleList(
|
187 |
-
[
|
188 |
-
TimestepEmbedSequential(
|
189 |
-
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
190 |
-
)
|
191 |
-
]
|
192 |
-
)
|
193 |
-
self._feature_size = model_channels
|
194 |
-
input_block_chans = [model_channels]
|
195 |
-
ch = model_channels
|
196 |
-
ds = 1
|
197 |
-
|
198 |
-
def get_attention_layer(
|
199 |
-
ch,
|
200 |
-
num_heads,
|
201 |
-
dim_head,
|
202 |
-
depth=1,
|
203 |
-
context_dim=None,
|
204 |
-
use_checkpoint=False,
|
205 |
-
disabled_sa=False,
|
206 |
-
):
|
207 |
-
return SpatialVideoTransformer(
|
208 |
-
ch,
|
209 |
-
num_heads,
|
210 |
-
dim_head,
|
211 |
-
depth=depth,
|
212 |
-
context_dim=context_dim,
|
213 |
-
time_context_dim=time_context_dim,
|
214 |
-
dropout=dropout,
|
215 |
-
ff_in=extra_ff_mix_layer,
|
216 |
-
use_spatial_context=use_spatial_context,
|
217 |
-
merge_strategy=merge_strategy,
|
218 |
-
merge_factor=merge_factor,
|
219 |
-
checkpoint=use_checkpoint,
|
220 |
-
use_linear=use_linear_in_transformer,
|
221 |
-
attn_mode=spatial_transformer_attn_type,
|
222 |
-
disable_self_attn=disabled_sa,
|
223 |
-
disable_temporal_crossattention=disable_temporal_crossattention,
|
224 |
-
max_time_embed_period=max_ddpm_temb_period,
|
225 |
-
)
|
226 |
-
|
227 |
-
def get_resblock(
|
228 |
-
merge_factor,
|
229 |
-
merge_strategy,
|
230 |
-
video_kernel_size,
|
231 |
-
ch,
|
232 |
-
time_embed_dim,
|
233 |
-
dropout,
|
234 |
-
out_ch,
|
235 |
-
dims,
|
236 |
-
use_checkpoint,
|
237 |
-
use_scale_shift_norm,
|
238 |
-
down=False,
|
239 |
-
up=False,
|
240 |
-
):
|
241 |
-
return VideoResBlock(
|
242 |
-
merge_factor=merge_factor,
|
243 |
-
merge_strategy=merge_strategy,
|
244 |
-
video_kernel_size=video_kernel_size,
|
245 |
-
channels=ch,
|
246 |
-
emb_channels=time_embed_dim,
|
247 |
-
dropout=dropout,
|
248 |
-
out_channels=out_ch,
|
249 |
-
dims=dims,
|
250 |
-
use_checkpoint=use_checkpoint,
|
251 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
252 |
-
down=down,
|
253 |
-
up=up,
|
254 |
-
)
|
255 |
-
|
256 |
-
for level, mult in enumerate(channel_mult):
|
257 |
-
for _ in range(num_res_blocks):
|
258 |
-
layers = [
|
259 |
-
get_resblock(
|
260 |
-
merge_factor=merge_factor,
|
261 |
-
merge_strategy=merge_strategy,
|
262 |
-
video_kernel_size=video_kernel_size,
|
263 |
-
ch=ch,
|
264 |
-
time_embed_dim=time_embed_dim,
|
265 |
-
dropout=dropout,
|
266 |
-
out_ch=mult * model_channels,
|
267 |
-
dims=dims,
|
268 |
-
use_checkpoint=use_checkpoint,
|
269 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
270 |
-
)
|
271 |
-
]
|
272 |
-
ch = mult * model_channels
|
273 |
-
if ds in attention_resolutions:
|
274 |
-
if num_head_channels == -1:
|
275 |
-
dim_head = ch // num_heads
|
276 |
-
else:
|
277 |
-
num_heads = ch // num_head_channels
|
278 |
-
dim_head = num_head_channels
|
279 |
-
|
280 |
-
layers.append(
|
281 |
-
get_attention_layer(
|
282 |
-
ch,
|
283 |
-
num_heads,
|
284 |
-
dim_head,
|
285 |
-
depth=transformer_depth[level],
|
286 |
-
context_dim=context_dim,
|
287 |
-
use_checkpoint=use_checkpoint,
|
288 |
-
disabled_sa=False,
|
289 |
-
)
|
290 |
-
)
|
291 |
-
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
292 |
-
self._feature_size += ch
|
293 |
-
input_block_chans.append(ch)
|
294 |
-
if level != len(channel_mult) - 1:
|
295 |
-
ds *= 2
|
296 |
-
out_ch = ch
|
297 |
-
self.input_blocks.append(
|
298 |
-
TimestepEmbedSequential(
|
299 |
-
get_resblock(
|
300 |
-
merge_factor=merge_factor,
|
301 |
-
merge_strategy=merge_strategy,
|
302 |
-
video_kernel_size=video_kernel_size,
|
303 |
-
ch=ch,
|
304 |
-
time_embed_dim=time_embed_dim,
|
305 |
-
dropout=dropout,
|
306 |
-
out_ch=out_ch,
|
307 |
-
dims=dims,
|
308 |
-
use_checkpoint=use_checkpoint,
|
309 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
310 |
-
down=True,
|
311 |
-
)
|
312 |
-
if resblock_updown
|
313 |
-
else Downsample(
|
314 |
-
ch,
|
315 |
-
conv_resample,
|
316 |
-
dims=dims,
|
317 |
-
out_channels=out_ch,
|
318 |
-
third_down=time_downup,
|
319 |
-
)
|
320 |
-
)
|
321 |
-
)
|
322 |
-
ch = out_ch
|
323 |
-
input_block_chans.append(ch)
|
324 |
-
|
325 |
-
self._feature_size += ch
|
326 |
-
|
327 |
-
if num_head_channels == -1:
|
328 |
-
dim_head = ch // num_heads
|
329 |
-
else:
|
330 |
-
num_heads = ch // num_head_channels
|
331 |
-
dim_head = num_head_channels
|
332 |
-
|
333 |
-
self.middle_block = TimestepEmbedSequential(
|
334 |
-
get_resblock(
|
335 |
-
merge_factor=merge_factor,
|
336 |
-
merge_strategy=merge_strategy,
|
337 |
-
video_kernel_size=video_kernel_size,
|
338 |
-
ch=ch,
|
339 |
-
time_embed_dim=time_embed_dim,
|
340 |
-
out_ch=None,
|
341 |
-
dropout=dropout,
|
342 |
-
dims=dims,
|
343 |
-
use_checkpoint=use_checkpoint,
|
344 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
345 |
-
),
|
346 |
-
get_attention_layer(
|
347 |
-
ch,
|
348 |
-
num_heads,
|
349 |
-
dim_head,
|
350 |
-
depth=transformer_depth_middle,
|
351 |
-
context_dim=context_dim,
|
352 |
-
use_checkpoint=use_checkpoint,
|
353 |
-
),
|
354 |
-
get_resblock(
|
355 |
-
merge_factor=merge_factor,
|
356 |
-
merge_strategy=merge_strategy,
|
357 |
-
video_kernel_size=video_kernel_size,
|
358 |
-
ch=ch,
|
359 |
-
out_ch=None,
|
360 |
-
time_embed_dim=time_embed_dim,
|
361 |
-
dropout=dropout,
|
362 |
-
dims=dims,
|
363 |
-
use_checkpoint=use_checkpoint,
|
364 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
365 |
-
),
|
366 |
-
)
|
367 |
-
self._feature_size += ch
|
368 |
-
|
369 |
-
self.output_blocks = nn.ModuleList([])
|
370 |
-
for level, mult in list(enumerate(channel_mult))[::-1]:
|
371 |
-
for i in range(num_res_blocks + 1):
|
372 |
-
ich = input_block_chans.pop()
|
373 |
-
layers = [
|
374 |
-
get_resblock(
|
375 |
-
merge_factor=merge_factor,
|
376 |
-
merge_strategy=merge_strategy,
|
377 |
-
video_kernel_size=video_kernel_size,
|
378 |
-
ch=ch + ich,
|
379 |
-
time_embed_dim=time_embed_dim,
|
380 |
-
dropout=dropout,
|
381 |
-
out_ch=model_channels * mult,
|
382 |
-
dims=dims,
|
383 |
-
use_checkpoint=use_checkpoint,
|
384 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
385 |
-
)
|
386 |
-
]
|
387 |
-
ch = model_channels * mult
|
388 |
-
if ds in attention_resolutions:
|
389 |
-
if num_head_channels == -1:
|
390 |
-
dim_head = ch // num_heads
|
391 |
-
else:
|
392 |
-
num_heads = ch // num_head_channels
|
393 |
-
dim_head = num_head_channels
|
394 |
-
|
395 |
-
layers.append(
|
396 |
-
get_attention_layer(
|
397 |
-
ch,
|
398 |
-
num_heads,
|
399 |
-
dim_head,
|
400 |
-
depth=transformer_depth[level],
|
401 |
-
context_dim=context_dim,
|
402 |
-
use_checkpoint=use_checkpoint,
|
403 |
-
disabled_sa=False,
|
404 |
-
)
|
405 |
-
)
|
406 |
-
if level and i == num_res_blocks:
|
407 |
-
out_ch = ch
|
408 |
-
ds //= 2
|
409 |
-
layers.append(
|
410 |
-
get_resblock(
|
411 |
-
merge_factor=merge_factor,
|
412 |
-
merge_strategy=merge_strategy,
|
413 |
-
video_kernel_size=video_kernel_size,
|
414 |
-
ch=ch,
|
415 |
-
time_embed_dim=time_embed_dim,
|
416 |
-
dropout=dropout,
|
417 |
-
out_ch=out_ch,
|
418 |
-
dims=dims,
|
419 |
-
use_checkpoint=use_checkpoint,
|
420 |
-
use_scale_shift_norm=use_scale_shift_norm,
|
421 |
-
up=True,
|
422 |
-
)
|
423 |
-
if resblock_updown
|
424 |
-
else Upsample(
|
425 |
-
ch,
|
426 |
-
conv_resample,
|
427 |
-
dims=dims,
|
428 |
-
out_channels=out_ch,
|
429 |
-
third_up=time_downup,
|
430 |
-
)
|
431 |
-
)
|
432 |
-
|
433 |
-
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
434 |
-
self._feature_size += ch
|
435 |
-
|
436 |
-
self.out = nn.Sequential(
|
437 |
-
normalization(ch),
|
438 |
-
nn.SiLU(),
|
439 |
-
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
440 |
-
)
|
441 |
-
|
442 |
-
def forward(
|
443 |
-
self,
|
444 |
-
x: th.Tensor,
|
445 |
-
timesteps: th.Tensor,
|
446 |
-
context: Optional[th.Tensor] = None,
|
447 |
-
y: Optional[th.Tensor] = None,
|
448 |
-
time_context: Optional[th.Tensor] = None,
|
449 |
-
num_video_frames: Optional[int] = None,
|
450 |
-
image_only_indicator: Optional[th.Tensor] = None,
|
451 |
-
):
|
452 |
-
assert (y is not None) == (
|
453 |
-
self.num_classes is not None
|
454 |
-
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
455 |
-
hs = []
|
456 |
-
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
457 |
-
emb = self.time_embed(t_emb)
|
458 |
-
|
459 |
-
if self.num_classes is not None:
|
460 |
-
assert y.shape[0] == x.shape[0]
|
461 |
-
emb = emb + self.label_emb(y)
|
462 |
-
|
463 |
-
h = x
|
464 |
-
for module in self.input_blocks:
|
465 |
-
h = module(
|
466 |
-
h,
|
467 |
-
emb,
|
468 |
-
context=context,
|
469 |
-
image_only_indicator=image_only_indicator,
|
470 |
-
time_context=time_context,
|
471 |
-
num_video_frames=num_video_frames,
|
472 |
-
)
|
473 |
-
hs.append(h)
|
474 |
-
h = self.middle_block(
|
475 |
-
h,
|
476 |
-
emb,
|
477 |
-
context=context,
|
478 |
-
image_only_indicator=image_only_indicator,
|
479 |
-
time_context=time_context,
|
480 |
-
num_video_frames=num_video_frames,
|
481 |
-
)
|
482 |
-
for module in self.output_blocks:
|
483 |
-
h = th.cat([h, hs.pop()], dim=1)
|
484 |
-
h = module(
|
485 |
-
h,
|
486 |
-
emb,
|
487 |
-
context=context,
|
488 |
-
image_only_indicator=image_only_indicator,
|
489 |
-
time_context=time_context,
|
490 |
-
num_video_frames=num_video_frames,
|
491 |
-
)
|
492 |
-
h = h.type(x.dtype)
|
493 |
-
return self.out(h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/diffusionmodules/wrappers.py
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from packaging import version
|
4 |
-
|
5 |
-
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
6 |
-
|
7 |
-
|
8 |
-
class IdentityWrapper(nn.Module):
|
9 |
-
def __init__(self, diffusion_model, compile_model: bool = False):
|
10 |
-
super().__init__()
|
11 |
-
compile = (
|
12 |
-
torch.compile
|
13 |
-
if (version.parse(torch.__version__) >= version.parse("2.0.0"))
|
14 |
-
and compile_model
|
15 |
-
else lambda x: x
|
16 |
-
)
|
17 |
-
self.diffusion_model = compile(diffusion_model)
|
18 |
-
|
19 |
-
def forward(self, *args, **kwargs):
|
20 |
-
return self.diffusion_model(*args, **kwargs)
|
21 |
-
|
22 |
-
|
23 |
-
class OpenAIWrapper(IdentityWrapper):
|
24 |
-
def forward(
|
25 |
-
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
26 |
-
) -> torch.Tensor:
|
27 |
-
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
28 |
-
return self.diffusion_model(
|
29 |
-
x,
|
30 |
-
timesteps=t,
|
31 |
-
context=c.get("crossattn", None),
|
32 |
-
y=c.get("vector", None),
|
33 |
-
**kwargs,
|
34 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/distributions/__init__.py
DELETED
File without changes
|
sgm/modules/distributions/distributions.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
|
4 |
-
|
5 |
-
class AbstractDistribution:
|
6 |
-
def sample(self):
|
7 |
-
raise NotImplementedError()
|
8 |
-
|
9 |
-
def mode(self):
|
10 |
-
raise NotImplementedError()
|
11 |
-
|
12 |
-
|
13 |
-
class DiracDistribution(AbstractDistribution):
|
14 |
-
def __init__(self, value):
|
15 |
-
self.value = value
|
16 |
-
|
17 |
-
def sample(self):
|
18 |
-
return self.value
|
19 |
-
|
20 |
-
def mode(self):
|
21 |
-
return self.value
|
22 |
-
|
23 |
-
|
24 |
-
class DiagonalGaussianDistribution(object):
|
25 |
-
def __init__(self, parameters, deterministic=False):
|
26 |
-
self.parameters = parameters
|
27 |
-
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
28 |
-
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
29 |
-
self.deterministic = deterministic
|
30 |
-
self.std = torch.exp(0.5 * self.logvar)
|
31 |
-
self.var = torch.exp(self.logvar)
|
32 |
-
if self.deterministic:
|
33 |
-
self.var = self.std = torch.zeros_like(self.mean).to(
|
34 |
-
device=self.parameters.device
|
35 |
-
)
|
36 |
-
|
37 |
-
def sample(self):
|
38 |
-
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
39 |
-
device=self.parameters.device
|
40 |
-
)
|
41 |
-
return x
|
42 |
-
|
43 |
-
def kl(self, other=None):
|
44 |
-
if self.deterministic:
|
45 |
-
return torch.Tensor([0.0])
|
46 |
-
else:
|
47 |
-
if other is None:
|
48 |
-
return 0.5 * torch.sum(
|
49 |
-
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
50 |
-
dim=[1, 2, 3],
|
51 |
-
)
|
52 |
-
else:
|
53 |
-
return 0.5 * torch.sum(
|
54 |
-
torch.pow(self.mean - other.mean, 2) / other.var
|
55 |
-
+ self.var / other.var
|
56 |
-
- 1.0
|
57 |
-
- self.logvar
|
58 |
-
+ other.logvar,
|
59 |
-
dim=[1, 2, 3],
|
60 |
-
)
|
61 |
-
|
62 |
-
def nll(self, sample, dims=[1, 2, 3]):
|
63 |
-
if self.deterministic:
|
64 |
-
return torch.Tensor([0.0])
|
65 |
-
logtwopi = np.log(2.0 * np.pi)
|
66 |
-
return 0.5 * torch.sum(
|
67 |
-
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
68 |
-
dim=dims,
|
69 |
-
)
|
70 |
-
|
71 |
-
def mode(self):
|
72 |
-
return self.mean
|
73 |
-
|
74 |
-
|
75 |
-
def normal_kl(mean1, logvar1, mean2, logvar2):
|
76 |
-
"""
|
77 |
-
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
78 |
-
Compute the KL divergence between two gaussians.
|
79 |
-
Shapes are automatically broadcasted, so batches can be compared to
|
80 |
-
scalars, among other use cases.
|
81 |
-
"""
|
82 |
-
tensor = None
|
83 |
-
for obj in (mean1, logvar1, mean2, logvar2):
|
84 |
-
if isinstance(obj, torch.Tensor):
|
85 |
-
tensor = obj
|
86 |
-
break
|
87 |
-
assert tensor is not None, "at least one argument must be a Tensor"
|
88 |
-
|
89 |
-
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
90 |
-
# Tensors, but it does not work for torch.exp().
|
91 |
-
logvar1, logvar2 = [
|
92 |
-
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
93 |
-
for x in (logvar1, logvar2)
|
94 |
-
]
|
95 |
-
|
96 |
-
return 0.5 * (
|
97 |
-
-1.0
|
98 |
-
+ logvar2
|
99 |
-
- logvar1
|
100 |
-
+ torch.exp(logvar1 - logvar2)
|
101 |
-
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
102 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgm/modules/ema.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from torch import nn
|
3 |
-
|
4 |
-
|
5 |
-
class LitEma(nn.Module):
|
6 |
-
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
7 |
-
super().__init__()
|
8 |
-
if decay < 0.0 or decay > 1.0:
|
9 |
-
raise ValueError("Decay must be between 0 and 1")
|
10 |
-
|
11 |
-
self.m_name2s_name = {}
|
12 |
-
self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
|
13 |
-
self.register_buffer(
|
14 |
-
"num_updates",
|
15 |
-
torch.tensor(0, dtype=torch.int)
|
16 |
-
if use_num_upates
|
17 |
-
else torch.tensor(-1, dtype=torch.int),
|
18 |
-
)
|
19 |
-
|
20 |
-
for name, p in model.named_parameters():
|
21 |
-
if p.requires_grad:
|
22 |
-
# remove as '.'-character is not allowed in buffers
|
23 |
-
s_name = name.replace(".", "")
|
24 |
-
self.m_name2s_name.update({name: s_name})
|
25 |
-
self.register_buffer(s_name, p.clone().detach().data)
|
26 |
-
|
27 |
-
self.collected_params = []
|
28 |
-
|
29 |
-
def reset_num_updates(self):
|
30 |
-
del self.num_updates
|
31 |
-
self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
|
32 |
-
|
33 |
-
def forward(self, model):
|
34 |
-
decay = self.decay
|
35 |
-
|
36 |
-
if self.num_updates >= 0:
|
37 |
-
self.num_updates += 1
|
38 |
-
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
39 |
-
|
40 |
-
one_minus_decay = 1.0 - decay
|
41 |
-
|
42 |
-
with torch.no_grad():
|
43 |
-
m_param = dict(model.named_parameters())
|
44 |
-
shadow_params = dict(self.named_buffers())
|
45 |
-
|
46 |
-
for key in m_param:
|
47 |
-
if m_param[key].requires_grad:
|
48 |
-
sname = self.m_name2s_name[key]
|
49 |
-
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
50 |
-
shadow_params[sname].sub_(
|
51 |
-
one_minus_decay * (shadow_params[sname] - m_param[key])
|
52 |
-
)
|
53 |
-
else:
|
54 |
-
assert not key in self.m_name2s_name
|
55 |
-
|
56 |
-
def copy_to(self, model):
|
57 |
-
m_param = dict(model.named_parameters())
|
58 |
-
shadow_params = dict(self.named_buffers())
|
59 |
-
for key in m_param:
|
60 |
-
if m_param[key].requires_grad:
|
61 |
-
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
62 |
-
else:
|
63 |
-
assert not key in self.m_name2s_name
|
64 |
-
|
65 |
-
def store(self, parameters):
|
66 |
-
"""
|
67 |
-
Save the current parameters for restoring later.
|
68 |
-
Args:
|
69 |
-
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
70 |
-
temporarily stored.
|
71 |
-
"""
|
72 |
-
self.collected_params = [param.clone() for param in parameters]
|
73 |
-
|
74 |
-
def restore(self, parameters):
|
75 |
-
"""
|
76 |
-
Restore the parameters stored with the `store` method.
|
77 |
-
Useful to validate the model with EMA parameters without affecting the
|
78 |
-
original optimization process. Store the parameters before the
|
79 |
-
`copy_to` method. After validation (or model saving), use this to
|
80 |
-
restore the former parameters.
|
81 |
-
Args:
|
82 |
-
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
83 |
-
updated with the stored parameters.
|
84 |
-
"""
|
85 |
-
for c_param, param in zip(self.collected_params, parameters):
|
86 |
-
param.data.copy_(c_param.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|