jayparmr commited on
Commit
049a85c
1 Parent(s): 92207d3

Upload folder using huggingface_hub

Browse files
inference.py CHANGED
@@ -20,13 +20,9 @@ from internals.util.args import apply_style_args
20
  from internals.util.avatar import Avatar
21
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
22
  from internals.util.commons import download_image, upload_image, upload_images
23
- from internals.util.config import (
24
- get_model_dir,
25
- num_return_sequences,
26
- set_configs_from_task,
27
- set_model_dir,
28
- set_root_dir,
29
- )
30
  from internals.util.failure_hander import FailureHandler
31
  from internals.util.lora_style import LoraStyle
32
  from internals.util.slack import Slack
@@ -468,7 +464,7 @@ def replace_bg(task: Task):
468
  width=task.get_width(),
469
  height=task.get_height(),
470
  steps=task.get_steps(),
471
- resize_dimension=task.get_resize_dimension(),
472
  product_scale_width=task.get_image_scale(),
473
  apply_high_res=task.get_high_res_fix(),
474
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
@@ -477,6 +473,7 @@ def replace_bg(task: Task):
477
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
478
 
479
  lora_patcher.cleanup()
 
480
 
481
  return {
482
  "modified_prompts": prompt,
@@ -486,6 +483,8 @@ def replace_bg(task: Task):
486
 
487
 
488
  def load_model_by_task(task: Task):
 
 
489
  if (
490
  task.get_type()
491
  in [
@@ -516,8 +515,6 @@ def load_model_by_task(task: Task):
516
  elif task.get_type() == TaskType.POSE:
517
  controlnet.load_pose()
518
 
519
- high_res.load()
520
-
521
  safety_checker.apply(controlnet)
522
 
523
 
 
20
  from internals.util.avatar import Avatar
21
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
22
  from internals.util.commons import download_image, upload_image, upload_images
23
+ from internals.util.config import (get_model_dir, num_return_sequences,
24
+ set_configs_from_task, set_model_dir,
25
+ set_root_dir)
 
 
 
 
26
  from internals.util.failure_hander import FailureHandler
27
  from internals.util.lora_style import LoraStyle
28
  from internals.util.slack import Slack
 
464
  width=task.get_width(),
465
  height=task.get_height(),
466
  steps=task.get_steps(),
467
+ extend_object=task.rbg_extend_object(),
468
  product_scale_width=task.get_image_scale(),
469
  apply_high_res=task.get_high_res_fix(),
470
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
 
473
  generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId())
474
 
475
  lora_patcher.cleanup()
476
+ clear_cuda()
477
 
478
  return {
479
  "modified_prompts": prompt,
 
483
 
484
 
485
  def load_model_by_task(task: Task):
486
+ high_res.load()
487
+
488
  if (
489
  task.get_type()
490
  in [
 
515
  elif task.get_type() == TaskType.POSE:
516
  controlnet.load_pose()
517
 
 
 
518
  safety_checker.apply(controlnet)
519
 
520
 
inference2.py CHANGED
@@ -13,19 +13,17 @@ from internals.pipelines.img_to_text import Image2Text
13
  from internals.pipelines.inpainter import InPainter
14
  from internals.pipelines.object_remove import ObjectRemoval
15
  from internals.pipelines.prompt_modifier import PromptModifier
16
- from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2
 
17
  from internals.pipelines.replace_background import ReplaceBackground
18
  from internals.pipelines.safety_checker import SafetyChecker
19
  from internals.pipelines.upscaler import Upscaler
20
  from internals.util.avatar import Avatar
21
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
22
- from internals.util.commons import construct_default_s3_url, upload_image, upload_images
23
- from internals.util.config import (
24
- num_return_sequences,
25
- set_configs_from_task,
26
- set_model_dir,
27
- set_root_dir,
28
- )
29
  from internals.util.failure_hander import FailureHandler
30
  from internals.util.lora_style import LoraStyle
31
  from internals.util.slack import Slack
@@ -171,7 +169,7 @@ def replace_bg(task: Task):
171
  width=task.get_width(),
172
  height=task.get_height(),
173
  steps=task.get_steps(),
174
- resize_dimension=task.get_resize_dimension(),
175
  product_scale_width=task.get_image_scale(),
176
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
177
  )
@@ -232,7 +230,9 @@ def model_fn(model_dir):
232
  inpainter.load()
233
  high_res.load()
234
 
235
- replace_background.load(upscaler, remove_background_v2)
 
 
236
 
237
  print("Logs: model loaded ....")
238
  return
 
13
  from internals.pipelines.inpainter import InPainter
14
  from internals.pipelines.object_remove import ObjectRemoval
15
  from internals.pipelines.prompt_modifier import PromptModifier
16
+ from internals.pipelines.remove_background import (RemoveBackground,
17
+ RemoveBackgroundV2)
18
  from internals.pipelines.replace_background import ReplaceBackground
19
  from internals.pipelines.safety_checker import SafetyChecker
20
  from internals.pipelines.upscaler import Upscaler
21
  from internals.util.avatar import Avatar
22
  from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda
23
+ from internals.util.commons import (construct_default_s3_url, upload_image,
24
+ upload_images)
25
+ from internals.util.config import (num_return_sequences, set_configs_from_task,
26
+ set_model_dir, set_root_dir)
 
 
 
27
  from internals.util.failure_hander import FailureHandler
28
  from internals.util.lora_style import LoraStyle
29
  from internals.util.slack import Slack
 
169
  width=task.get_width(),
170
  height=task.get_height(),
171
  steps=task.get_steps(),
172
+ extend_object=task.rbg_extend_object(),
173
  product_scale_width=task.get_image_scale(),
174
  conditioning_scale=task.rbg_controlnet_conditioning_scale(),
175
  )
 
230
  inpainter.load()
231
  high_res.load()
232
 
233
+ replace_background.load(
234
+ upscaler=upscaler, remove_background=remove_background_v2, high_res=high_res
235
+ )
236
 
237
  print("Logs: model loaded ....")
238
  return
internals/data/task.py CHANGED
@@ -141,6 +141,9 @@ class Task:
141
  def rbg_controlnet_conditioning_scale(self) -> float:
142
  return self.__data.get("rbg_conditioning_scale", 0.5)
143
 
 
 
 
144
  def get_nsfw_threshold(self) -> float:
145
  return self.__data.get("nsfw_threshold", 0.03)
146
 
 
141
  def rbg_controlnet_conditioning_scale(self) -> float:
142
  return self.__data.get("rbg_conditioning_scale", 0.5)
143
 
144
+ def rbg_extend_object(self) -> bool:
145
+ return self.__data.get("rbg_extend_object", False)
146
+
147
  def get_nsfw_threshold(self) -> float:
148
  return self.__data.get("nsfw_threshold", 0.03)
149
 
internals/pipelines/replace_background.py CHANGED
@@ -2,12 +2,9 @@ from io import BytesIO
2
  from typing import List, Optional, Union
3
 
4
  import torch
5
- from diffusers import (
6
- ControlNetModel,
7
- StableDiffusionControlNetInpaintPipeline,
8
- StableDiffusionInpaintPipeline,
9
- UniPCMultistepScheduler,
10
- )
11
  from PIL import Image, ImageFilter, ImageOps
12
 
13
  import internals.util.image as ImageUtil
@@ -46,7 +43,7 @@ class ReplaceBackground(AbstractPipeline):
46
  pipe.controlnet = controlnet_model
47
  else:
48
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
49
- get_model_dir(),
50
  controlnet=controlnet_model,
51
  torch_dtype=torch.float16,
52
  cache_dir=get_hf_cache_dir(),
@@ -81,12 +78,13 @@ class ReplaceBackground(AbstractPipeline):
81
  product_scale_width: float,
82
  prompt: List[str],
83
  negative_prompt: List[str],
84
- resize_dimension: int,
85
  conditioning_scale: float,
86
  seed: int,
87
  steps: int,
88
  apply_high_res: bool = False,
89
  ):
 
90
  if type(image) is str:
91
  image = download_image(image)
92
 
@@ -94,8 +92,8 @@ class ReplaceBackground(AbstractPipeline):
94
  torch.cuda.manual_seed(seed)
95
 
96
  image = image.convert("RGB")
97
- if max(image.size) > 1536:
98
- image = ImageUtil.resize_image(image, dimension=1536)
99
  image = self.remove_background.remove(image)
100
 
101
  width = int(width)
@@ -106,11 +104,15 @@ class ReplaceBackground(AbstractPipeline):
106
 
107
  print(width, height, n_width, n_height)
108
 
109
- image = ImageUtil.padd_image(image, n_width, n_height)
 
 
 
 
 
110
 
111
- f_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
112
- f_image.paste(image, ((width - n_width) // 2, (height - n_height) // 2))
113
- image = f_image
114
 
115
  mask = image.copy()
116
  pixdata = mask.load()
@@ -124,13 +126,13 @@ class ReplaceBackground(AbstractPipeline):
124
  else:
125
  pixdata[x, y] = (0, 0, 0, 255)
126
 
 
 
127
  mask = mask.convert("RGB")
128
 
129
- condition_image = ControlNet.linearart_condition_image(image)
130
-
131
  if apply_high_res and hasattr(self, "high_res"):
132
- (w, h) = self.high_res.get_intermediate_dimension(width, height)
133
- images = self.pipe.__call__(
134
  prompt=prompt,
135
  negative_prompt=negative_prompt,
136
  image=image,
@@ -142,15 +144,17 @@ class ReplaceBackground(AbstractPipeline):
142
  num_inference_steps=steps,
143
  height=w,
144
  width=h,
145
- ).images
146
- result = self.high_res.apply(
147
- prompt=prompt,
148
- negative_prompt=negative_prompt,
149
- images=images,
150
- width=width,
151
- height=width,
152
- steps=steps,
153
  )
 
 
 
 
 
 
 
 
 
 
154
  else:
155
  result = self.pipe.__call__(
156
  prompt=prompt,
 
2
  from typing import List, Optional, Union
3
 
4
  import torch
5
+ from diffusers import (ControlNetModel,
6
+ StableDiffusionControlNetInpaintPipeline,
7
+ StableDiffusionInpaintPipeline, UniPCMultistepScheduler)
 
 
 
8
  from PIL import Image, ImageFilter, ImageOps
9
 
10
  import internals.util.image as ImageUtil
 
43
  pipe.controlnet = controlnet_model
44
  else:
45
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
46
+ "runwayml/stable-diffusion-inpainting",
47
  controlnet=controlnet_model,
48
  torch_dtype=torch.float16,
49
  cache_dir=get_hf_cache_dir(),
 
78
  product_scale_width: float,
79
  prompt: List[str],
80
  negative_prompt: List[str],
81
+ extend_object: bool,
82
  conditioning_scale: float,
83
  seed: int,
84
  steps: int,
85
  apply_high_res: bool = False,
86
  ):
87
+ # image = Image.open("original.png")
88
  if type(image) is str:
89
  image = download_image(image)
90
 
 
92
  torch.cuda.manual_seed(seed)
93
 
94
  image = image.convert("RGB")
95
+ if max(image.size) > 1024:
96
+ image = ImageUtil.resize_image(image, dimension=1024)
97
  image = self.remove_background.remove(image)
98
 
99
  width = int(width)
 
104
 
105
  print(width, height, n_width, n_height)
106
 
107
+ if extend_object:
108
+ condition_image = ControlNet.linearart_condition_image(image).resize(
109
+ (n_width, n_height)
110
+ )
111
+ condition_image = ImageUtil.padd_image(condition_image, width, height)
112
+ condition_image = condition_image.convert("RGB")
113
 
114
+ image = image.resize((n_width, n_height))
115
+ image = ImageUtil.padd_image(image, width, height)
 
116
 
117
  mask = image.copy()
118
  pixdata = mask.load()
 
126
  else:
127
  pixdata[x, y] = (0, 0, 0, 255)
128
 
129
+ if not extend_object:
130
+ condition_image = ControlNet.linearart_condition_image(image)
131
  mask = mask.convert("RGB")
132
 
 
 
133
  if apply_high_res and hasattr(self, "high_res"):
134
+ w, h = HighRes.get_intermediate_dimension(width, height)
135
+ result = self.pipe.__call__(
136
  prompt=prompt,
137
  negative_prompt=negative_prompt,
138
  image=image,
 
144
  num_inference_steps=steps,
145
  height=w,
146
  width=h,
 
 
 
 
 
 
 
 
147
  )
148
+ for i, _ in enumerate(result.images):
149
+ out_bytes = self.upscaler.upscale(
150
+ image=result.images[i],
151
+ width=w,
152
+ height=h,
153
+ face_enhance=False,
154
+ resize_dimension=max(width, height),
155
+ )
156
+ result.images[i] = Image.open(BytesIO(out_bytes)).convert("RGB")
157
+ result = Result.from_result(result)
158
  else:
159
  result = self.pipe.__call__(
160
  prompt=prompt,
internals/util/image.py CHANGED
@@ -50,7 +50,7 @@ def padd_image(image: Image.Image, to_width: int, to_height: int) -> Image.Image
50
  # resize Image
51
  if iw > ih:
52
  image = image.resize((value, int(value * ih / iw)))
53
- else:
54
  image = image.resize((int(value * iw / ih), value))
55
 
56
  # padd Image
 
50
  # resize Image
51
  if iw > ih:
52
  image = image.resize((value, int(value * ih / iw)))
53
+ elif ih > iw:
54
  image = image.resize((int(value * iw / ih), value))
55
 
56
  # padd Image