MaykaGR commited on
Commit
b8c0a8e
·
verified ·
1 Parent(s): 28a6791

Upload latent_preview.py

Browse files
Files changed (1) hide show
  1. latent_preview.py +108 -0
latent_preview.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from comfy.cli_args import args, LatentPreviewMethod
4
+ from comfy.taesd.taesd import TAESD
5
+ import comfy.model_management
6
+ import folder_paths
7
+ import comfy.utils
8
+ import logging
9
+
10
+ MAX_PREVIEW_RESOLUTION = args.preview_size
11
+
12
+ def preview_to_image(latent_image):
13
+ latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
14
+ .mul(0xFF) # to 0..255
15
+ )
16
+ if comfy.model_management.directml_enabled:
17
+ latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
18
+ latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
19
+
20
+ return Image.fromarray(latents_ubyte.numpy())
21
+
22
+ class LatentPreviewer:
23
+ def decode_latent_to_preview(self, x0):
24
+ pass
25
+
26
+ def decode_latent_to_preview_image(self, preview_format, x0):
27
+ preview_image = self.decode_latent_to_preview(x0)
28
+ return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
29
+
30
+ class TAESDPreviewerImpl(LatentPreviewer):
31
+ def __init__(self, taesd):
32
+ self.taesd = taesd
33
+
34
+ def decode_latent_to_preview(self, x0):
35
+ x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
36
+ return preview_to_image(x_sample)
37
+
38
+
39
+ class Latent2RGBPreviewer(LatentPreviewer):
40
+ def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
41
+ self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
42
+ self.latent_rgb_factors_bias = None
43
+ if latent_rgb_factors_bias is not None:
44
+ self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
45
+
46
+ def decode_latent_to_preview(self, x0):
47
+ self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
48
+ if self.latent_rgb_factors_bias is not None:
49
+ self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
50
+
51
+ if x0.ndim == 5:
52
+ x0 = x0[0, :, 0]
53
+ else:
54
+ x0 = x0[0]
55
+
56
+ latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
57
+ # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
58
+
59
+ return preview_to_image(latent_image)
60
+
61
+
62
+ def get_previewer(device, latent_format):
63
+ previewer = None
64
+ method = args.preview_method
65
+ if method != LatentPreviewMethod.NoPreviews:
66
+ # TODO previewer methods
67
+ taesd_decoder_path = None
68
+ if latent_format.taesd_decoder_name is not None:
69
+ taesd_decoder_path = next(
70
+ (fn for fn in folder_paths.get_filename_list("vae_approx")
71
+ if fn.startswith(latent_format.taesd_decoder_name)),
72
+ ""
73
+ )
74
+ taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
75
+
76
+ if method == LatentPreviewMethod.Auto:
77
+ method = LatentPreviewMethod.Latent2RGB
78
+
79
+ if method == LatentPreviewMethod.TAESD:
80
+ if taesd_decoder_path:
81
+ taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
82
+ previewer = TAESDPreviewerImpl(taesd)
83
+ else:
84
+ logging.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
85
+
86
+ if previewer is None:
87
+ if latent_format.latent_rgb_factors is not None:
88
+ previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
89
+ return previewer
90
+
91
+ def prepare_callback(model, steps, x0_output_dict=None):
92
+ preview_format = "JPEG"
93
+ if preview_format not in ["JPEG", "PNG"]:
94
+ preview_format = "JPEG"
95
+
96
+ previewer = get_previewer(model.load_device, model.model.latent_format)
97
+
98
+ pbar = comfy.utils.ProgressBar(steps)
99
+ def callback(step, x0, x, total_steps):
100
+ if x0_output_dict is not None:
101
+ x0_output_dict["x0"] = x0
102
+
103
+ preview_bytes = None
104
+ if previewer:
105
+ preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
106
+ pbar.update_absolute(step + 1, total_steps, preview_bytes)
107
+ return callback
108
+