jayparmr commited on
Commit
86248f3
1 Parent(s): 0daeeb0

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -15,7 +15,8 @@ from internals.util.args import apply_style_args
15
  from internals.util.avatar import Avatar
16
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
17
  from internals.util.commons import pickPoses, upload_image, upload_images
18
- from internals.util.config import set_configs_from_task, set_root_dir
 
19
  from internals.util.failure_hander import FailureHandler
20
  from internals.util.lora_style import LoraStyle
21
  from internals.util.slack import Slack
@@ -23,7 +24,6 @@ from internals.util.slack import Slack
23
  torch.backends.cudnn.benchmark = True
24
  torch.backends.cuda.matmul.allow_tf32 = True
25
 
26
- num_return_sequences = 4 # the number of results to generate
27
  auto_mode = False
28
 
29
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
@@ -57,10 +57,6 @@ def get_patched_prompt(task: Task):
57
  ori_prompt = [task.get_prompt()] * num_return_sequences
58
 
59
  class_name = None
60
- # if task.get_imageUrl():
61
- # class_name = img_classifier.classify(
62
- # task.get_imageUrl(), task.get_width(), task.get_height()
63
- # )
64
  add_style_and_character(ori_prompt, class_name)
65
  add_style_and_character(prompt, class_name)
66
 
@@ -69,6 +65,54 @@ def get_patched_prompt(task: Task):
69
  return (prompt, ori_prompt)
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def get_patched_prompt_tile_upscale(task: Task):
73
  if task.get_prompt():
74
  prompt = task.get_prompt()
@@ -164,6 +208,72 @@ def tile_upscale(task: Task):
164
  }
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  @update_db
168
  @auto_clear_cuda_and_gc(controlnet)
169
  @slack.auto_send_alert
@@ -207,7 +317,7 @@ def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
207
  @auto_clear_cuda_and_gc(controlnet)
208
  @slack.auto_send_alert
209
  def text2img(task: Task):
210
- prompt, ori_prompt = get_patched_prompt(task)
211
 
212
  lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
213
  lora_patcher.patch()
@@ -215,13 +325,12 @@ def text2img(task: Task):
215
  torch.manual_seed(task.get_seed())
216
 
217
  images, has_nsfw = text2img_pipe.process(
218
- prompt=ori_prompt,
219
- modified_prompts=prompt,
220
  num_inference_steps=task.get_steps(),
221
  guidance_scale=7.5,
222
  height=task.get_height(),
223
  width=task.get_width(),
224
- negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
225
  iteration=task.get_iteration(),
226
  **lora_patcher.kwargs(),
227
  )
@@ -231,7 +340,7 @@ def text2img(task: Task):
231
  lora_patcher.cleanup()
232
 
233
  return {
234
- "modified_prompts": prompt,
235
  "generated_image_urls": generated_image_urls,
236
  "has_nsfw": has_nsfw,
237
  }
@@ -361,6 +470,10 @@ def predict_fn(data, pipe):
361
  return tile_upscale(task)
362
  elif task_type == TaskType.INPAINT:
363
  return inpaint(task)
 
 
 
 
364
  else:
365
  raise Exception("Invalid task type")
366
  except Exception as e:
 
15
  from internals.util.avatar import Avatar
16
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
17
  from internals.util.commons import pickPoses, upload_image, upload_images
18
+ from internals.util.config import (num_return_sequences, set_configs_from_task,
19
+ set_root_dir)
20
  from internals.util.failure_hander import FailureHandler
21
  from internals.util.lora_style import LoraStyle
22
  from internals.util.slack import Slack
 
24
  torch.backends.cudnn.benchmark = True
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
 
 
27
  auto_mode = False
28
 
29
  prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
 
57
  ori_prompt = [task.get_prompt()] * num_return_sequences
58
 
59
  class_name = None
 
 
 
 
60
  add_style_and_character(ori_prompt, class_name)
61
  add_style_and_character(prompt, class_name)
62
 
 
65
  return (prompt, ori_prompt)
66
 
67
 
68
+ def get_patched_prompt_text2img(task: Task) -> Text2Img.Params:
69
+ def add_style_and_character(prompt: str, prepend: str = ""):
70
+ prompt = avatar.add_code_names(prompt)
71
+ prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
72
+ prompt = prepend + prompt
73
+ return prompt
74
+
75
+ if task.get_prompt_left() and task.get_prompt_right():
76
+ # prepend = "2characters, "
77
+ prepend = ""
78
+ if task.is_prompt_engineering():
79
+ mod_prompt = prompt_modifier.modify(task.get_prompt())
80
+ else:
81
+ mod_prompt = [task.get_prompt()] * num_return_sequences
82
+
83
+ prompt, prompt_left, prompt_right = [], [], []
84
+ for i in range(len(mod_prompt)):
85
+ mp = mod_prompt[i].replace(task.get_prompt(), "")
86
+ prompt.append(add_style_and_character(task.get_prompt(), prepend) + mp)
87
+ prompt_left.append(
88
+ add_style_and_character(task.get_prompt_left(), prepend) + mp
89
+ )
90
+ prompt_right.append(
91
+ add_style_and_character(task.get_prompt_right(), prepend) + mp
92
+ )
93
+
94
+ params = Text2Img.Params(
95
+ prompt=prompt,
96
+ prompt_left=prompt_left,
97
+ prompt_right=prompt_right,
98
+ )
99
+ else:
100
+ if task.is_prompt_engineering():
101
+ mod_prompt = prompt_modifier.modify(task.get_prompt())
102
+ else:
103
+ mod_prompt = [task.get_prompt()] * num_return_sequences
104
+ mod_prompt = [add_style_and_character(mp) for mp in mod_prompt]
105
+
106
+ params = Text2Img.Params(
107
+ prompt=[add_style_and_character(task.get_prompt())] * num_return_sequences,
108
+ modified_prompt=mod_prompt,
109
+ )
110
+
111
+ print(params)
112
+
113
+ return params
114
+
115
+
116
  def get_patched_prompt_tile_upscale(task: Task):
117
  if task.get_prompt():
118
  prompt = task.get_prompt()
 
208
  }
209
 
210
 
211
+ @update_db
212
+ @auto_clear_cuda_and_gc(controlnet)
213
+ @slack.auto_send_alert
214
+ def scribble(task: Task):
215
+ prompt, _ = get_patched_prompt(task)
216
+
217
+ controlnet.load_scribble()
218
+
219
+ lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
220
+ lora_patcher.patch()
221
+
222
+ images, has_nsfw = controlnet.process_scribble(
223
+ imageUrl=task.get_imageUrl(),
224
+ seed=task.get_seed(),
225
+ steps=task.get_steps(),
226
+ width=task.get_width(),
227
+ height=task.get_height(),
228
+ prompt=prompt,
229
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
230
+ )
231
+
232
+ generated_image_urls = upload_images(images, "_scribble", task.get_taskId())
233
+
234
+ lora_patcher.cleanup()
235
+ controlnet.cleanup()
236
+
237
+ return {
238
+ "modified_prompts": prompt,
239
+ "generated_image_urls": generated_image_urls,
240
+ "has_nsfw": has_nsfw,
241
+ }
242
+
243
+
244
+ @update_db
245
+ @auto_clear_cuda_and_gc(controlnet)
246
+ @slack.auto_send_alert
247
+ def linearart(task: Task):
248
+ prompt, _ = get_patched_prompt(task)
249
+
250
+ controlnet.load_linearart()
251
+
252
+ lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
253
+ lora_patcher.patch()
254
+
255
+ images, has_nsfw = controlnet.process_linearart(
256
+ imageUrl=task.get_imageUrl(),
257
+ seed=task.get_seed(),
258
+ steps=task.get_steps(),
259
+ width=task.get_width(),
260
+ height=task.get_height(),
261
+ prompt=prompt,
262
+ negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
263
+ )
264
+
265
+ generated_image_urls = upload_images(images, "_linearart", task.get_taskId())
266
+
267
+ lora_patcher.cleanup()
268
+ controlnet.cleanup()
269
+
270
+ return {
271
+ "modified_prompts": prompt,
272
+ "generated_image_urls": generated_image_urls,
273
+ "has_nsfw": has_nsfw,
274
+ }
275
+
276
+
277
  @update_db
278
  @auto_clear_cuda_and_gc(controlnet)
279
  @slack.auto_send_alert
 
317
  @auto_clear_cuda_and_gc(controlnet)
318
  @slack.auto_send_alert
319
  def text2img(task: Task):
320
+ params = get_patched_prompt_text2img(task)
321
 
322
  lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
323
  lora_patcher.patch()
 
325
  torch.manual_seed(task.get_seed())
326
 
327
  images, has_nsfw = text2img_pipe.process(
328
+ params=params,
 
329
  num_inference_steps=task.get_steps(),
330
  guidance_scale=7.5,
331
  height=task.get_height(),
332
  width=task.get_width(),
333
+ negative_prompt=task.get_negative_prompt(),
334
  iteration=task.get_iteration(),
335
  **lora_patcher.kwargs(),
336
  )
 
340
  lora_patcher.cleanup()
341
 
342
  return {
343
+ **params.__dict__,
344
  "generated_image_urls": generated_image_urls,
345
  "has_nsfw": has_nsfw,
346
  }
 
470
  return tile_upscale(task)
471
  elif task_type == TaskType.INPAINT:
472
  return inpaint(task)
473
+ elif task_type == TaskType.SCRIBBLE:
474
+ return scribble(task)
475
+ elif task_type == TaskType.LINEARART:
476
+ return linearart(task)
477
  else:
478
  raise Exception("Invalid task type")
479
  except Exception as e:
inference2.py CHANGED
@@ -12,16 +12,18 @@ from internals.pipelines.safety_checker import SafetyChecker
12
  from internals.pipelines.upscaler import Upscaler
13
  from internals.util.avatar import Avatar
14
  from internals.util.cache import clear_cuda
15
- from internals.util.commons import (construct_default_s3_url, upload_image,
16
- upload_images)
17
- from internals.util.config import set_configs_from_task, set_root_dir
 
 
 
18
  from internals.util.failure_hander import FailureHandler
19
  from internals.util.slack import Slack
20
 
21
  torch.backends.cudnn.benchmark = True
22
  torch.backends.cuda.matmul.allow_tf32 = True
23
 
24
- num_return_sequences = 4
25
  auto_mode = False
26
 
27
  slack = Slack()
 
12
  from internals.pipelines.upscaler import Upscaler
13
  from internals.util.avatar import Avatar
14
  from internals.util.cache import clear_cuda
15
+ from internals.util.commons import construct_default_s3_url, upload_image, upload_images
16
+ from internals.util.config import (
17
+ num_return_sequences,
18
+ set_configs_from_task,
19
+ set_root_dir,
20
+ )
21
  from internals.util.failure_hander import FailureHandler
22
  from internals.util.slack import Slack
23
 
24
  torch.backends.cudnn.benchmark = True
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
 
 
27
  auto_mode = False
28
 
29
  slack = Slack()
internals/data/task.py CHANGED
@@ -14,6 +14,8 @@ class TaskType(Enum):
14
  UPSCALE_IMAGE = "UPSCALE_IMAGE"
15
  TILE_UPSCALE = "TILE_UPSCALE"
16
  OBJECT_REMOVAL = "OBJECT_REMOVAL"
 
 
17
 
18
 
19
  class ModelType(Enum):
@@ -45,6 +47,12 @@ class Task:
45
  def get_prompt(self) -> str:
46
  return self.__data.get("prompt", "")
47
 
 
 
 
 
 
 
48
  def get_userId(self) -> str:
49
  return self.__data.get("userId", "")
50
 
 
14
  UPSCALE_IMAGE = "UPSCALE_IMAGE"
15
  TILE_UPSCALE = "TILE_UPSCALE"
16
  OBJECT_REMOVAL = "OBJECT_REMOVAL"
17
+ SCRIBBLE = "SCRIBBLE"
18
+ LINEARART = "LINEARART"
19
 
20
 
21
  class ModelType(Enum):
 
47
  def get_prompt(self) -> str:
48
  return self.__data.get("prompt", "")
49
 
50
+ def get_prompt_left(self) -> str:
51
+ return self.__data.get("prompt_left", "")
52
+
53
+ def get_prompt_right(self) -> str:
54
+ return self.__data.get("prompt_right", "")
55
+
56
  def get_userId(self) -> str:
57
  return self.__data.get("userId", "")
58
 
internals/pipelines/commons.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Any, Callable, Dict, List, Optional, Union
2
 
3
  import torch
@@ -6,6 +7,7 @@ from diffusers import StableDiffusionImg2ImgPipeline
6
  from internals.data.result import Result
7
  from internals.pipelines.twoStepPipeline import two_step_pipeline
8
  from internals.util.commons import disable_safety_checker, download_image
 
9
 
10
 
11
  class AbstractPipeline:
@@ -17,6 +19,13 @@ class AbstractPipeline:
17
 
18
 
19
  class Text2Img(AbstractPipeline):
 
 
 
 
 
 
 
20
  def load(self, model_dir: str):
21
  self.pipe = two_step_pipeline.from_pretrained(
22
  model_dir, torch_dtype=torch.float16
@@ -33,14 +42,13 @@ class Text2Img(AbstractPipeline):
33
  @torch.inference_mode()
34
  def process(
35
  self,
36
- prompt: Union[str, List[str]] = None,
37
- modified_prompts: Union[str, List[str]] = None,
38
  height: Optional[int] = None,
39
  width: Optional[int] = None,
40
  num_inference_steps: int = 50,
41
  guidance_scale: float = 7.5,
42
- negative_prompt: Optional[Union[str, List[str]]] = None,
43
- num_images_per_prompt: Optional[int] = 1,
44
  eta: float = 0.0,
45
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
46
  latents: Optional[torch.FloatTensor] = None,
@@ -53,27 +61,54 @@ class Text2Img(AbstractPipeline):
53
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
54
  iteration: float = 3.0,
55
  ):
56
- result = self.pipe.two_step_pipeline(
57
- prompt=prompt,
58
- modified_prompts=modified_prompts,
59
- height=height,
60
- width=width,
61
- num_inference_steps=num_inference_steps,
62
- guidance_scale=guidance_scale,
63
- negative_prompt=negative_prompt,
64
- num_images_per_prompt=num_images_per_prompt,
65
- eta=eta,
66
- generator=generator,
67
- latents=latents,
68
- prompt_embeds=prompt_embeds,
69
- negative_prompt_embeds=negative_prompt_embeds,
70
- output_type=output_type,
71
- return_dict=return_dict,
72
- callback=callback,
73
- callback_steps=callback_steps,
74
- cross_attention_kwargs=cross_attention_kwargs,
75
- iteration=iteration,
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return Result.from_result(result)
78
 
79
 
 
1
+ from dataclasses import dataclass
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
 
7
  from internals.data.result import Result
8
  from internals.pipelines.twoStepPipeline import two_step_pipeline
9
  from internals.util.commons import disable_safety_checker, download_image
10
+ from internals.util.config import num_return_sequences
11
 
12
 
13
  class AbstractPipeline:
 
19
 
20
 
21
  class Text2Img(AbstractPipeline):
22
+ @dataclass
23
+ class Params:
24
+ prompt: List[str] = None
25
+ modified_prompt: List[str] = None
26
+ prompt_left: List[str] = None
27
+ prompt_right: List[str] = None
28
+
29
  def load(self, model_dir: str):
30
  self.pipe = two_step_pipeline.from_pretrained(
31
  model_dir, torch_dtype=torch.float16
 
42
  @torch.inference_mode()
43
  def process(
44
  self,
45
+ params: Params,
 
46
  height: Optional[int] = None,
47
  width: Optional[int] = None,
48
  num_inference_steps: int = 50,
49
  guidance_scale: float = 7.5,
50
+ negative_prompt: Optional[str] = None,
51
+ num_images_per_prompt: int = 1,
52
  eta: float = 0.0,
53
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
54
  latents: Optional[torch.FloatTensor] = None,
 
61
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
62
  iteration: float = 3.0,
63
  ):
64
+ prompt = params.prompt
65
+
66
+ if params.prompt_left and params.prompt_right:
67
+ # multi-character pipelines
68
+ prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]]
69
+ result = self.pipe.multi_character_diffusion(
70
+ prompt=prompt,
71
+ pos=["1:1-0:0", "1:2-0:0", "1:2-0:1"],
72
+ mix_val=[0.2, 0.8, 0.8],
73
+ height=height,
74
+ width=width,
75
+ num_inference_steps=num_inference_steps,
76
+ guidance_scale=guidance_scale,
77
+ negative_prompt=[negative_prompt or ""] * len(prompt),
78
+ num_images_per_prompt=num_return_sequences,
79
+ eta=eta,
80
+ # generator=generator,
81
+ output_type=output_type,
82
+ return_dict=return_dict,
83
+ callback=callback,
84
+ callback_steps=callback_steps,
85
+ )
86
+ else:
87
+ # two step pipeline
88
+ modified_prompt = params.modified_prompt
89
+
90
+ result = self.pipe.two_step_pipeline(
91
+ prompt=prompt,
92
+ modified_prompts=modified_prompt,
93
+ height=height,
94
+ width=width,
95
+ num_inference_steps=num_inference_steps,
96
+ guidance_scale=guidance_scale,
97
+ negative_prompt=[negative_prompt or ""] * num_return_sequences,
98
+ num_images_per_prompt=num_images_per_prompt,
99
+ eta=eta,
100
+ generator=generator,
101
+ latents=latents,
102
+ prompt_embeds=prompt_embeds,
103
+ negative_prompt_embeds=negative_prompt_embeds,
104
+ output_type=output_type,
105
+ return_dict=return_dict,
106
+ callback=callback,
107
+ callback_steps=callback_steps,
108
+ cross_attention_kwargs=cross_attention_kwargs,
109
+ iteration=iteration,
110
+ )
111
+
112
  return Result.from_result(result)
113
 
114
 
internals/pipelines/controlnets.py CHANGED
@@ -1,20 +1,20 @@
1
- from typing import List
2
 
3
  import cv2
4
  import numpy as np
5
  import torch
6
- from controlnet_aux import OpenposeDetector
7
- from diffusers import (
8
- ControlNetModel,
9
- DiffusionPipeline,
10
- StableDiffusionControlNetPipeline,
11
- UniPCMultistepScheduler,
12
- )
13
  from PIL import Image
 
14
  from tqdm import gui
15
 
16
  from internals.data.result import Result
17
  from internals.pipelines.commons import AbstractPipeline
 
 
18
  from internals.util.cache import clear_cuda_and_gc
19
  from internals.util.commons import download_image
20
 
@@ -27,11 +27,10 @@ class ControlNet(AbstractPipeline):
27
  self.load_canny()
28
 
29
  # controlnet pipeline for canny and pose
30
- pipe = DiffusionPipeline.from_pretrained(
31
  model_dir,
32
  controlnet=self.controlnet,
33
  torch_dtype=torch.float16,
34
- custom_pipeline="stable_diffusion_controlnet_img2img",
35
  ).to("cuda")
36
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
37
  pipe.enable_model_cpu_offload()
@@ -62,7 +61,7 @@ class ControlNet(AbstractPipeline):
62
  if self.__current_task_name == "pose":
63
  return
64
  pose = ControlNetModel.from_pretrained(
65
- "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
66
  ).to("cuda")
67
  self.__current_task_name = "pose"
68
  self.controlnet = pose
@@ -86,6 +85,35 @@ class ControlNet(AbstractPipeline):
86
  self.pipe2.controlnet = tile_upscaler
87
  clear_cuda_and_gc()
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def cleanup(self):
90
  self.pipe.controlnet = None
91
  self.pipe2.controlnet = None
@@ -191,12 +219,84 @@ class ControlNet(AbstractPipeline):
191
  )
192
  return Result.from_result(result)
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def detect_pose(self, imageUrl: str) -> Image.Image:
195
  detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
196
  image = download_image(imageUrl)
197
  image = detector.__call__(image, hand_and_face=True)
198
  return image
199
 
 
 
 
 
 
 
 
 
 
 
200
  def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
201
  image_array = np.array(image)
202
 
 
1
+ from typing import List, Union
2
 
3
  import cv2
4
  import numpy as np
5
  import torch
6
+ from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
7
+ from diffusers import (ControlNetModel, DiffusionPipeline,
8
+ StableDiffusionControlNetPipeline,
9
+ UniPCMultistepScheduler)
 
 
 
10
  from PIL import Image
11
+ from torch.nn import Linear
12
  from tqdm import gui
13
 
14
  from internals.data.result import Result
15
  from internals.pipelines.commons import AbstractPipeline
16
+ from internals.pipelines.tileUpscalePipeline import \
17
+ StableDiffusionControlNetImg2ImgPipeline
18
  from internals.util.cache import clear_cuda_and_gc
19
  from internals.util.commons import download_image
20
 
 
27
  self.load_canny()
28
 
29
  # controlnet pipeline for canny and pose
30
+ pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
31
  model_dir,
32
  controlnet=self.controlnet,
33
  torch_dtype=torch.float16,
 
34
  ).to("cuda")
35
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
36
  pipe.enable_model_cpu_offload()
 
61
  if self.__current_task_name == "pose":
62
  return
63
  pose = ControlNetModel.from_pretrained(
64
+ "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
65
  ).to("cuda")
66
  self.__current_task_name = "pose"
67
  self.controlnet = pose
 
85
  self.pipe2.controlnet = tile_upscaler
86
  clear_cuda_and_gc()
87
 
88
+ def load_scribble(self):
89
+ if self.__current_task_name == "scribble":
90
+ return
91
+ scribble = ControlNetModel.from_pretrained(
92
+ "lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16
93
+ ).to("cuda")
94
+ self.__current_task_name = "scribble"
95
+ self.controlnet = scribble
96
+ if hasattr(self, "pipe"):
97
+ self.pipe.controlnet = scribble
98
+ if hasattr(self, "pipe2"):
99
+ self.pipe2.controlnet = scribble
100
+ clear_cuda_and_gc()
101
+
102
+ def load_linearart(self):
103
+ if self.__current_task_name == "linearart":
104
+ return
105
+ linearart = ControlNetModel.from_pretrained(
106
+ "ControlNet-1-1-preview/control_v11p_sd15_lineart",
107
+ torch_dtype=torch.float16,
108
+ ).to("cuda")
109
+ self.__current_task_name = "linearart"
110
+ self.controlnet = linearart
111
+ if hasattr(self, "pipe"):
112
+ self.pipe.controlnet = linearart
113
+ if hasattr(self, "pipe2"):
114
+ self.pipe2.controlnet = linearart
115
+ clear_cuda_and_gc()
116
+
117
  def cleanup(self):
118
  self.pipe.controlnet = None
119
  self.pipe2.controlnet = None
 
219
  )
220
  return Result.from_result(result)
221
 
222
+ @torch.inference_mode()
223
+ def process_scribble(
224
+ self,
225
+ imageUrl: str,
226
+ prompt: Union[str, List[str]],
227
+ negative_prompt: Union[str, List[str]],
228
+ steps: int,
229
+ seed: int,
230
+ height: int,
231
+ width: int,
232
+ guidance_scale: float = 7.5,
233
+ ):
234
+ if self.__current_task_name != "scribble":
235
+ raise Exception("ControlNet is not loaded with scribble model")
236
+
237
+ torch.manual_seed(seed)
238
+
239
+ init_image = download_image(imageUrl).resize((width, height))
240
+ condition_image = self.__scribble_condition_image(init_image)
241
+
242
+ result = self.pipe2.__call__(
243
+ image=condition_image,
244
+ prompt=prompt,
245
+ num_inference_steps=steps,
246
+ negative_prompt=negative_prompt,
247
+ height=height,
248
+ width=width,
249
+ guidance_scale=guidance_scale,
250
+ )
251
+ return Result.from_result(result)
252
+
253
+ @torch.inference_mode()
254
+ def process_linearart(
255
+ self,
256
+ imageUrl: str,
257
+ prompt: Union[str, List[str]],
258
+ negative_prompt: Union[str, List[str]],
259
+ steps: int,
260
+ seed: int,
261
+ height: int,
262
+ width: int,
263
+ guidance_scale: float = 7.5,
264
+ ):
265
+ if self.__current_task_name != "linearart":
266
+ raise Exception("ControlNet is not loaded with linearart model")
267
+
268
+ torch.manual_seed(seed)
269
+
270
+ init_image = download_image(imageUrl).resize((width, height))
271
+ condition_image = self.__linearart_condition_image(init_image)
272
+
273
+ result = self.pipe2.__call__(
274
+ image=condition_image,
275
+ prompt=prompt,
276
+ num_inference_steps=steps,
277
+ negative_prompt=negative_prompt,
278
+ height=height,
279
+ width=width,
280
+ guidance_scale=guidance_scale,
281
+ )
282
+ return Result.from_result(result)
283
+
284
  def detect_pose(self, imageUrl: str) -> Image.Image:
285
  detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
286
  image = download_image(imageUrl)
287
  image = detector.__call__(image, hand_and_face=True)
288
  return image
289
 
290
+ def __scribble_condition_image(self, image: Image.Image) -> Image.Image:
291
+ processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
292
+ image = processor.__call__(input_image=image, scribble=True)
293
+ return image
294
+
295
+ def __linearart_condition_image(self, image: Image.Image) -> Image.Image:
296
+ processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
297
+ image = processor.__call__(input_image=image)
298
+ return image
299
+
300
  def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
301
  image_array = np.array(image)
302
 
internals/pipelines/prompt_modifier.py CHANGED
@@ -18,18 +18,20 @@ class PromptModifier:
18
  self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
19
  self.prompter_tokenizer.padding_side = "left"
20
 
21
- def modify(self, text: str) -> List[str]:
22
  eos_id = self.prompter_tokenizer.eos_token_id
23
  # restricted_words_list = ["octane", "cyber"]
24
  # restricted_words_token_ids = prompter_tokenizer(
25
  # restricted_words_list, add_special_tokens=False
26
  # ).input_ids
27
 
 
 
28
  generation_config = GenerationConfig(
29
  do_sample=False,
30
  max_new_tokens=75,
31
  num_beams=4,
32
- num_return_sequences=self.__num_of_sequences,
33
  eos_token_id=eos_id,
34
  pad_token_id=eos_id,
35
  length_penalty=-1.0,
 
18
  self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
19
  self.prompter_tokenizer.padding_side = "left"
20
 
21
+ def modify(self, text: str, num_of_sequences: Optional[int] = None) -> List[str]:
22
  eos_id = self.prompter_tokenizer.eos_token_id
23
  # restricted_words_list = ["octane", "cyber"]
24
  # restricted_words_token_ids = prompter_tokenizer(
25
  # restricted_words_list, add_special_tokens=False
26
  # ).input_ids
27
 
28
+ num_of_sequences = num_of_sequences or self.__num_of_sequences
29
+
30
  generation_config = GenerationConfig(
31
  do_sample=False,
32
  max_new_tokens=75,
33
  num_beams=4,
34
+ num_return_sequences=num_of_sequences,
35
  eos_token_id=eos_id,
36
  pad_token_id=eos_id,
37
  length_penalty=-1.0,
internals/pipelines/tileUpscalePipeline.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline,
8
+ UNet2DConditionModel)
9
+ from diffusers.loaders import LoraLoaderMixin
10
+ from diffusers.pipelines.stable_diffusion import (
11
+ StableDiffusionPipelineOutput, StableDiffusionSafetyChecker)
12
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import \
13
+ MultiControlNetModel
14
+ from diffusers.schedulers import KarrasDiffusionSchedulers
15
+ from diffusers.utils import (PIL_INTERPOLATION, is_accelerate_available,
16
+ is_accelerate_version, randn_tensor,
17
+ replace_example_docstring)
18
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
19
+
20
+ EXAMPLE_DOC_STRING = """
21
+ Examples:
22
+ ```py
23
+ >>> import numpy as np
24
+ >>> import torch
25
+ >>> from PIL import Image
26
+ >>> from diffusers import ControlNetModel, UniPCMultistepScheduler
27
+ >>> from diffusers.utils import load_image
28
+
29
+ >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
30
+
31
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
32
+
33
+ >>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
34
+ "runwayml/stable-diffusion-v1-5",
35
+ controlnet=controlnet,
36
+ safety_checker=None,
37
+ torch_dtype=torch.float16
38
+ )
39
+
40
+ >>> pipe_controlnet.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
41
+ >>> pipe_controlnet.enable_xformers_memory_efficient_attention()
42
+ >>> pipe_controlnet.enable_model_cpu_offload()
43
+
44
+ # using image with edges for our canny controlnet
45
+ >>> control_image = load_image(
46
+ "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/vermeer_canny_edged.png")
47
+
48
+
49
+ >>> result_img = pipe_controlnet(controlnet_conditioning_image=control_image,
50
+ image=input_image,
51
+ prompt="an android robot, cyberpank, digitl art masterpiece",
52
+ num_inference_steps=20).images[0]
53
+
54
+ >>> result_img.show()
55
+ ```
56
+ """
57
+
58
+
59
+ def prepare_image(image):
60
+ if isinstance(image, torch.Tensor):
61
+ # Batch single image
62
+ if image.ndim == 3:
63
+ image = image.unsqueeze(0)
64
+
65
+ image = image.to(dtype=torch.float32)
66
+ else:
67
+ # preprocess image
68
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
69
+ image = [image]
70
+
71
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
72
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
73
+ image = np.concatenate(image, axis=0)
74
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
75
+ image = np.concatenate([i[None, :] for i in image], axis=0)
76
+
77
+ image = image.transpose(0, 3, 1, 2)
78
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
79
+
80
+ return image
81
+
82
+
83
+ def prepare_controlnet_conditioning_image(
84
+ controlnet_conditioning_image,
85
+ width,
86
+ height,
87
+ batch_size,
88
+ num_images_per_prompt,
89
+ device,
90
+ dtype,
91
+ do_classifier_free_guidance,
92
+ ):
93
+ if not isinstance(controlnet_conditioning_image, torch.Tensor):
94
+ if isinstance(controlnet_conditioning_image, PIL.Image.Image):
95
+ controlnet_conditioning_image = [controlnet_conditioning_image]
96
+
97
+ if isinstance(controlnet_conditioning_image[0], PIL.Image.Image):
98
+ controlnet_conditioning_image = [
99
+ np.array(
100
+ i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
101
+ )[None, :]
102
+ for i in controlnet_conditioning_image
103
+ ]
104
+ controlnet_conditioning_image = np.concatenate(
105
+ controlnet_conditioning_image, axis=0
106
+ )
107
+ controlnet_conditioning_image = (
108
+ np.array(controlnet_conditioning_image).astype(np.float32) / 255.0
109
+ )
110
+ controlnet_conditioning_image = controlnet_conditioning_image.transpose(
111
+ 0, 3, 1, 2
112
+ )
113
+ controlnet_conditioning_image = torch.from_numpy(
114
+ controlnet_conditioning_image
115
+ )
116
+ elif isinstance(controlnet_conditioning_image[0], torch.Tensor):
117
+ controlnet_conditioning_image = torch.cat(
118
+ controlnet_conditioning_image, dim=0
119
+ )
120
+
121
+ image_batch_size = controlnet_conditioning_image.shape[0]
122
+
123
+ if image_batch_size == 1:
124
+ repeat_by = batch_size
125
+ else:
126
+ # image batch size is the same as prompt batch size
127
+ repeat_by = num_images_per_prompt
128
+
129
+ controlnet_conditioning_image = controlnet_conditioning_image.repeat_interleave(
130
+ repeat_by, dim=0
131
+ )
132
+
133
+ controlnet_conditioning_image = controlnet_conditioning_image.to(
134
+ device=device, dtype=dtype
135
+ )
136
+
137
+ if do_classifier_free_guidance:
138
+ controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)
139
+
140
+ return controlnet_conditioning_image
141
+
142
+
143
+ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
144
+ """
145
+ Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/
146
+ """
147
+
148
+ _optional_components = ["safety_checker", "feature_extractor"]
149
+
150
+ def __init__(
151
+ self,
152
+ vae: AutoencoderKL,
153
+ text_encoder: CLIPTextModel,
154
+ tokenizer: CLIPTokenizer,
155
+ unet: UNet2DConditionModel,
156
+ controlnet: Union[
157
+ ControlNetModel,
158
+ List[ControlNetModel],
159
+ Tuple[ControlNetModel],
160
+ MultiControlNetModel,
161
+ ],
162
+ scheduler: KarrasDiffusionSchedulers,
163
+ safety_checker: StableDiffusionSafetyChecker,
164
+ feature_extractor: CLIPImageProcessor,
165
+ requires_safety_checker: bool = True,
166
+ ):
167
+ super().__init__()
168
+
169
+ if safety_checker is None and requires_safety_checker:
170
+ print(
171
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
172
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
173
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
174
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
175
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
176
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
177
+ )
178
+
179
+ if safety_checker is not None and feature_extractor is None:
180
+ raise ValueError(
181
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
182
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
183
+ )
184
+
185
+ if isinstance(controlnet, (list, tuple)):
186
+ controlnet = MultiControlNetModel(controlnet)
187
+
188
+ self.register_modules(
189
+ vae=vae,
190
+ text_encoder=text_encoder,
191
+ tokenizer=tokenizer,
192
+ unet=unet,
193
+ controlnet=controlnet,
194
+ scheduler=scheduler,
195
+ safety_checker=safety_checker,
196
+ feature_extractor=feature_extractor,
197
+ )
198
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
199
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
200
+
201
+ def enable_vae_slicing(self):
202
+ r"""
203
+ Enable sliced VAE decoding.
204
+
205
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
206
+ steps. This is useful to save some memory and allow larger batch sizes.
207
+ """
208
+ self.vae.enable_slicing()
209
+
210
+ def disable_vae_slicing(self):
211
+ r"""
212
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
213
+ computing decoding in one step.
214
+ """
215
+ self.vae.disable_slicing()
216
+
217
+ def enable_sequential_cpu_offload(self, gpu_id=0):
218
+ r"""
219
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
220
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
221
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
222
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
223
+ `enable_model_cpu_offload`, but performance is lower.
224
+ """
225
+ if is_accelerate_available():
226
+ from accelerate import cpu_offload
227
+ else:
228
+ raise ImportError("Please install accelerate via `pip install accelerate`")
229
+
230
+ device = torch.device(f"cuda:{gpu_id}")
231
+
232
+ for cpu_offloaded_model in [
233
+ self.unet,
234
+ self.text_encoder,
235
+ self.vae,
236
+ self.controlnet,
237
+ ]:
238
+ cpu_offload(cpu_offloaded_model, device)
239
+
240
+ if self.safety_checker is not None:
241
+ cpu_offload(
242
+ self.safety_checker, execution_device=device, offload_buffers=True
243
+ )
244
+
245
+ def enable_model_cpu_offload(self, gpu_id=0):
246
+ r"""
247
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
248
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
249
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
250
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
251
+ """
252
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
253
+ from accelerate import cpu_offload_with_hook
254
+ else:
255
+ raise ImportError(
256
+ "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher."
257
+ )
258
+
259
+ device = torch.device(f"cuda:{gpu_id}")
260
+
261
+ hook = None
262
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
263
+ _, hook = cpu_offload_with_hook(
264
+ cpu_offloaded_model, device, prev_module_hook=hook
265
+ )
266
+
267
+ if self.safety_checker is not None:
268
+ # the safety checker can offload the vae again
269
+ _, hook = cpu_offload_with_hook(
270
+ self.safety_checker, device, prev_module_hook=hook
271
+ )
272
+
273
+ # control net hook has be manually offloaded as it alternates with unet
274
+ cpu_offload_with_hook(self.controlnet, device)
275
+
276
+ # We'll offload the last model manually.
277
+ self.final_offload_hook = hook
278
+
279
+ @property
280
+ def _execution_device(self):
281
+ r"""
282
+ Returns the device on which the pipeline's models will be executed. After calling
283
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
284
+ hooks.
285
+ """
286
+ if not hasattr(self.unet, "_hf_hook"):
287
+ return self.device
288
+ for module in self.unet.modules():
289
+ if (
290
+ hasattr(module, "_hf_hook")
291
+ and hasattr(module._hf_hook, "execution_device")
292
+ and module._hf_hook.execution_device is not None
293
+ ):
294
+ return torch.device(module._hf_hook.execution_device)
295
+ return self.device
296
+
297
+ def _encode_prompt(
298
+ self,
299
+ prompt,
300
+ device,
301
+ num_images_per_prompt,
302
+ do_classifier_free_guidance,
303
+ negative_prompt=None,
304
+ prompt_embeds: Optional[torch.FloatTensor] = None,
305
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
306
+ ):
307
+ r"""
308
+ Encodes the prompt into text encoder hidden states.
309
+
310
+ Args:
311
+ prompt (`str` or `List[str]`, *optional*):
312
+ prompt to be encoded
313
+ device: (`torch.device`):
314
+ torch device
315
+ num_images_per_prompt (`int`):
316
+ number of images that should be generated per prompt
317
+ do_classifier_free_guidance (`bool`):
318
+ whether to use classifier free guidance or not
319
+ negative_prompt (`str` or `List[str]`, *optional*):
320
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
321
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
322
+ prompt_embeds (`torch.FloatTensor`, *optional*):
323
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
324
+ provided, text embeddings will be generated from `prompt` input argument.
325
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
326
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
327
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
328
+ argument.
329
+ """
330
+ if prompt is not None and isinstance(prompt, str):
331
+ batch_size = 1
332
+ elif prompt is not None and isinstance(prompt, list):
333
+ batch_size = len(prompt)
334
+ else:
335
+ batch_size = prompt_embeds.shape[0]
336
+
337
+ if prompt_embeds is None:
338
+ text_inputs = self.tokenizer(
339
+ prompt,
340
+ padding="max_length",
341
+ max_length=self.tokenizer.model_max_length,
342
+ truncation=True,
343
+ return_tensors="pt",
344
+ )
345
+ text_input_ids = text_inputs.input_ids
346
+ untruncated_ids = self.tokenizer(
347
+ prompt, padding="longest", return_tensors="pt"
348
+ ).input_ids
349
+
350
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
351
+ -1
352
+ ] and not torch.equal(text_input_ids, untruncated_ids):
353
+ removed_text = self.tokenizer.batch_decode(
354
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
355
+ )
356
+ print(
357
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
358
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
359
+ )
360
+
361
+ if (
362
+ hasattr(self.text_encoder.config, "use_attention_mask")
363
+ and self.text_encoder.config.use_attention_mask
364
+ ):
365
+ attention_mask = text_inputs.attention_mask.to(device)
366
+ else:
367
+ attention_mask = None
368
+
369
+ prompt_embeds = self.text_encoder(
370
+ text_input_ids.to(device),
371
+ attention_mask=attention_mask,
372
+ )
373
+ prompt_embeds = prompt_embeds[0]
374
+
375
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
376
+
377
+ bs_embed, seq_len, _ = prompt_embeds.shape
378
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
379
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
380
+ prompt_embeds = prompt_embeds.view(
381
+ bs_embed * num_images_per_prompt, seq_len, -1
382
+ )
383
+
384
+ # get unconditional embeddings for classifier free guidance
385
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
386
+ uncond_tokens: List[str]
387
+ if negative_prompt is None:
388
+ uncond_tokens = [""] * batch_size
389
+ elif type(prompt) is not type(negative_prompt):
390
+ raise TypeError(
391
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
392
+ f" {type(prompt)}."
393
+ )
394
+ elif isinstance(negative_prompt, str):
395
+ uncond_tokens = [negative_prompt]
396
+ elif batch_size != len(negative_prompt):
397
+ raise ValueError(
398
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
399
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
400
+ " the batch size of `prompt`."
401
+ )
402
+ else:
403
+ uncond_tokens = negative_prompt
404
+
405
+ max_length = prompt_embeds.shape[1]
406
+ uncond_input = self.tokenizer(
407
+ uncond_tokens,
408
+ padding="max_length",
409
+ max_length=max_length,
410
+ truncation=True,
411
+ return_tensors="pt",
412
+ )
413
+
414
+ if (
415
+ hasattr(self.text_encoder.config, "use_attention_mask")
416
+ and self.text_encoder.config.use_attention_mask
417
+ ):
418
+ attention_mask = uncond_input.attention_mask.to(device)
419
+ else:
420
+ attention_mask = None
421
+
422
+ negative_prompt_embeds = self.text_encoder(
423
+ uncond_input.input_ids.to(device),
424
+ attention_mask=attention_mask,
425
+ )
426
+ negative_prompt_embeds = negative_prompt_embeds[0]
427
+
428
+ if do_classifier_free_guidance:
429
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
430
+ seq_len = negative_prompt_embeds.shape[1]
431
+
432
+ negative_prompt_embeds = negative_prompt_embeds.to(
433
+ dtype=self.text_encoder.dtype, device=device
434
+ )
435
+
436
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
437
+ 1, num_images_per_prompt, 1
438
+ )
439
+ negative_prompt_embeds = negative_prompt_embeds.view(
440
+ batch_size * num_images_per_prompt, seq_len, -1
441
+ )
442
+
443
+ # For classifier free guidance, we need to do two forward passes.
444
+ # Here we concatenate the unconditional and text embeddings into a single batch
445
+ # to avoid doing two forward passes
446
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
447
+
448
+ return prompt_embeds
449
+
450
+ def run_safety_checker(self, image, device, dtype):
451
+ if self.safety_checker is not None:
452
+ safety_checker_input = self.feature_extractor(
453
+ self.numpy_to_pil(image), return_tensors="pt"
454
+ ).to(device)
455
+ image, has_nsfw_concept = self.safety_checker(
456
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
457
+ )
458
+ else:
459
+ has_nsfw_concept = None
460
+ return image, has_nsfw_concept
461
+
462
+ def decode_latents(self, latents):
463
+ latents = 1 / self.vae.config.scaling_factor * latents
464
+ image = self.vae.decode(latents).sample
465
+ image = (image / 2 + 0.5).clamp(0, 1)
466
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
467
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
468
+ return image
469
+
470
+ def prepare_extra_step_kwargs(self, generator, eta):
471
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
472
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
473
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
474
+ # and should be between [0, 1]
475
+
476
+ accepts_eta = "eta" in set(
477
+ inspect.signature(self.scheduler.step).parameters.keys()
478
+ )
479
+ extra_step_kwargs = {}
480
+ if accepts_eta:
481
+ extra_step_kwargs["eta"] = eta
482
+
483
+ # check if the scheduler accepts generator
484
+ accepts_generator = "generator" in set(
485
+ inspect.signature(self.scheduler.step).parameters.keys()
486
+ )
487
+ if accepts_generator:
488
+ extra_step_kwargs["generator"] = generator
489
+ return extra_step_kwargs
490
+
491
+ def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
492
+ image_is_pil = isinstance(image, PIL.Image.Image)
493
+ image_is_tensor = isinstance(image, torch.Tensor)
494
+ image_is_pil_list = isinstance(image, list) and isinstance(
495
+ image[0], PIL.Image.Image
496
+ )
497
+ image_is_tensor_list = isinstance(image, list) and isinstance(
498
+ image[0], torch.Tensor
499
+ )
500
+
501
+ if (
502
+ not image_is_pil
503
+ and not image_is_tensor
504
+ and not image_is_pil_list
505
+ and not image_is_tensor_list
506
+ ):
507
+ raise TypeError(
508
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
509
+ )
510
+
511
+ if image_is_pil:
512
+ image_batch_size = 1
513
+ elif image_is_tensor:
514
+ image_batch_size = image.shape[0]
515
+ elif image_is_pil_list:
516
+ image_batch_size = len(image)
517
+ elif image_is_tensor_list:
518
+ image_batch_size = len(image)
519
+ else:
520
+ raise ValueError("controlnet condition image is not valid")
521
+
522
+ if prompt is not None and isinstance(prompt, str):
523
+ prompt_batch_size = 1
524
+ elif prompt is not None and isinstance(prompt, list):
525
+ prompt_batch_size = len(prompt)
526
+ elif prompt_embeds is not None:
527
+ prompt_batch_size = prompt_embeds.shape[0]
528
+ else:
529
+ raise ValueError("prompt or prompt_embeds are not valid")
530
+
531
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
532
+ raise ValueError(
533
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
534
+ )
535
+
536
+ def check_inputs(
537
+ self,
538
+ prompt,
539
+ image,
540
+ controlnet_conditioning_image,
541
+ height,
542
+ width,
543
+ callback_steps,
544
+ negative_prompt=None,
545
+ prompt_embeds=None,
546
+ negative_prompt_embeds=None,
547
+ strength=None,
548
+ controlnet_guidance_start=None,
549
+ controlnet_guidance_end=None,
550
+ controlnet_conditioning_scale=None,
551
+ ):
552
+ if height % 8 != 0 or width % 8 != 0:
553
+ raise ValueError(
554
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
555
+ )
556
+
557
+ if (callback_steps is None) or (
558
+ callback_steps is not None
559
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
560
+ ):
561
+ raise ValueError(
562
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
563
+ f" {type(callback_steps)}."
564
+ )
565
+
566
+ if prompt is not None and prompt_embeds is not None:
567
+ raise ValueError(
568
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
569
+ " only forward one of the two."
570
+ )
571
+ elif prompt is None and prompt_embeds is None:
572
+ raise ValueError(
573
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
574
+ )
575
+ elif prompt is not None and (
576
+ not isinstance(prompt, str) and not isinstance(prompt, list)
577
+ ):
578
+ raise ValueError(
579
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
580
+ )
581
+
582
+ if negative_prompt is not None and negative_prompt_embeds is not None:
583
+ raise ValueError(
584
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
585
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
586
+ )
587
+
588
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
589
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
590
+ raise ValueError(
591
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
592
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
593
+ f" {negative_prompt_embeds.shape}."
594
+ )
595
+
596
+ # check controlnet condition image
597
+
598
+ if isinstance(self.controlnet, ControlNetModel):
599
+ self.check_controlnet_conditioning_image(
600
+ controlnet_conditioning_image, prompt, prompt_embeds
601
+ )
602
+ elif isinstance(self.controlnet, MultiControlNetModel):
603
+ if not isinstance(controlnet_conditioning_image, list):
604
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
605
+
606
+ if len(controlnet_conditioning_image) != len(self.controlnet.nets):
607
+ raise ValueError(
608
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
609
+ )
610
+
611
+ for image_ in controlnet_conditioning_image:
612
+ self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
613
+ else:
614
+ assert False
615
+
616
+ # Check `controlnet_conditioning_scale`
617
+
618
+ if isinstance(self.controlnet, ControlNetModel):
619
+ if not isinstance(controlnet_conditioning_scale, float):
620
+ raise TypeError(
621
+ "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
622
+ )
623
+ elif isinstance(self.controlnet, MultiControlNetModel):
624
+ if isinstance(controlnet_conditioning_scale, list) and len(
625
+ controlnet_conditioning_scale
626
+ ) != len(self.controlnet.nets):
627
+ raise ValueError(
628
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
629
+ " the same length as the number of controlnets"
630
+ )
631
+ else:
632
+ assert False
633
+
634
+ if isinstance(image, torch.Tensor):
635
+ if image.ndim != 3 and image.ndim != 4:
636
+ raise ValueError("`image` must have 3 or 4 dimensions")
637
+
638
+ if image.ndim == 3:
639
+ image_batch_size = 1
640
+ image_channels, image_height, image_width = image.shape
641
+ elif image.ndim == 4:
642
+ (
643
+ image_batch_size,
644
+ image_channels,
645
+ image_height,
646
+ image_width,
647
+ ) = image.shape
648
+ else:
649
+ assert False
650
+
651
+ if image_channels != 3:
652
+ raise ValueError("`image` must have 3 channels")
653
+
654
+ if image.min() < -1 or image.max() > 1:
655
+ raise ValueError("`image` should be in range [-1, 1]")
656
+
657
+ if self.vae.config.latent_channels != self.unet.config.in_channels:
658
+ raise ValueError(
659
+ f"The config of `pipeline.unet` expects {self.unet.config.in_channels} but received"
660
+ f" latent channels: {self.vae.config.latent_channels},"
661
+ f" Please verify the config of `pipeline.unet` and the `pipeline.vae`"
662
+ )
663
+
664
+ if strength < 0 or strength > 1:
665
+ raise ValueError(
666
+ f"The value of `strength` should in [0.0, 1.0] but is {strength}"
667
+ )
668
+
669
+ if controlnet_guidance_start < 0 or controlnet_guidance_start > 1:
670
+ raise ValueError(
671
+ f"The value of `controlnet_guidance_start` should in [0.0, 1.0] but is {controlnet_guidance_start}"
672
+ )
673
+
674
+ if controlnet_guidance_end < 0 or controlnet_guidance_end > 1:
675
+ raise ValueError(
676
+ f"The value of `controlnet_guidance_end` should in [0.0, 1.0] but is {controlnet_guidance_end}"
677
+ )
678
+
679
+ if controlnet_guidance_start > controlnet_guidance_end:
680
+ raise ValueError(
681
+ "The value of `controlnet_guidance_start` should be less than `controlnet_guidance_end`, but got"
682
+ f" `controlnet_guidance_start` {controlnet_guidance_start} >= `controlnet_guidance_end` {controlnet_guidance_end}"
683
+ )
684
+
685
+ def get_timesteps(self, num_inference_steps, strength, device):
686
+ # get the original timestep using init_timestep
687
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
688
+
689
+ t_start = max(num_inference_steps - init_timestep, 0)
690
+ timesteps = self.scheduler.timesteps[t_start:]
691
+
692
+ return timesteps, num_inference_steps - t_start
693
+
694
+ def prepare_latents(
695
+ self,
696
+ image,
697
+ timestep,
698
+ batch_size,
699
+ num_images_per_prompt,
700
+ dtype,
701
+ device,
702
+ generator=None,
703
+ ):
704
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
705
+ raise ValueError(
706
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
707
+ )
708
+
709
+ image = image.to(device=device, dtype=dtype)
710
+
711
+ batch_size = batch_size * num_images_per_prompt
712
+ if isinstance(generator, list) and len(generator) != batch_size:
713
+ raise ValueError(
714
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
715
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
716
+ )
717
+
718
+ if isinstance(generator, list):
719
+ init_latents = [
720
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
721
+ for i in range(batch_size)
722
+ ]
723
+ init_latents = torch.cat(init_latents, dim=0)
724
+ else:
725
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
726
+
727
+ init_latents = self.vae.config.scaling_factor * init_latents
728
+
729
+ if (
730
+ batch_size > init_latents.shape[0]
731
+ and batch_size % init_latents.shape[0] == 0
732
+ ):
733
+ raise ValueError(
734
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
735
+ )
736
+ else:
737
+ init_latents = torch.cat([init_latents], dim=0)
738
+
739
+ shape = init_latents.shape
740
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
741
+
742
+ # get latents
743
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
744
+ latents = init_latents
745
+
746
+ return latents
747
+
748
+ def _default_height_width(self, height, width, image):
749
+ if isinstance(image, list):
750
+ image = image[0]
751
+
752
+ if height is None:
753
+ if isinstance(image, PIL.Image.Image):
754
+ height = image.height
755
+ elif isinstance(image, torch.Tensor):
756
+ height = image.shape[3]
757
+
758
+ height = (height // 8) * 8 # round down to nearest multiple of 8
759
+
760
+ if width is None:
761
+ if isinstance(image, PIL.Image.Image):
762
+ width = image.width
763
+ elif isinstance(image, torch.Tensor):
764
+ width = image.shape[2]
765
+
766
+ width = (width // 8) * 8 # round down to nearest multiple of 8
767
+
768
+ return height, width
769
+
770
+ @torch.no_grad()
771
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
772
+ def __call__(
773
+ self,
774
+ prompt: Union[str, List[str]] = None,
775
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
776
+ controlnet_conditioning_image: Union[
777
+ torch.FloatTensor,
778
+ PIL.Image.Image,
779
+ List[torch.FloatTensor],
780
+ List[PIL.Image.Image],
781
+ ] = None,
782
+ strength: float = 0.8,
783
+ height: Optional[int] = None,
784
+ width: Optional[int] = None,
785
+ num_inference_steps: int = 50,
786
+ guidance_scale: float = 7.5,
787
+ negative_prompt: Optional[Union[str, List[str]]] = None,
788
+ num_images_per_prompt: Optional[int] = 1,
789
+ eta: float = 0.0,
790
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
791
+ latents: Optional[torch.FloatTensor] = None,
792
+ prompt_embeds: Optional[torch.FloatTensor] = None,
793
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
794
+ output_type: Optional[str] = "pil",
795
+ return_dict: bool = True,
796
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
797
+ callback_steps: int = 1,
798
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
799
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
800
+ controlnet_guidance_start: float = 0.0,
801
+ controlnet_guidance_end: float = 1.0,
802
+ ):
803
+ r"""
804
+ Function invoked when calling the pipeline for generation.
805
+
806
+ Args:
807
+ prompt (`str` or `List[str]`, *optional*):
808
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
809
+ instead.
810
+ image (`torch.Tensor` or `PIL.Image.Image`):
811
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
812
+ be masked out with `mask_image` and repainted according to `prompt`.
813
+ controlnet_conditioning_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
814
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
815
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
816
+ also be accepted as an image. The control image is automatically resized to fit the output image.
817
+ strength (`float`, *optional*):
818
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image`
819
+ will be used as a starting point, adding more noise to it the larger the `strength`. The number of
820
+ denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will
821
+ be maximum and the denoising process will run for the full number of iterations specified in
822
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
823
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
824
+ The height in pixels of the generated image.
825
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
826
+ The width in pixels of the generated image.
827
+ num_inference_steps (`int`, *optional*, defaults to 50):
828
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
829
+ expense of slower inference.
830
+ guidance_scale (`float`, *optional*, defaults to 7.5):
831
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
832
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
833
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
834
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
835
+ usually at the expense of lower image quality.
836
+ negative_prompt (`str` or `List[str]`, *optional*):
837
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
838
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
839
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
840
+ The number of images to generate per prompt.
841
+ eta (`float`, *optional*, defaults to 0.0):
842
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
843
+ [`schedulers.DDIMScheduler`], will be ignored for others.
844
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
845
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
846
+ to make generation deterministic.
847
+ latents (`torch.FloatTensor`, *optional*):
848
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
849
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
850
+ tensor will ge generated by sampling using the supplied random `generator`.
851
+ prompt_embeds (`torch.FloatTensor`, *optional*):
852
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
853
+ provided, text embeddings will be generated from `prompt` input argument.
854
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
855
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
856
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
857
+ argument.
858
+ output_type (`str`, *optional*, defaults to `"pil"`):
859
+ The output format of the generate image. Choose between
860
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
861
+ return_dict (`bool`, *optional*, defaults to `True`):
862
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
863
+ plain tuple.
864
+ callback (`Callable`, *optional*):
865
+ A function that will be called every `callback_steps` steps during inference. The function will be
866
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
867
+ callback_steps (`int`, *optional*, defaults to 1):
868
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
869
+ called at every step.
870
+ cross_attention_kwargs (`dict`, *optional*):
871
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
872
+ `self.processor` in
873
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
874
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
875
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
876
+ to the residual in the original unet.
877
+ controlnet_guidance_start ('float', *optional*, defaults to 0.0):
878
+ The percentage of total steps the controlnet starts applying. Must be between 0 and 1.
879
+ controlnet_guidance_end ('float', *optional*, defaults to 1.0):
880
+ The percentage of total steps the controlnet ends applying. Must be between 0 and 1. Must be greater
881
+ than `controlnet_guidance_start`.
882
+
883
+ Examples:
884
+
885
+ Returns:
886
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
887
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
888
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
889
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
890
+ (nsfw) content, according to the `safety_checker`.
891
+ """
892
+ # 0. Default height and width to unet
893
+ height, width = self._default_height_width(
894
+ height, width, controlnet_conditioning_image
895
+ )
896
+
897
+ # 1. Check inputs. Raise error if not correct
898
+ self.check_inputs(
899
+ prompt,
900
+ image,
901
+ controlnet_conditioning_image,
902
+ height,
903
+ width,
904
+ callback_steps,
905
+ negative_prompt,
906
+ prompt_embeds,
907
+ negative_prompt_embeds,
908
+ strength,
909
+ controlnet_guidance_start,
910
+ controlnet_guidance_end,
911
+ controlnet_conditioning_scale,
912
+ )
913
+
914
+ # 2. Define call parameters
915
+ if prompt is not None and isinstance(prompt, str):
916
+ batch_size = 1
917
+ elif prompt is not None and isinstance(prompt, list):
918
+ batch_size = len(prompt)
919
+ else:
920
+ batch_size = prompt_embeds.shape[0]
921
+
922
+ device = self._execution_device
923
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
924
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
925
+ # corresponds to doing no classifier free guidance.
926
+ do_classifier_free_guidance = guidance_scale > 1.0
927
+
928
+ if isinstance(self.controlnet, MultiControlNetModel) and isinstance(
929
+ controlnet_conditioning_scale, float
930
+ ):
931
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
932
+ self.controlnet.nets
933
+ )
934
+
935
+ # 3. Encode input prompt
936
+ prompt_embeds = self._encode_prompt(
937
+ prompt,
938
+ device,
939
+ num_images_per_prompt,
940
+ do_classifier_free_guidance,
941
+ negative_prompt,
942
+ prompt_embeds=prompt_embeds,
943
+ negative_prompt_embeds=negative_prompt_embeds,
944
+ )
945
+
946
+ # 4. Prepare image, and controlnet_conditioning_image
947
+ image = prepare_image(image)
948
+
949
+ # condition image(s)
950
+ if isinstance(self.controlnet, ControlNetModel):
951
+ controlnet_conditioning_image = prepare_controlnet_conditioning_image(
952
+ controlnet_conditioning_image=controlnet_conditioning_image,
953
+ width=width,
954
+ height=height,
955
+ batch_size=batch_size * num_images_per_prompt,
956
+ num_images_per_prompt=num_images_per_prompt,
957
+ device=device,
958
+ dtype=self.controlnet.dtype,
959
+ do_classifier_free_guidance=do_classifier_free_guidance,
960
+ )
961
+ elif isinstance(self.controlnet, MultiControlNetModel):
962
+ controlnet_conditioning_images = []
963
+
964
+ for image_ in controlnet_conditioning_image:
965
+ image_ = prepare_controlnet_conditioning_image(
966
+ controlnet_conditioning_image=image_,
967
+ width=width,
968
+ height=height,
969
+ batch_size=batch_size * num_images_per_prompt,
970
+ num_images_per_prompt=num_images_per_prompt,
971
+ device=device,
972
+ dtype=self.controlnet.dtype,
973
+ do_classifier_free_guidance=do_classifier_free_guidance,
974
+ )
975
+
976
+ controlnet_conditioning_images.append(image_)
977
+
978
+ controlnet_conditioning_image = controlnet_conditioning_images
979
+ else:
980
+ assert False
981
+
982
+ # 5. Prepare timesteps
983
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
984
+ timesteps, num_inference_steps = self.get_timesteps(
985
+ num_inference_steps, strength, device
986
+ )
987
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
988
+
989
+ # 6. Prepare latent variables
990
+ latents = self.prepare_latents(
991
+ image,
992
+ latent_timestep,
993
+ batch_size,
994
+ num_images_per_prompt,
995
+ prompt_embeds.dtype,
996
+ device,
997
+ generator,
998
+ )
999
+
1000
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1001
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1002
+
1003
+ # 8. Denoising loop
1004
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1005
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1006
+ for i, t in enumerate(timesteps):
1007
+ # expand the latents if we are doing classifier free guidance
1008
+ latent_model_input = (
1009
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1010
+ )
1011
+
1012
+ latent_model_input = self.scheduler.scale_model_input(
1013
+ latent_model_input, t
1014
+ )
1015
+
1016
+ # compute the percentage of total steps we are at
1017
+ current_sampling_percent = i / len(timesteps)
1018
+
1019
+ if (
1020
+ current_sampling_percent < controlnet_guidance_start
1021
+ or current_sampling_percent > controlnet_guidance_end
1022
+ ):
1023
+ # do not apply the controlnet
1024
+ down_block_res_samples = None
1025
+ mid_block_res_sample = None
1026
+ else:
1027
+ # apply the controlnet
1028
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1029
+ latent_model_input,
1030
+ t,
1031
+ encoder_hidden_states=prompt_embeds,
1032
+ controlnet_cond=controlnet_conditioning_image,
1033
+ conditioning_scale=controlnet_conditioning_scale,
1034
+ return_dict=False,
1035
+ )
1036
+
1037
+ # predict the noise residual
1038
+ noise_pred = self.unet(
1039
+ latent_model_input,
1040
+ t,
1041
+ encoder_hidden_states=prompt_embeds,
1042
+ cross_attention_kwargs=cross_attention_kwargs,
1043
+ down_block_additional_residuals=down_block_res_samples,
1044
+ mid_block_additional_residual=mid_block_res_sample,
1045
+ ).sample
1046
+
1047
+ # perform guidance
1048
+ if do_classifier_free_guidance:
1049
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1050
+ noise_pred = noise_pred_uncond + guidance_scale * (
1051
+ noise_pred_text - noise_pred_uncond
1052
+ )
1053
+
1054
+ # compute the previous noisy sample x_t -> x_t-1
1055
+ latents = self.scheduler.step(
1056
+ noise_pred, t, latents, **extra_step_kwargs
1057
+ ).prev_sample
1058
+
1059
+ # call the callback, if provided
1060
+ if i == len(timesteps) - 1 or (
1061
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1062
+ ):
1063
+ progress_bar.update()
1064
+ if callback is not None and i % callback_steps == 0:
1065
+ callback(i, t, latents)
1066
+
1067
+ # If we do sequential model offloading, let's offload unet and controlnet
1068
+ # manually for max memory savings
1069
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1070
+ self.unet.to("cpu")
1071
+ self.controlnet.to("cpu")
1072
+ torch.cuda.empty_cache()
1073
+
1074
+ if output_type == "latent":
1075
+ image = latents
1076
+ has_nsfw_concept = None
1077
+ elif output_type == "pil":
1078
+ # 8. Post-processing
1079
+ image = self.decode_latents(latents)
1080
+
1081
+ # 9. Run safety checker
1082
+ image, has_nsfw_concept = self.run_safety_checker(
1083
+ image, device, prompt_embeds.dtype
1084
+ )
1085
+
1086
+ # 10. Convert to PIL
1087
+ image = self.numpy_to_pil(image)
1088
+ else:
1089
+ # 8. Post-processing
1090
+ image = self.decode_latents(latents)
1091
+
1092
+ # 9. Run safety checker
1093
+ image, has_nsfw_concept = self.run_safety_checker(
1094
+ image, device, prompt_embeds.dtype
1095
+ )
1096
+
1097
+ # Offload last model to CPU
1098
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1099
+ self.final_offload_hook.offload()
1100
+
1101
+ if not return_dict:
1102
+ return (image, has_nsfw_concept)
1103
+
1104
+ return StableDiffusionPipelineOutput(
1105
+ images=image, nsfw_content_detected=has_nsfw_concept
1106
+ )
internals/pipelines/twoStepPipeline.py CHANGED
@@ -250,3 +250,292 @@ class two_step_pipeline(StableDiffusionPipeline):
250
  return StableDiffusionPipelineOutput(
251
  images=image, nsfw_content_detected=has_nsfw_concept
252
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  return StableDiffusionPipelineOutput(
251
  images=image, nsfw_content_detected=has_nsfw_concept
252
  )
253
+
254
+ @torch.no_grad()
255
+ def multi_character_diffusion(
256
+ self,
257
+ prompt: Union[str, List[str]],
258
+ pos: List[str],
259
+ mix_val: Union[float, List[float]] = 0.5,
260
+ height: Optional[int] = None,
261
+ width: Optional[int] = None,
262
+ num_inference_steps: int = 50,
263
+ guidance_scale: float = 7.5,
264
+ negative_prompt: Optional[Union[str, List[str]]] = None,
265
+ num_images_per_prompt: Optional[int] = 1,
266
+ eta: float = 0.0,
267
+ generator: Optional[torch.Generator] = None,
268
+ latents: Optional[torch.FloatTensor] = None,
269
+ output_type: Optional[str] = "pil",
270
+ return_dict: bool = True,
271
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
272
+ callback_steps: Optional[int] = 1,
273
+ ):
274
+ r"""
275
+ Function invoked when calling the pipeline for generation.
276
+
277
+ Args:
278
+ prompt (`str` or `List[str]`):
279
+ The prompt or prompts to guide the image generation.
280
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
281
+ The height in pixels of the generated image.
282
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
283
+ The width in pixels of the generated image.
284
+ num_inference_steps (`int`, *optional*, defaults to 50):
285
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
286
+ expense of slower inference.
287
+ guidance_scale (`float`, *optional*, defaults to 7.5):
288
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
289
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
290
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
291
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
292
+ usually at the expense of lower image quality.
293
+ negative_prompt (`str` or `List[str]`, *optional*):
294
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
295
+ if `guidance_scale` is less than `1`).
296
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
297
+ The number of images to generate per prompt.
298
+ eta (`float`, *optional*, defaults to 0.0):
299
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
300
+ [`schedulers.DDIMScheduler`], will be ignored for others.
301
+ generator (`torch.Generator`, *optional*):
302
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
303
+ deterministic.
304
+ latents (`torch.FloatTensor`, *optional*):
305
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
306
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
307
+ tensor will ge generated by sampling using the supplied random `generator`.
308
+ output_type (`str`, *optional*, defaults to `"pil"`):
309
+ The output format of the generate image. Choose between
310
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
311
+ return_dict (`bool`, *optional*, defaults to `True`):
312
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
313
+ plain tuple.
314
+ callback (`Callable`, *optional*):
315
+ A function that will be called every `callback_steps` steps during inference. The function will be
316
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
317
+ callback_steps (`int`, *optional*, defaults to 1):
318
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
319
+ called at every step.
320
+
321
+ Returns:
322
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
323
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
324
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
325
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
326
+ (nsfw) content, according to the `safety_checker`.
327
+ """
328
+ # 生成する画像サイズは8で割り切れなければならない
329
+ height = height - height % 8
330
+ width = width - width % 8
331
+
332
+ # 0. Default height and width to unet
333
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
334
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
335
+
336
+ # 1. Check inputs. Raise error if not correct
337
+ self.check_inputs(prompt[0], height, width, callback_steps)
338
+
339
+ # 2. Define call parameters
340
+ batch_size = 1 if isinstance(prompt[0], str) else len(prompt[0])
341
+ device = self._execution_device
342
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
343
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
344
+ # corresponds to doing no classifier free guidance.
345
+ do_classifier_free_guidance = guidance_scale > 1.0
346
+
347
+ # 3. Encode input prompt
348
+ text_embeddings = []
349
+ for i in range(len(prompt)):
350
+ one_text_embeddings = self._encode_prompt(
351
+ prompt[i],
352
+ device,
353
+ num_images_per_prompt,
354
+ do_classifier_free_guidance,
355
+ negative_prompt[i],
356
+ )
357
+ text_embeddings.append(one_text_embeddings)
358
+
359
+ # 4. Prepare timesteps
360
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
361
+ timesteps = self.scheduler.timesteps
362
+
363
+ # 5. Prepare latent variables
364
+ num_channels_latents = self.unet.in_channels
365
+ latents = self.prepare_latents(
366
+ batch_size * num_images_per_prompt,
367
+ num_channels_latents,
368
+ height,
369
+ width,
370
+ text_embeddings[0].dtype,
371
+ device,
372
+ generator,
373
+ latents,
374
+ )
375
+
376
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
377
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
378
+
379
+ # 7. Denoising loop
380
+ # num_warmup_steps = len(timesteps) - num_inference_steps# * self.scheduler.order
381
+ for i, t in enumerate(self.progress_bar(timesteps)):
382
+ # expand the latents if we are doing classifier free guidance
383
+ latent_model_input = (
384
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
385
+ )
386
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
387
+
388
+ # predict the noise residual
389
+ noise_preds = []
390
+ for i in range(len(prompt)):
391
+ noise_pred = self.unet(
392
+ latent_model_input, t, encoder_hidden_states=text_embeddings[i]
393
+ ).sample
394
+ noise_preds.append(noise_pred)
395
+ # perform guidance
396
+ if do_classifier_free_guidance:
397
+ noise_pred_unconds = []
398
+ noise_pred_texts = []
399
+ for i in range(len(prompt)):
400
+ noise_pred_uncond, noise_pred_text = noise_preds[i].chunk(2)
401
+ noise_pred_unconds.append(noise_pred_uncond)
402
+ noise_pred_texts.append(noise_pred_text)
403
+ # TODO:posに基づいてフィルターを作る
404
+ mask_list = []
405
+ for i in range(len(prompt)):
406
+ pos_base = pos[i].split("-")
407
+ pos_dev = pos_base[0].split(":") # 1:2
408
+ pos_pos = pos_base[1].split(":") # 0:0
409
+ one_filter = None
410
+ zero_f = False
411
+ for y in range(int(pos_dev[0])):
412
+ one_line = None
413
+ zero = False
414
+ for x in range(int(pos_dev[1])):
415
+ if y == int(pos_pos[0]) and x == int(pos_pos[1]):
416
+ # print("same", zero, (height//8) // int(pos_dev[0]), (width//8) // int(pos_dev[1]))
417
+ if zero:
418
+ one_block = (
419
+ torch.ones(
420
+ batch_size,
421
+ 4,
422
+ (height // 8) // int(pos_dev[0]),
423
+ (width // 8) // int(pos_dev[1]),
424
+ )
425
+ .to(device)
426
+ .to(torch.float16)
427
+ * mix_val[i]
428
+ )
429
+ one_line = torch.cat((one_line, one_block), 3)
430
+ else:
431
+ zero = True
432
+ one_block = (
433
+ torch.ones(
434
+ batch_size,
435
+ 4,
436
+ (height // 8) // int(pos_dev[0]),
437
+ (width // 8) // int(pos_dev[1]),
438
+ )
439
+ .to(device)
440
+ .to(torch.float16)
441
+ * mix_val[i]
442
+ )
443
+ one_line = one_block
444
+ else:
445
+ # print("else", zero, (height//8) // int(pos_dev[0]), (width//8) // int(pos_dev[1]))
446
+ if zero:
447
+ one_block = (
448
+ torch.zeros(
449
+ batch_size,
450
+ 4,
451
+ (height // 8) // int(pos_dev[0]),
452
+ (width // 8) // int(pos_dev[1]),
453
+ )
454
+ .to(device)
455
+ .to(torch.float16)
456
+ )
457
+ one_line = torch.cat((one_line, one_block), 3)
458
+ else:
459
+ zero = True
460
+ one_block = (
461
+ torch.zeros(
462
+ batch_size,
463
+ 4,
464
+ (height // 8) // int(pos_dev[0]),
465
+ (width // 8) // int(pos_dev[1]),
466
+ )
467
+ .to(device)
468
+ .to(torch.float16)
469
+ )
470
+ one_line = one_block
471
+ one_block = (
472
+ torch.zeros(
473
+ batch_size,
474
+ 4,
475
+ (height // 8) // int(pos_dev[0]),
476
+ (width // 8) - one_line.size()[3],
477
+ )
478
+ .to(device)
479
+ .to(torch.float16)
480
+ )
481
+ one_line = torch.cat((one_line, one_block), 3)
482
+ if zero_f:
483
+ one_filter = torch.cat((one_filter, one_line), 2)
484
+ else:
485
+ zero_f = True
486
+ one_filter = one_line
487
+ mask_list.append(one_filter)
488
+ for i in range(len(mask_list)):
489
+ import torchvision
490
+
491
+ torchvision.transforms.functional.to_pil_image(
492
+ mask_list[i][0] * 256
493
+ ).save(str(i) + ".png")
494
+
495
+ result = None
496
+ noise_preds = []
497
+ for i in range(len(prompt)):
498
+ noise_pred = noise_pred_unconds[i] + guidance_scale * (
499
+ noise_pred_texts[i] - noise_pred_unconds[i]
500
+ )
501
+ noise_preds.append(noise_pred)
502
+ result = noise_preds[0] * mask_list[0]
503
+ for i in range(1, len(prompt)):
504
+ result += noise_preds[i] * mask_list[i]
505
+
506
+ # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
507
+
508
+ # compute the previous noisy sample x_t -> x_t-1
509
+ latents = self.scheduler.step(
510
+ result, t, latents, **extra_step_kwargs
511
+ ).prev_sample
512
+
513
+ # call the callback, if provided
514
+ if callback is not None and i % callback_steps == 0:
515
+ callback(i, t, latents)
516
+
517
+ # 8. Post-processing
518
+ image = self.decode_latents(latents)
519
+
520
+ # 9. Run safety checker
521
+ image, has_nsfw_concept = self.run_safety_checker(
522
+ image, device, text_embeddings[0].dtype
523
+ )
524
+
525
+ # 10. Convert to PIL
526
+ if output_type == "pil":
527
+ image = self.numpy_to_pil(image)
528
+ output = []
529
+ import torchvision
530
+
531
+ for i in mask_list:
532
+ output.append(
533
+ torchvision.transforms.functional.to_pil_image(i[0] * 256)
534
+ )
535
+
536
+ if not return_dict:
537
+ return (image, has_nsfw_concept)
538
+
539
+ return StableDiffusionPipelineOutput(
540
+ images=image, nsfw_content_detected=has_nsfw_concept
541
+ )
internals/util/config.py CHANGED
@@ -8,6 +8,8 @@ nsfw_access = False
8
  access_token = ""
9
  root_dir = ""
10
 
 
 
11
 
12
  def set_root_dir(main_file: str):
13
  global root_dir
 
8
  access_token = ""
9
  root_dir = ""
10
 
11
+ num_return_sequences = 4 # the number of results to generate
12
+
13
 
14
  def set_root_dir(main_file: str):
15
  global root_dir
internals/util/lora_style.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Union
5
 
6
  import boto3
7
  import torch
 
8
  from lora_diffusion import patch_pipe, tune_lora_scale
9
  from pydash import chain
10
 
@@ -32,7 +33,24 @@ class LoraStyle:
32
  def cleanup(self):
33
  tune_lora_scale(self.pipe.unet, 0.0)
34
  tune_lora_scale(self.pipe.text_encoder, 0.0)
35
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class EmptyLoraPatcher:
38
  def __init__(self, pipe):
@@ -41,7 +59,6 @@ class LoraStyle:
41
  def patch(self):
42
  "Patch will act as cleanup, to tune down any corrupted lora"
43
  self.cleanup()
44
- pass
45
 
46
  def kwargs(self):
47
  return {}
@@ -49,7 +66,7 @@ class LoraStyle:
49
  def cleanup(self):
50
  tune_lora_scale(self.pipe.unet, 0.0)
51
  tune_lora_scale(self.pipe.text_encoder, 0.0)
52
- pass
53
 
54
  def load(self, model_dir: str):
55
  self.model = model_dir
@@ -70,9 +87,13 @@ class LoraStyle:
70
  return f"{', '.join(style['text'])}, {prompt}"
71
  return prompt
72
 
73
- def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
 
 
74
  if key in self.__styles:
75
  style = self.__styles[key]
 
 
76
  return self.LoraPatcher(pipe, style)
77
  return self.EmptyLoraPatcher(pipe)
78
 
@@ -152,3 +173,8 @@ class LoraStyle:
152
  + " not found at path: "
153
  + self.__styles[item]["path"]
154
  )
 
 
 
 
 
 
5
 
6
  import boto3
7
  import torch
8
+ from diffusers.models.attention_processor import AttnProcessor2_0
9
  from lora_diffusion import patch_pipe, tune_lora_scale
10
  from pydash import chain
11
 
 
33
  def cleanup(self):
34
  tune_lora_scale(self.pipe.unet, 0.0)
35
  tune_lora_scale(self.pipe.text_encoder, 0.0)
36
+
37
+ class LoraDiffuserPatcher:
38
+ def __init__(self, pipe, style: Dict[str, Any]):
39
+ self.__style = style
40
+ self.pipe = pipe
41
+
42
+ @torch.inference_mode()
43
+ def patch(self):
44
+ path = self.__style["path"]
45
+ self.pipe.load_lora_weights(
46
+ os.path.dirname(path), weight_name=os.path.basename(path)
47
+ )
48
+
49
+ def kwargs(self):
50
+ return {}
51
+
52
+ def cleanup(self):
53
+ LoraStyle.unload_lora_weights(self.pipe)
54
 
55
  class EmptyLoraPatcher:
56
  def __init__(self, pipe):
 
59
  def patch(self):
60
  "Patch will act as cleanup, to tune down any corrupted lora"
61
  self.cleanup()
 
62
 
63
  def kwargs(self):
64
  return {}
 
66
  def cleanup(self):
67
  tune_lora_scale(self.pipe.unet, 0.0)
68
  tune_lora_scale(self.pipe.text_encoder, 0.0)
69
+ LoraStyle.unload_lora_weights(self.pipe)
70
 
71
  def load(self, model_dir: str):
72
  self.model = model_dir
 
87
  return f"{', '.join(style['text'])}, {prompt}"
88
  return prompt
89
 
90
+ def get_patcher(
91
+ self, pipe, key: str
92
+ ) -> Union[LoraPatcher, LoraDiffuserPatcher, EmptyLoraPatcher]:
93
  if key in self.__styles:
94
  style = self.__styles[key]
95
+ if style["type"] == "diffuser":
96
+ return self.LoraDiffuserPatcher(pipe, style)
97
  return self.LoraPatcher(pipe, style)
98
  return self.EmptyLoraPatcher(pipe)
99
 
 
173
  + " not found at path: "
174
  + self.__styles[item]["path"]
175
  )
176
+
177
+ @staticmethod
178
+ def unload_lora_weights(pipe):
179
+ pipe.unet.set_attn_processor(AttnProcessor2_0()) # for pytorch 2.0
180
+ pipe._remove_text_encoder_monkey_patch()
requirements.txt CHANGED
@@ -8,7 +8,6 @@ redis==4.3.4
8
  requests==2.28.1
9
  transformers
10
  rembg==2.0.30
11
- accelerate==0.17.0
12
  gfpgan==1.3.8
13
  rembg==2.0.30
14
  controlnet-aux==0.0.5
@@ -20,6 +19,7 @@ albumentations==0.5.2
20
  kornia==0.5.0
21
  pytorch-lightning==1.2.9
22
  pydash
 
23
  pandas
24
  xformers
25
  torchvision
 
8
  requests==2.28.1
9
  transformers
10
  rembg==2.0.30
 
11
  gfpgan==1.3.8
12
  rembg==2.0.30
13
  controlnet-aux==0.0.5
 
19
  kornia==0.5.0
20
  pytorch-lightning==1.2.9
21
  pydash
22
+ accelerate
23
  pandas
24
  xformers
25
  torchvision