Spaces:
Runtime error
Runtime error
File size: 10,355 Bytes
6831a54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import torch
import math
import itertools
from tqdm import trange
from backend import memory_management
from backend.patcher.base import ModelPatcher
@torch.inference_mode()
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu"):
dims = len(tile)
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
for b in trange(samples.shape[0]):
s = samples[b:b + 1]
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
for d in range(2, dims + 2):
m = mask.narrow(d, t, 1)
m *= ((1.0 / feather) * (t + 1))
m = mask.narrow(d, mask.shape[d] - 1 - t, 1)
m *= ((1.0 / feather) * (t + 1))
o = out
o_d = out_div
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask
o_d += mask
output[b:b + 1] = out / out_div
return output
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu"):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device)
class VAE:
def __init__(self, model=None, device=None, dtype=None, no_init=False):
if no_init:
return
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * memory_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * memory_management.dtype_size(dtype)
self.downscale_ratio = int(2 ** (len(model.config.down_block_types) - 1))
self.latent_channels = int(model.config.latent_channels)
self.first_stage_model = model.eval()
if device is None:
device = memory_management.vae_device()
self.device = device
offload_device = memory_management.vae_offload_device()
if dtype is None:
dtype = memory_management.vae_dtype()
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
self.output_device = memory_management.intermediate_device()
self.patcher = ModelPatcher(
self.first_stage_model,
load_device=self.device,
offload_device=offload_device
)
def clone(self):
n = VAE(no_init=True)
n.patcher = self.patcher.clone()
n.memory_used_encode = self.memory_used_encode
n.memory_used_decode = self.memory_used_decode
n.downscale_ratio = self.downscale_ratio
n.latent_channels = self.latent_channels
n.first_stage_model = self.first_stage_model
n.device = self.device
n.vae_dtype = self.vae_dtype
n.output_device = self.output_device
return n
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
steps = samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
output = torch.clamp(((tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device) +
tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device) +
tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device))
/ 3.0) / 2.0, min=0.0, max=1.0)
return output
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
steps = pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
samples = tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
samples += tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
samples += tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
samples /= 3.0
return samples
def decode_inner(self, samples_in):
if memory_management.VAE_ALWAYS_TILED:
return self.decode_tiled(samples_in).to(self.output_device)
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = memory_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x + batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
except memory_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
return pixel_samples
def decode(self, samples_in):
wrapper = self.patcher.model_options.get('model_vae_decode_wrapper', None)
if wrapper is None:
return self.decode_inner(samples_in)
else:
return wrapper(self.decode_inner, samples_in)
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
memory_management.load_model_gpu(self.patcher)
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
return output.movedim(1, -1)
def encode_inner(self, pixel_samples):
if memory_management.VAE_ALWAYS_TILED:
return self.encode_tiled(pixel_samples)
regulation = self.patcher.model_options.get("model_vae_regulation", None)
pixel_samples = pixel_samples.movedim(-1, 1)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = memory_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = (2. * pixel_samples[x:x + batch_number] - 1.).to(self.vae_dtype).to(self.device)
samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in, regulation).to(self.output_device).float()
except memory_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
return samples
def encode(self, pixel_samples):
wrapper = self.patcher.model_options.get('model_vae_encode_wrapper', None)
if wrapper is None:
return self.encode_inner(pixel_samples)
else:
return wrapper(self.encode_inner, pixel_samples)
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
memory_management.load_model_gpu(self.patcher)
pixel_samples = pixel_samples.movedim(-1, 1)
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
return samples
|