Robert001 commited on
Commit
b094b4a
·
1 Parent(s): 45781bf

first commit

Browse files
Files changed (5) hide show
  1. lib/attention.py +1 -1
  2. lib/ddpm_multi.py +1 -1
  3. lib/openaimodel.py +1 -1
  4. lib/util.py +2 -10
  5. lib/utils.py +117 -0
lib/attention.py CHANGED
@@ -18,7 +18,7 @@ from torch import nn, einsum
18
  from einops import rearrange, repeat
19
  from typing import Optional, Any
20
 
21
- from ..utils import checkpoint
22
 
23
  try:
24
  import xformers
 
18
  from einops import rearrange, repeat
19
  from typing import Optional, Any
20
 
21
+ from utils import checkpoint
22
 
23
  try:
24
  import xformers
lib/ddpm_multi.py CHANGED
@@ -30,7 +30,7 @@ from torchvision.utils import make_grid
30
  from pytorch_lightning.utilities.distributed import rank_zero_only
31
  from omegaconf import ListConfig
32
 
33
- from ..utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
34
  from lib.distributions import normal_kl, DiagonalGaussianDistribution
35
  from lib.autoencoder import IdentityFirstStage, AutoencoderKL
36
  from lib.util import make_beta_schedule, extract_into_tensor, noise_like
 
30
  from pytorch_lightning.utilities.distributed import rank_zero_only
31
  from omegaconf import ListConfig
32
 
33
+ from utils import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
34
  from lib.distributions import normal_kl, DiagonalGaussianDistribution
35
  from lib.autoencoder import IdentityFirstStage, AutoencoderKL
36
  from lib.util import make_beta_schedule, extract_into_tensor, noise_like
lib/openaimodel.py CHANGED
@@ -26,7 +26,7 @@ from lib.util import (
26
  timestep_embedding,
27
  )
28
  from attention import SpatialTransformer
29
- from ..utils import exists
30
 
31
 
32
  # dummy replace
 
26
  timestep_embedding,
27
  )
28
  from attention import SpatialTransformer
29
+ from utils import exists
30
 
31
 
32
  # dummy replace
lib/util.py CHANGED
@@ -25,16 +25,8 @@ import torch.nn as nn
25
  import numpy as np
26
  from einops import repeat
27
 
28
- #from ..utils import instantiate_from_config
29
-
30
- def instantiate_from_config(config):
31
- if not "target" in config:
32
- if config == '__is_first_stage__':
33
- return None
34
- elif config == "__is_unconditional__":
35
- return None
36
- raise KeyError("Expected key `target` to instantiate.")
37
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
38
 
39
  def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
40
  if schedule == "linear":
 
25
  import numpy as np
26
  from einops import repeat
27
 
28
+ from utils import instantiate_from_config
29
+
 
 
 
 
 
 
 
 
30
 
31
  def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
32
  if schedule == "linear":
lib/utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2023 Salesforce, Inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: Apache License 2.0
5
+ * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6
+ * By Can Qin
7
+ * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
8
+ * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
9
+ '''
10
+
11
+ import os
12
+ import torch
13
+ from omegaconf import OmegaConf
14
+ import importlib
15
+ import numpy as np
16
+
17
+
18
+ from inspect import isfunction
19
+ from PIL import Image, ImageDraw, ImageFont
20
+
21
+
22
+ def log_txt_as_img(wh, xc, size=10):
23
+ # wh a tuple of (width, height)
24
+ # xc a list of captions to plot
25
+ b = len(xc)
26
+ txts = list()
27
+ for bi in range(b):
28
+ txt = Image.new("RGB", wh, color="white")
29
+ draw = ImageDraw.Draw(txt)
30
+ font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
31
+ nc = int(40 * (wh[0] / 256))
32
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
33
+
34
+ try:
35
+ draw.text((0, 0), lines, fill="black", font=font)
36
+ except UnicodeEncodeError:
37
+ print("Cant encode string for logging. Skipping.")
38
+
39
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
40
+ txts.append(txt)
41
+ txts = np.stack(txts)
42
+ txts = torch.tensor(txts)
43
+ return txts
44
+
45
+
46
+ def ismap(x):
47
+ if not isinstance(x, torch.Tensor):
48
+ return False
49
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
50
+
51
+
52
+ def isimage(x):
53
+ if not isinstance(x,torch.Tensor):
54
+ return False
55
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
56
+
57
+
58
+ def exists(x):
59
+ return x is not None
60
+
61
+
62
+ def default(val, d):
63
+ if exists(val):
64
+ return val
65
+ return d() if isfunction(d) else d
66
+
67
+
68
+ def mean_flat(tensor):
69
+ """
70
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
71
+ Take the mean over all non-batch dimensions.
72
+ """
73
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
74
+
75
+ def count_params(model, verbose=False):
76
+ total_params = sum(p.numel() for p in model.parameters())
77
+ if verbose:
78
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
79
+ return total_params
80
+
81
+
82
+ def get_state_dict(d):
83
+ return d.get('state_dict', d)
84
+
85
+
86
+ def load_state_dict(ckpt_path, location='cpu'):
87
+ _, extension = os.path.splitext(ckpt_path)
88
+ if extension.lower() == ".safetensors":
89
+ import safetensors.torch
90
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
91
+ else:
92
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
93
+ state_dict = get_state_dict(state_dict)
94
+ print(f'Loaded state_dict from [{ckpt_path}]')
95
+ return state_dict
96
+
97
+ def get_obj_from_str(string, reload=False):
98
+ module, cls = string.rsplit(".", 1)
99
+ if reload:
100
+ module_imp = importlib.import_module(module)
101
+ importlib.reload(module_imp)
102
+ return getattr(importlib.import_module(module, package=None), cls)
103
+
104
+ def instantiate_from_config(config):
105
+ if not "target" in config:
106
+ if config == '__is_first_stage__':
107
+ return None
108
+ elif config == "__is_unconditional__":
109
+ return None
110
+ raise KeyError("Expected key `target` to instantiate.")
111
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
112
+
113
+ def create_model(config_path):
114
+ config = OmegaConf.load(config_path)
115
+ model = instantiate_from_config(config.model).cpu()
116
+ print(f'Loaded model config from [{config_path}]')
117
+ return model