yizhangliu
commited on
Commit
·
47dfe4c
1
Parent(s):
34fde06
update app.py
Browse files
app.py
CHANGED
@@ -45,7 +45,7 @@ plt = matplotlib.pyplot
|
|
45 |
|
46 |
groundingdino_enable = True
|
47 |
sam_enable = True
|
48 |
-
inpainting_enable =
|
49 |
ram_enable = False
|
50 |
|
51 |
lama_cleaner_enable = True
|
@@ -79,11 +79,13 @@ from io import BytesIO
|
|
79 |
from diffusers import StableDiffusionInpaintPipeline
|
80 |
from huggingface_hub import hf_hub_download
|
81 |
|
82 |
-
from
|
83 |
-
|
84 |
-
from
|
85 |
-
from kolors.
|
86 |
-
from
|
|
|
|
|
87 |
|
88 |
from util_computer import computer_info
|
89 |
|
@@ -329,6 +331,7 @@ def load_sd_model(device):
|
|
329 |
global sd_model
|
330 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
331 |
sd_model = None
|
|
|
332 |
if os.environ.get('IS_MY_DEBUG') is None:
|
333 |
# sd_model = StableDiffusionInpaintPipeline.from_pretrained(
|
334 |
# "runwayml/stable-diffusion-inpainting",
|
@@ -355,6 +358,7 @@ def load_sd_model(device):
|
|
355 |
|
356 |
sd_model.to(device)
|
357 |
sd_model.enable_attention_slicing()
|
|
|
358 |
|
359 |
def load_lama_cleaner_model(device):
|
360 |
# initialize lama_cleaner
|
@@ -613,6 +617,29 @@ def get_time_cost(run_task_time, time_cost_str):
|
|
613 |
run_task_time = now_time
|
614 |
return run_task_time, time_cost_str
|
615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
616 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
617 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
618 |
|
@@ -624,6 +651,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
624 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
625 |
|
626 |
# logger.info(f"input_image==={input_image}")
|
|
|
627 |
if 'background' in input_image.keys():
|
628 |
input_image['image'] = input_image['background'].convert("RGB")
|
629 |
if len(input_image['layers']) > 0:
|
@@ -794,7 +822,9 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
794 |
image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
|
795 |
output_images.append(image_mask_for_inpaint.convert("RGB"))
|
796 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
797 |
-
|
|
|
|
|
798 |
else:
|
799 |
# remove from mask
|
800 |
if mask_source_radio == mask_source_segment:
|
|
|
45 |
|
46 |
groundingdino_enable = True
|
47 |
sam_enable = True
|
48 |
+
inpainting_enable = True
|
49 |
ram_enable = False
|
50 |
|
51 |
lama_cleaner_enable = True
|
|
|
79 |
from diffusers import StableDiffusionInpaintPipeline
|
80 |
from huggingface_hub import hf_hub_download
|
81 |
|
82 |
+
from gradio_client import Client, handle_file
|
83 |
+
|
84 |
+
# from huggingface_hub import snapshot_download
|
85 |
+
# from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
|
86 |
+
# from kolors.models.modeling_chatglm import ChatGLMModel
|
87 |
+
# from kolors.models.tokenization_chatglm import ChatGLMTokenizer
|
88 |
+
# from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
89 |
|
90 |
from util_computer import computer_info
|
91 |
|
|
|
331 |
global sd_model
|
332 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
333 |
sd_model = None
|
334 |
+
'''
|
335 |
if os.environ.get('IS_MY_DEBUG') is None:
|
336 |
# sd_model = StableDiffusionInpaintPipeline.from_pretrained(
|
337 |
# "runwayml/stable-diffusion-inpainting",
|
|
|
358 |
|
359 |
sd_model.to(device)
|
360 |
sd_model.enable_attention_slicing()
|
361 |
+
'''
|
362 |
|
363 |
def load_lama_cleaner_model(device):
|
364 |
# initialize lama_cleaner
|
|
|
617 |
run_task_time = now_time
|
618 |
return run_task_time, time_cost_str
|
619 |
|
620 |
+
def load_kolors_inpainting(inpaint_prompt, image, mask_image):
|
621 |
+
# sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
622 |
+
|
623 |
+
client = Client("Kwai-Kolors/Kolors-Inpainting")
|
624 |
+
result = client.predict(
|
625 |
+
prompt=inpaint_prompt,
|
626 |
+
image=image,
|
627 |
+
mask_image = mask_image,
|
628 |
+
negative_prompt="broken fingers, deformed fingers, deformed hands, stumps, blurriness, low quality",
|
629 |
+
seed=0,
|
630 |
+
randomize_seed=True,
|
631 |
+
guidance_scale=6,
|
632 |
+
num_inference_steps=25,
|
633 |
+
api_name="/infer"
|
634 |
+
)
|
635 |
+
logger.info(f'load_kolors_inpainting_result={result}')
|
636 |
+
im = Image.open(result)
|
637 |
+
if im.mode == "RGBA":
|
638 |
+
im.load() # required for png.split()
|
639 |
+
background = Image.new("RGB", im.size, (255, 255, 255))
|
640 |
+
background.paste(im, mask=im.split()[3])
|
641 |
+
return result
|
642 |
+
|
643 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
644 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
645 |
|
|
|
651 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
652 |
|
653 |
# logger.info(f"input_image==={input_image}")
|
654 |
+
ori_input_image = input_image
|
655 |
if 'background' in input_image.keys():
|
656 |
input_image['image'] = input_image['background'].convert("RGB")
|
657 |
if len(input_image['layers']) > 0:
|
|
|
822 |
image_mask_for_inpaint = Image.fromarray(255*img_arr.astype('uint8'))
|
823 |
output_images.append(image_mask_for_inpaint.convert("RGB"))
|
824 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
825 |
+
|
826 |
+
# image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
827 |
+
image_inpainting = load_kolors_inpainting(ori_input_image, image_source_for_inpaint, image_mask_for_inpaint).images[0])
|
828 |
else:
|
829 |
# remove from mask
|
830 |
if mask_source_radio == mask_source_segment:
|