require
Browse files
sgm/models/diffusion.py
CHANGED
@@ -5,7 +5,6 @@ import pytorch_lightning as pl
|
|
5 |
import torch
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
from safetensors.torch import load_file as load_safetensors
|
8 |
-
from torch.optim.lr_scheduler import LambdaLR
|
9 |
|
10 |
from ..modules import UNCONDITIONAL_CONFIG
|
11 |
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
|
|
5 |
import torch
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
from safetensors.torch import load_file as load_safetensors
|
|
|
8 |
|
9 |
from ..modules import UNCONDITIONAL_CONFIG
|
10 |
from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
|
sgm/modules/diffusionmodules/loss.py
CHANGED
@@ -4,7 +4,7 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from omegaconf import ListConfig
|
7 |
-
from taming.modules.losses.lpips import LPIPS
|
8 |
from torchvision.utils import save_image
|
9 |
from ...util import append_dims, instantiate_from_config
|
10 |
|
@@ -26,8 +26,8 @@ class StandardDiffusionLoss(nn.Module):
|
|
26 |
self.type = type
|
27 |
self.offset_noise_level = offset_noise_level
|
28 |
|
29 |
-
if type == "lpips":
|
30 |
-
|
31 |
|
32 |
if not batch2model_keys:
|
33 |
batch2model_keys = []
|
|
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
from omegaconf import ListConfig
|
7 |
+
# from taming.modules.losses.lpips import LPIPS
|
8 |
from torchvision.utils import save_image
|
9 |
from ...util import append_dims, instantiate_from_config
|
10 |
|
|
|
26 |
self.type = type
|
27 |
self.offset_noise_level = offset_noise_level
|
28 |
|
29 |
+
# if type == "lpips":
|
30 |
+
# self.lpips = LPIPS().eval()
|
31 |
|
32 |
if not batch2model_keys:
|
33 |
batch2model_keys = []
|
sgm/modules/diffusionmodules/openaimodel.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import math
|
3 |
from abc import abstractmethod
|
4 |
from functools import partial
|
|
|
|
|
1 |
import math
|
2 |
from abc import abstractmethod
|
3 |
from functools import partial
|
sgm/modules/diffusionmodules/sampling.py
CHANGED
@@ -7,7 +7,6 @@ from typing import Dict, Union
|
|
7 |
|
8 |
import imageio
|
9 |
import torch
|
10 |
-
import json
|
11 |
import numpy as np
|
12 |
import torch.nn.functional as F
|
13 |
from omegaconf import ListConfig, OmegaConf
|
|
|
7 |
|
8 |
import imageio
|
9 |
import torch
|
|
|
10 |
import numpy as np
|
11 |
import torch.nn.functional as F
|
12 |
from omegaconf import ListConfig, OmegaConf
|