ZYMPKU commited on
Commit
8876f9a
1 Parent(s): 1820a80
sgm/models/diffusion.py CHANGED
@@ -5,7 +5,6 @@ import pytorch_lightning as pl
5
  import torch
6
  from omegaconf import ListConfig, OmegaConf
7
  from safetensors.torch import load_file as load_safetensors
8
- from torch.optim.lr_scheduler import LambdaLR
9
 
10
  from ..modules import UNCONDITIONAL_CONFIG
11
  from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
 
5
  import torch
6
  from omegaconf import ListConfig, OmegaConf
7
  from safetensors.torch import load_file as load_safetensors
 
8
 
9
  from ..modules import UNCONDITIONAL_CONFIG
10
  from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
sgm/modules/diffusionmodules/loss.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from omegaconf import ListConfig
7
- from taming.modules.losses.lpips import LPIPS
8
  from torchvision.utils import save_image
9
  from ...util import append_dims, instantiate_from_config
10
 
@@ -26,8 +26,8 @@ class StandardDiffusionLoss(nn.Module):
26
  self.type = type
27
  self.offset_noise_level = offset_noise_level
28
 
29
- if type == "lpips":
30
- self.lpips = LPIPS().eval()
31
 
32
  if not batch2model_keys:
33
  batch2model_keys = []
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  from omegaconf import ListConfig
7
+ # from taming.modules.losses.lpips import LPIPS
8
  from torchvision.utils import save_image
9
  from ...util import append_dims, instantiate_from_config
10
 
 
26
  self.type = type
27
  self.offset_noise_level = offset_noise_level
28
 
29
+ # if type == "lpips":
30
+ # self.lpips = LPIPS().eval()
31
 
32
  if not batch2model_keys:
33
  batch2model_keys = []
sgm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import math
3
  from abc import abstractmethod
4
  from functools import partial
 
 
1
  import math
2
  from abc import abstractmethod
3
  from functools import partial
sgm/modules/diffusionmodules/sampling.py CHANGED
@@ -7,7 +7,6 @@ from typing import Dict, Union
7
 
8
  import imageio
9
  import torch
10
- import json
11
  import numpy as np
12
  import torch.nn.functional as F
13
  from omegaconf import ListConfig, OmegaConf
 
7
 
8
  import imageio
9
  import torch
 
10
  import numpy as np
11
  import torch.nn.functional as F
12
  from omegaconf import ListConfig, OmegaConf