Upload 4 files
Browse files- cog.yaml +11 -0
- image_to_image.py +281 -0
- predict.py +136 -0
- script/download-weights +18 -0
cog.yaml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
cuda: "11.6.2"
|
4 |
+
python_version: "3.10"
|
5 |
+
python_packages:
|
6 |
+
- "diffusers==0.2.4"
|
7 |
+
- "torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116"
|
8 |
+
- "ftfy==6.1.1"
|
9 |
+
- "scipy==1.9.0"
|
10 |
+
- "transformers==4.21.1"
|
11 |
+
predict: "predict.py:Predictor"
|
image_to_image.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import List, Optional, Union, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from diffusers import (
|
8 |
+
AutoencoderKL,
|
9 |
+
DDIMScheduler,
|
10 |
+
DiffusionPipeline,
|
11 |
+
PNDMScheduler,
|
12 |
+
LMSDiscreteScheduler,
|
13 |
+
UNet2DConditionModel,
|
14 |
+
)
|
15 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
16 |
+
from tqdm.auto import tqdm
|
17 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
18 |
+
|
19 |
+
|
20 |
+
def preprocess_init_image(image: Image, width: int, height: int):
|
21 |
+
image = image.resize((width, height), resample=Image.LANCZOS)
|
22 |
+
image = np.array(image).astype(np.float32) / 255.0
|
23 |
+
image = image[None].transpose(0, 3, 1, 2)
|
24 |
+
image = torch.from_numpy(image)
|
25 |
+
return 2.0 * image - 1.0
|
26 |
+
|
27 |
+
|
28 |
+
def preprocess_mask(mask: Image, width: int, height: int):
|
29 |
+
mask = mask.convert("L")
|
30 |
+
mask = mask.resize((width // 8, height // 8), resample=Image.LANCZOS)
|
31 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
32 |
+
mask = np.tile(mask, (4, 1, 1))
|
33 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
34 |
+
mask = torch.from_numpy(mask)
|
35 |
+
return mask
|
36 |
+
|
37 |
+
|
38 |
+
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
39 |
+
"""
|
40 |
+
From https://github.com/huggingface/diffusers/pull/241
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
vae: AutoencoderKL,
|
46 |
+
text_encoder: CLIPTextModel,
|
47 |
+
tokenizer: CLIPTokenizer,
|
48 |
+
unet: UNet2DConditionModel,
|
49 |
+
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
50 |
+
safety_checker: StableDiffusionSafetyChecker,
|
51 |
+
feature_extractor: CLIPFeatureExtractor,
|
52 |
+
):
|
53 |
+
super().__init__()
|
54 |
+
scheduler = scheduler.set_format("pt")
|
55 |
+
self.register_modules(
|
56 |
+
vae=vae,
|
57 |
+
text_encoder=text_encoder,
|
58 |
+
tokenizer=tokenizer,
|
59 |
+
unet=unet,
|
60 |
+
scheduler=scheduler,
|
61 |
+
safety_checker=safety_checker,
|
62 |
+
feature_extractor=feature_extractor,
|
63 |
+
)
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def __call__(
|
67 |
+
self,
|
68 |
+
prompt: Union[str, List[str]],
|
69 |
+
init_image: Optional[torch.FloatTensor],
|
70 |
+
mask: Optional[torch.FloatTensor],
|
71 |
+
width: int,
|
72 |
+
height: int,
|
73 |
+
prompt_strength: float = 0.8,
|
74 |
+
num_inference_steps: int = 50,
|
75 |
+
guidance_scale: float = 7.5,
|
76 |
+
eta: float = 0.0,
|
77 |
+
generator: Optional[torch.Generator] = None,
|
78 |
+
) -> Image:
|
79 |
+
if isinstance(prompt, str):
|
80 |
+
batch_size = 1
|
81 |
+
elif isinstance(prompt, list):
|
82 |
+
batch_size = len(prompt)
|
83 |
+
else:
|
84 |
+
raise ValueError(
|
85 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
86 |
+
)
|
87 |
+
|
88 |
+
if prompt_strength < 0 or prompt_strength > 1:
|
89 |
+
raise ValueError(
|
90 |
+
f"The value of prompt_strength should in [0.0, 1.0] but is {prompt_strength}"
|
91 |
+
)
|
92 |
+
|
93 |
+
if mask is not None and init_image is None:
|
94 |
+
raise ValueError(
|
95 |
+
"If mask is defined, then init_image also needs to be defined"
|
96 |
+
)
|
97 |
+
|
98 |
+
if width % 8 != 0 or height % 8 != 0:
|
99 |
+
raise ValueError("Width and height must both be divisible by 8")
|
100 |
+
|
101 |
+
# set timesteps
|
102 |
+
accepts_offset = "offset" in set(
|
103 |
+
inspect.signature(self.scheduler.set_timesteps).parameters.keys()
|
104 |
+
)
|
105 |
+
extra_set_kwargs = {}
|
106 |
+
offset = 0
|
107 |
+
if accepts_offset:
|
108 |
+
offset = 1
|
109 |
+
extra_set_kwargs["offset"] = 1
|
110 |
+
|
111 |
+
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
112 |
+
|
113 |
+
if init_image is not None:
|
114 |
+
init_latents_orig, latents, init_timestep = self.latents_from_init_image(
|
115 |
+
init_image,
|
116 |
+
prompt_strength,
|
117 |
+
offset,
|
118 |
+
num_inference_steps,
|
119 |
+
batch_size,
|
120 |
+
generator,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
latents = torch.randn(
|
124 |
+
(batch_size, self.unet.in_channels, height // 8, width // 8),
|
125 |
+
generator=generator,
|
126 |
+
device=self.device,
|
127 |
+
)
|
128 |
+
init_timestep = num_inference_steps
|
129 |
+
|
130 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
131 |
+
text_embeddings = self.embed_text(
|
132 |
+
prompt, do_classifier_free_guidance, batch_size
|
133 |
+
)
|
134 |
+
|
135 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
136 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
137 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
138 |
+
# and should be between [0, 1]
|
139 |
+
accepts_eta = "eta" in set(
|
140 |
+
inspect.signature(self.scheduler.step).parameters.keys()
|
141 |
+
)
|
142 |
+
extra_step_kwargs = {}
|
143 |
+
if accepts_eta:
|
144 |
+
extra_step_kwargs["eta"] = eta
|
145 |
+
|
146 |
+
mask_noise = torch.randn(latents.shape, generator=generator, device=self.device)
|
147 |
+
|
148 |
+
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
|
149 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
150 |
+
latents = latents * self.scheduler.sigmas[0]
|
151 |
+
|
152 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
153 |
+
for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
|
154 |
+
# expand the latents if we are doing classifier free guidance
|
155 |
+
latent_model_input = (
|
156 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
157 |
+
)
|
158 |
+
|
159 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
160 |
+
sigma = self.scheduler.sigmas[i]
|
161 |
+
latent_model_input = latent_model_input / ((sigma ** 2 + 1) ** 0.5)
|
162 |
+
|
163 |
+
# predict the noise residual
|
164 |
+
noise_pred = self.unet(
|
165 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings
|
166 |
+
)["sample"]
|
167 |
+
|
168 |
+
# perform guidance
|
169 |
+
if do_classifier_free_guidance:
|
170 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
171 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
172 |
+
noise_pred_text - noise_pred_uncond
|
173 |
+
)
|
174 |
+
|
175 |
+
# compute the previous noisy sample x_t -> x_t-1
|
176 |
+
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
177 |
+
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)[
|
178 |
+
"prev_sample"
|
179 |
+
]
|
180 |
+
else:
|
181 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)[
|
182 |
+
"prev_sample"
|
183 |
+
]
|
184 |
+
|
185 |
+
# replace the unmasked part with original latents, with added noise
|
186 |
+
if mask is not None:
|
187 |
+
timesteps = self.scheduler.timesteps[t_start + i]
|
188 |
+
timesteps = torch.tensor(
|
189 |
+
[timesteps] * batch_size, dtype=torch.long, device=self.device
|
190 |
+
)
|
191 |
+
noisy_init_latents = self.scheduler.add_noise(init_latents_orig, mask_noise, timesteps)
|
192 |
+
latents = noisy_init_latents * mask + latents * (1 - mask)
|
193 |
+
|
194 |
+
# scale and decode the image latents with vae
|
195 |
+
latents = 1 / 0.18215 * latents
|
196 |
+
image = self.vae.decode(latents)
|
197 |
+
|
198 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
199 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
200 |
+
|
201 |
+
# run safety checker
|
202 |
+
safety_cheker_input = self.feature_extractor(
|
203 |
+
self.numpy_to_pil(image), return_tensors="pt"
|
204 |
+
).to(self.device)
|
205 |
+
image, has_nsfw_concept = self.safety_checker(
|
206 |
+
images=image, clip_input=safety_cheker_input.pixel_values
|
207 |
+
)
|
208 |
+
|
209 |
+
image = self.numpy_to_pil(image)
|
210 |
+
|
211 |
+
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
212 |
+
|
213 |
+
def latents_from_init_image(
|
214 |
+
self,
|
215 |
+
init_image: torch.FloatTensor,
|
216 |
+
prompt_strength: float,
|
217 |
+
offset: int,
|
218 |
+
num_inference_steps: int,
|
219 |
+
batch_size: int,
|
220 |
+
generator: Optional[torch.Generator],
|
221 |
+
) -> Tuple[torch.FloatTensor, torch.FloatTensor, int]:
|
222 |
+
# encode the init image into latents and scale the latents
|
223 |
+
init_latents = self.vae.encode(init_image.to(self.device)).sample()
|
224 |
+
init_latents = 0.18215 * init_latents
|
225 |
+
init_latents_orig = init_latents
|
226 |
+
|
227 |
+
# prepare init_latents noise to latents
|
228 |
+
init_latents = torch.cat([init_latents] * batch_size)
|
229 |
+
|
230 |
+
# get the original timestep using init_timestep
|
231 |
+
init_timestep = int(num_inference_steps * prompt_strength) + offset
|
232 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
233 |
+
timesteps = self.scheduler.timesteps[-init_timestep]
|
234 |
+
timesteps = torch.tensor(
|
235 |
+
[timesteps] * batch_size, dtype=torch.long, device=self.device
|
236 |
+
)
|
237 |
+
|
238 |
+
# add noise to latents using the timesteps
|
239 |
+
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
240 |
+
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
241 |
+
|
242 |
+
return init_latents_orig, init_latents, init_timestep
|
243 |
+
|
244 |
+
def embed_text(
|
245 |
+
self,
|
246 |
+
prompt: Union[str, List[str]],
|
247 |
+
do_classifier_free_guidance: bool,
|
248 |
+
batch_size: int,
|
249 |
+
) -> torch.FloatTensor:
|
250 |
+
# get prompt text embeddings
|
251 |
+
text_input = self.tokenizer(
|
252 |
+
prompt,
|
253 |
+
padding="max_length",
|
254 |
+
max_length=self.tokenizer.model_max_length,
|
255 |
+
truncation=True,
|
256 |
+
return_tensors="pt",
|
257 |
+
)
|
258 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
259 |
+
|
260 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
261 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
262 |
+
# corresponds to doing no classifier free guidance.
|
263 |
+
# get unconditional embeddings for classifier free guidance
|
264 |
+
if do_classifier_free_guidance:
|
265 |
+
max_length = text_input.input_ids.shape[-1]
|
266 |
+
uncond_input = self.tokenizer(
|
267 |
+
[""] * batch_size,
|
268 |
+
padding="max_length",
|
269 |
+
max_length=max_length,
|
270 |
+
return_tensors="pt",
|
271 |
+
)
|
272 |
+
uncond_embeddings = self.text_encoder(
|
273 |
+
uncond_input.input_ids.to(self.device)
|
274 |
+
)[0]
|
275 |
+
|
276 |
+
# For classifier free guidance, we need to do two forward passes.
|
277 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
278 |
+
# to avoid doing two forward passes
|
279 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
280 |
+
|
281 |
+
return text_embeddings
|
predict.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import autocast
|
7 |
+
from diffusers import PNDMScheduler, LMSDiscreteScheduler
|
8 |
+
from PIL import Image
|
9 |
+
from cog import BasePredictor, Input, Path
|
10 |
+
|
11 |
+
from image_to_image import (
|
12 |
+
StableDiffusionImg2ImgPipeline,
|
13 |
+
preprocess_init_image,
|
14 |
+
preprocess_mask,
|
15 |
+
)
|
16 |
+
|
17 |
+
def patch_conv(**patch):
|
18 |
+
cls = torch.nn.Conv2d
|
19 |
+
init = cls.__init__
|
20 |
+
def __init__(self, *args, **kwargs):
|
21 |
+
return init(self, *args, **kwargs, **patch)
|
22 |
+
cls.__init__ = __init__
|
23 |
+
|
24 |
+
patch_conv(padding_mode='circular')
|
25 |
+
|
26 |
+
MODEL_CACHE = "diffusers-cache"
|
27 |
+
|
28 |
+
|
29 |
+
class Predictor(BasePredictor):
|
30 |
+
def setup(self):
|
31 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
32 |
+
print("Loading pipeline...")
|
33 |
+
scheduler = PNDMScheduler(
|
34 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
35 |
+
)
|
36 |
+
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
37 |
+
"CompVis/stable-diffusion-v1-4",
|
38 |
+
scheduler=scheduler,
|
39 |
+
revision="fp16",
|
40 |
+
torch_dtype=torch.float16,
|
41 |
+
cache_dir=MODEL_CACHE,
|
42 |
+
local_files_only=True,
|
43 |
+
).to("cuda")
|
44 |
+
|
45 |
+
@torch.inference_mode()
|
46 |
+
@torch.cuda.amp.autocast()
|
47 |
+
def predict(
|
48 |
+
self,
|
49 |
+
prompt: str = Input(description="Input prompt", default=""),
|
50 |
+
width: int = Input(
|
51 |
+
description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
|
52 |
+
choices=[128, 256, 512, 768, 1024],
|
53 |
+
default=512,
|
54 |
+
),
|
55 |
+
height: int = Input(
|
56 |
+
description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits",
|
57 |
+
choices=[128, 256, 512, 768, 1024],
|
58 |
+
default=512,
|
59 |
+
),
|
60 |
+
init_image: Path = Input(
|
61 |
+
description="Inital image to generate variations of. Will be resized to the specified width and height",
|
62 |
+
default=None,
|
63 |
+
),
|
64 |
+
mask: Path = Input(
|
65 |
+
description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7",
|
66 |
+
default=None,
|
67 |
+
),
|
68 |
+
prompt_strength: float = Input(
|
69 |
+
description="Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image",
|
70 |
+
default=0.8,
|
71 |
+
),
|
72 |
+
num_outputs: int = Input(
|
73 |
+
description="Number of images to output", choices=[1, 4], default=1
|
74 |
+
),
|
75 |
+
num_inference_steps: int = Input(
|
76 |
+
description="Number of denoising steps", ge=1, le=500, default=50
|
77 |
+
),
|
78 |
+
guidance_scale: float = Input(
|
79 |
+
description="Scale for classifier-free guidance", ge=1, le=20, default=7.5
|
80 |
+
),
|
81 |
+
seed: int = Input(
|
82 |
+
description="Random seed. Leave blank to randomize the seed", default=None
|
83 |
+
),
|
84 |
+
) -> List[Path]:
|
85 |
+
"""Run a single prediction on the model"""
|
86 |
+
if seed is None:
|
87 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
88 |
+
print(f"Using seed: {seed}")
|
89 |
+
|
90 |
+
if width == height == 1024:
|
91 |
+
raise ValueError(
|
92 |
+
"Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height."
|
93 |
+
)
|
94 |
+
|
95 |
+
if init_image:
|
96 |
+
init_image = Image.open(init_image).convert("RGB")
|
97 |
+
init_image = preprocess_init_image(init_image, width, height).to("cuda")
|
98 |
+
|
99 |
+
# use PNDM with init images
|
100 |
+
scheduler = PNDMScheduler(
|
101 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
# use LMS without init images
|
105 |
+
scheduler = LMSDiscreteScheduler(
|
106 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
107 |
+
)
|
108 |
+
|
109 |
+
self.pipe.scheduler = scheduler
|
110 |
+
|
111 |
+
if mask:
|
112 |
+
mask = Image.open(mask).convert("RGB")
|
113 |
+
mask = preprocess_mask(mask, width, height).to("cuda")
|
114 |
+
|
115 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
116 |
+
output = self.pipe(
|
117 |
+
prompt=[prompt] * num_outputs if prompt is not None else None,
|
118 |
+
init_image=init_image,
|
119 |
+
mask=mask,
|
120 |
+
width=width,
|
121 |
+
height=height,
|
122 |
+
prompt_strength=prompt_strength,
|
123 |
+
guidance_scale=guidance_scale,
|
124 |
+
generator=generator,
|
125 |
+
num_inference_steps=num_inference_steps,
|
126 |
+
)
|
127 |
+
if any(output["nsfw_content_detected"]):
|
128 |
+
raise Exception("NSFW content detected, please try a different prompt")
|
129 |
+
|
130 |
+
output_paths = []
|
131 |
+
for i, sample in enumerate(output["sample"]):
|
132 |
+
output_path = f"/tmp/out-{i}.png"
|
133 |
+
sample.save(output_path)
|
134 |
+
output_paths.append(Path(output_path))
|
135 |
+
|
136 |
+
return output_paths
|
script/download-weights
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import StableDiffusionPipeline
|
8 |
+
|
9 |
+
os.makedirs("diffusers-cache", exist_ok=True)
|
10 |
+
|
11 |
+
|
12 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
13 |
+
"CompVis/stable-diffusion-v1-4",
|
14 |
+
cache_dir="diffusers-cache",
|
15 |
+
revision="fp16",
|
16 |
+
torch_dtype=torch.float16,
|
17 |
+
use_auth_token=sys.argv[1],
|
18 |
+
)
|