clone3 commited on
Commit
188fd89
·
verified ·
1 Parent(s): f188e91

Upload 4 files

Browse files
Files changed (4) hide show
  1. cog.yaml +11 -0
  2. image_to_image.py +281 -0
  3. predict.py +136 -0
  4. script/download-weights +18 -0
cog.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.6.2"
4
+ python_version: "3.10"
5
+ python_packages:
6
+ - "diffusers==0.2.4"
7
+ - "torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116"
8
+ - "ftfy==6.1.1"
9
+ - "scipy==1.9.0"
10
+ - "transformers==4.21.1"
11
+ predict: "predict.py:Predictor"
image_to_image.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import List, Optional, Union, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from diffusers import (
8
+ AutoencoderKL,
9
+ DDIMScheduler,
10
+ DiffusionPipeline,
11
+ PNDMScheduler,
12
+ LMSDiscreteScheduler,
13
+ UNet2DConditionModel,
14
+ )
15
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
16
+ from tqdm.auto import tqdm
17
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
18
+
19
+
20
+ def preprocess_init_image(image: Image, width: int, height: int):
21
+ image = image.resize((width, height), resample=Image.LANCZOS)
22
+ image = np.array(image).astype(np.float32) / 255.0
23
+ image = image[None].transpose(0, 3, 1, 2)
24
+ image = torch.from_numpy(image)
25
+ return 2.0 * image - 1.0
26
+
27
+
28
+ def preprocess_mask(mask: Image, width: int, height: int):
29
+ mask = mask.convert("L")
30
+ mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS)
31
+ mask = np.array(mask).astype(np.float32) / 255.0
32
+ mask = np.tile(mask, (4, 1, 1))
33
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
34
+ mask = torch.from_numpy(mask)
35
+ return mask
36
+
37
+
38
+ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
39
+ """
40
+ From https://github.com/huggingface/diffusers/pull/241
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ vae: AutoencoderKL,
46
+ text_encoder: CLIPTextModel,
47
+ tokenizer: CLIPTokenizer,
48
+ unet: UNet2DConditionModel,
49
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
50
+ safety_checker: StableDiffusionSafetyChecker,
51
+ feature_extractor: CLIPFeatureExtractor,
52
+ ):
53
+ super().__init__()
54
+ scheduler = scheduler.set_format("pt")
55
+ self.register_modules(
56
+ vae=vae,
57
+ text_encoder=text_encoder,
58
+ tokenizer=tokenizer,
59
+ unet=unet,
60
+ scheduler=scheduler,
61
+ safety_checker=safety_checker,
62
+ feature_extractor=feature_extractor,
63
+ )
64
+
65
+ @torch.no_grad()
66
+ def __call__(
67
+ self,
68
+ prompt: Union[str, List[str]],
69
+ init_image: Optional[torch.FloatTensor],
70
+ mask: Optional[torch.FloatTensor],
71
+ width: int,
72
+ height: int,
73
+ prompt_strength: float = 0.8,
74
+ num_inference_steps: int = 50,
75
+ guidance_scale: float = 7.5,
76
+ eta: float = 0.0,
77
+ generator: Optional[torch.Generator] = None,
78
+ ) -> Image:
79
+ if isinstance(prompt, str):
80
+ batch_size = 1
81
+ elif isinstance(prompt, list):
82
+ batch_size = len(prompt)
83
+ else:
84
+ raise ValueError(
85
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
86
+ )
87
+
88
+ if prompt_strength < 0 or prompt_strength > 1:
89
+ raise ValueError(
90
+ f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}"
91
+ )
92
+
93
+ if mask is not None and init_image is None:
94
+ raise ValueError(
95
+ "If mask is defined, then init_image also needs to be defined"
96
+ )
97
+
98
+ if width % 8 != 0 or height % 8 != 0:
99
+ raise ValueError("Width and height must both be divisible by 8")
100
+
101
+ # set timesteps
102
+ accepts_offset = "offset" in set(
103
+ inspect.signature(self.scheduler.set_timesteps).parameters.keys()
104
+ )
105
+ extra_set_kwargs = {}
106
+ offset = 0
107
+ if accepts_offset:
108
+ offset = 1
109
+ extra_set_kwargs["offset"] = 1
110
+
111
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
112
+
113
+ if init_image is not None:
114
+ init_latents_orig, latents, init_timestep = self.latents_from_init_image(
115
+ init_image,
116
+ prompt_strength,
117
+ offset,
118
+ num_inference_steps,
119
+ batch_size,
120
+ generator,
121
+ )
122
+ else:
123
+ latents = torch.randn(
124
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
125
+ generator=generator,
126
+ device=self.device,
127
+ )
128
+ init_timestep = num_inference_steps
129
+
130
+ do_classifier_free_guidance = guidance_scale > 1.0
131
+ text_embeddings = self.embed_text(
132
+ prompt, do_classifier_free_guidance, batch_size
133
+ )
134
+
135
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
136
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
137
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
138
+ # and should be between [0, 1]
139
+ accepts_eta = "eta" in set(
140
+ inspect.signature(self.scheduler.step).parameters.keys()
141
+ )
142
+ extra_step_kwargs = {}
143
+ if accepts_eta:
144
+ extra_step_kwargs["eta"] = eta
145
+
146
+ mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)
147
+
148
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
149
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
150
+ latents = latents * self.scheduler.sigmas[0]
151
+
152
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
153
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
154
+ # expand the latents if we are doing classifier free guidance
155
+ latent_model_input = (
156
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
157
+ )
158
+
159
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
160
+ sigma = self.scheduler.sigmas[i]
161
+ latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)
162
+
163
+ # predict the noise residual
164
+ noise_pred = self.unet(
165
+ latent_model_input, t, encoder_hidden_states=text_embeddings
166
+ )["sample"]
167
+
168
+ # perform guidance
169
+ if do_classifier_free_guidance:
170
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
171
+ noise_pred = noise_pred_uncond + guidance_scale * (
172
+ noise_pred_text - noise_pred_uncond
173
+ )
174
+
175
+ # compute the previous noisy sample x_t -> x_t-1
176
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
177
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[
178
+ "prev_sample"
179
+ ]
180
+ else:
181
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
182
+ "prev_sample"
183
+ ]
184
+
185
+ # replace the unmasked part with original latents, with added noise
186
+ if mask is not None:
187
+ timesteps = self.scheduler.timesteps[t_start + i]
188
+ timesteps = torch.tensor(
189
+ [timesteps] * batch_size, dtype=torch.long, device=self.device
190
+ )
191
+ noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps)
192
+ latents = noisy_init_latents * mask + latents * (1 - mask)
193
+
194
+ # scale and decode the image latents with vae
195
+ latents = 1 / 0.18215 * latents
196
+ image = self.vae.decode(latents)
197
+
198
+ image = (image / 2 + 0.5).clamp(0, 1)
199
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
200
+
201
+ # run safety checker
202
+ safety_cheker_input = self.feature_extractor(
203
+ self.numpy_to_pil(image), return_tensors="pt"
204
+ ).to(self.device)
205
+ image, has_nsfw_concept = self.safety_checker(
206
+ images=image, clip_input=safety_cheker_input.pixel_values
207
+ )
208
+
209
+ image = self.numpy_to_pil(image)
210
+
211
+ return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
212
+
213
+ def latents_from_init_image(
214
+ self,
215
+ init_image: torch.FloatTensor,
216
+ prompt_strength: float,
217
+ offset: int,
218
+ num_inference_steps: int,
219
+ batch_size: int,
220
+ generator: Optional[torch.Generator],
221
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]:
222
+ # encode the init image into latents and scale the latents
223
+ init_latents = self.vae.encode(init_image.to(self.device)).sample()
224
+ init_latents = 0.18215 * init_latents
225
+ init_latents_orig = init_latents
226
+
227
+ # prepare init_latents noise to latents
228
+ init_latents = torch.cat([init_latents] * batch_size)
229
+
230
+ # get the original timestep using init_timestep
231
+ init_timestep = int(num_inference_steps * prompt_strength) + offset
232
+ init_timestep = min(init_timestep, num_inference_steps)
233
+ timesteps = self.scheduler.timesteps[-init_timestep]
234
+ timesteps = torch.tensor(
235
+ [timesteps] * batch_size, dtype=torch.long, device=self.device
236
+ )
237
+
238
+ # add noise to latents using the timesteps
239
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
240
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
241
+
242
+ return init_latents_orig, init_latents, init_timestep
243
+
244
+ def embed_text(
245
+ self,
246
+ prompt: Union[str, List[str]],
247
+ do_classifier_free_guidance: bool,
248
+ batch_size: int,
249
+ ) -> torch.FloatTensor:
250
+ # get prompt text embeddings
251
+ text_input = self.tokenizer(
252
+ prompt,
253
+ padding="max_length",
254
+ max_length=self.tokenizer.model_max_length,
255
+ truncation=True,
256
+ return_tensors="pt",
257
+ )
258
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
259
+
260
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
261
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
262
+ # corresponds to doing no classifier free guidance.
263
+ # get unconditional embeddings for classifier free guidance
264
+ if do_classifier_free_guidance:
265
+ max_length = text_input.input_ids.shape[-1]
266
+ uncond_input = self.tokenizer(
267
+ [""] * batch_size,
268
+ padding="max_length",
269
+ max_length=max_length,
270
+ return_tensors="pt",
271
+ )
272
+ uncond_embeddings = self.text_encoder(
273
+ uncond_input.input_ids.to(self.device)
274
+ )[0]
275
+
276
+ # For classifier free guidance, we need to do two forward passes.
277
+ # Here we concatenate the unconditional and text embeddings into a single batch
278
+ # to avoid doing two forward passes
279
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
280
+
281
+ return text_embeddings
predict.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import autocast
7
+ from diffusers import PNDMScheduler, LMSDiscreteScheduler
8
+ from PIL import Image
9
+ from cog import BasePredictor, Input, Path
10
+
11
+ from image_to_image import (
12
+ StableDiffusionImg2ImgPipeline,
13
+ preprocess_init_image,
14
+ preprocess_mask,
15
+ )
16
+
17
+ def patch_conv(**patch):
18
+ cls = torch.nn.Conv2d
19
+ init = cls.__init__
20
+ def __init__(self, *args, **kwargs):
21
+ return init(self, *args, **kwargs, **patch)
22
+ cls.__init__ = __init__
23
+
24
+ patch_conv(padding_mode='circular')
25
+
26
+ MODEL_CACHE = "diffusers-cache"
27
+
28
+
29
+ class Predictor(BasePredictor):
30
+ def setup(self):
31
+ """Load the model into memory to make running multiple predictions efficient"""
32
+ print("Loading pipeline...")
33
+ scheduler = PNDMScheduler(
34
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
35
+ )
36
+ self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
37
+ "CompVis/stable-diffusion-v1-4",
38
+ scheduler=scheduler,
39
+ revision="fp16",
40
+ torch_dtype=torch.float16,
41
+ cache_dir=MODEL_CACHE,
42
+ local_files_only=True,
43
+ ).to("cuda")
44
+
45
+ @torch.inference_mode()
46
+ @torch.cuda.amp.autocast()
47
+ def predict(
48
+ self,
49
+ prompt: str = Input(description="Input prompt", default=""),
50
+ width: int = Input(
51
+ description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
52
+ choices=[128, 256, 512, 768, 1024],
53
+ default=512,
54
+ ),
55
+ height: int = Input(
56
+ description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
57
+ choices=[128, 256, 512, 768, 1024],
58
+ default=512,
59
+ ),
60
+ init_image: Path = Input(
61
+ description="Inital image to generate variations of. Will be resized to the specified width and height",
62
+ default=None,
63
+ ),
64
+ mask: Path = Input(
65
+ description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7",
66
+ default=None,
67
+ ),
68
+ prompt_strength: float = Input(
69
+ description="Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image",
70
+ default=0.8,
71
+ ),
72
+ num_outputs: int = Input(
73
+ description="Number of images to output", choices=[1, 4], default=1
74
+ ),
75
+ num_inference_steps: int = Input(
76
+ description="Number of denoising steps", ge=1, le=500, default=50
77
+ ),
78
+ guidance_scale: float = Input(
79
+ description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
80
+ ),
81
+ seed: int = Input(
82
+ description="Random seed. Leave blank to randomize the seed", default=None
83
+ ),
84
+ ) -> List[Path]:
85
+ """Run a single prediction on the model"""
86
+ if seed is None:
87
+ seed = int.from_bytes(os.urandom(2), "big")
88
+ print(f"Using seed: {seed}")
89
+
90
+ if width == height == 1024:
91
+ raise ValueError(
92
+ "Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height."
93
+ )
94
+
95
+ if init_image:
96
+ init_image = Image.open(init_image).convert("RGB")
97
+ init_image = preprocess_init_image(init_image, width, height).to("cuda")
98
+
99
+ # use PNDM with init images
100
+ scheduler = PNDMScheduler(
101
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
102
+ )
103
+ else:
104
+ # use LMS without init images
105
+ scheduler = LMSDiscreteScheduler(
106
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
107
+ )
108
+
109
+ self.pipe.scheduler = scheduler
110
+
111
+ if mask:
112
+ mask = Image.open(mask).convert("RGB")
113
+ mask = preprocess_mask(mask, width, height).to("cuda")
114
+
115
+ generator = torch.Generator("cuda").manual_seed(seed)
116
+ output = self.pipe(
117
+ prompt=[prompt] * num_outputs if prompt is not None else None,
118
+ init_image=init_image,
119
+ mask=mask,
120
+ width=width,
121
+ height=height,
122
+ prompt_strength=prompt_strength,
123
+ guidance_scale=guidance_scale,
124
+ generator=generator,
125
+ num_inference_steps=num_inference_steps,
126
+ )
127
+ if any(output["nsfw_content_detected"]):
128
+ raise Exception("NSFW content detected, please try a different prompt")
129
+
130
+ output_paths = []
131
+ for i, sample in enumerate(output["sample"]):
132
+ output_path = f"/tmp/out-{i}.png"
133
+ sample.save(output_path)
134
+ output_paths.append(Path(output_path))
135
+
136
+ return output_paths
script/download-weights ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ from diffusers import StableDiffusionPipeline
8
+
9
+ os.makedirs("diffusers-cache", exist_ok=True)
10
+
11
+
12
+ pipe = StableDiffusionPipeline.from_pretrained(
13
+ "CompVis/stable-diffusion-v1-4",
14
+ cache_dir="diffusers-cache",
15
+ revision="fp16",
16
+ torch_dtype=torch.float16,
17
+ use_auth_token=sys.argv[1],
18
+ )