import os import torch from diffusers import DDIMScheduler,DiffusionPipeline import torch.nn.functional as F import cv2 from torchvision.utils import save_image from diffusers.utils import load_image from torchvision.transforms.functional import to_tensor, gaussian_blur from matplotlib import pyplot as plt import gradio as gr import spaces from gradio_imageslider import ImageSlider from torchvision.transforms.functional import to_pil_image, to_tensor from PIL import ImageFilter, Image import traceback def preprocess_image(input_image, device): image = to_tensor(input_image) image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1] if image.shape[1] != 3: image = image.expand(-1, 3, -1, -1) image = F.interpolate(image, (1024, 1024)) image = image.to(dtype).to(device) return image def load_description(fp): with open(fp, 'r', encoding='utf-8') as f: content = f.read() return content def preprocess_mask(input_mask, device): # Split the channels r, g, b, alpha = input_mask.split() # Create a new image where: # - Black areas (where RGB = 0) become white (255). # - Transparent areas (where alpha = 0) become black (0). new_mask = Image.new("L", input_mask.size) for x in range(input_mask.width): for y in range(input_mask.height): if alpha.getpixel((x, y)) == 0: # Transparent pixel new_mask.putpixel((x, y), 0) # Set to black else: # Non-transparent pixel (originally black in the mask) new_mask.putpixel((x, y), 255) # Set to white mask = to_tensor(new_mask.convert('L')) mask = mask.unsqueeze_(0).float() # 0 or 1 mask = F.interpolate(mask, (1024, 1024)) mask = gaussian_blur(mask, kernel_size=(77, 77)) mask[mask < 0.1] = 0 mask[mask >= 0.1] = 1 mask = mask.to(dtype).to(device) return mask def make_redder(img, mask, increase_factor=0.4): img_redder = img.clone() mask_expanded = mask.expand_as(img) img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1) return img_redder # Model loading parameters is_cpu_offload_enabled = False is_attention_slicing_enabled = True # Load model dtype = torch.float16 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) model_path = "stabilityai/stable-diffusion-xl-base-1.0" pipeline = DiffusionPipeline.from_pretrained( model_path, custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser.py", scheduler=scheduler, variant="fp16", use_safetensors=True, torch_dtype=dtype, ).to(device) if is_attention_slicing_enabled: pipeline.enable_attention_slicing() if is_cpu_offload_enabled: pipeline.enable_model_cpu_offload() @spaces.GPU def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8): try: generator = torch.Generator('cuda').manual_seed(seed) prompt = "" # Set prompt to null source_image_pure = gradio_image["background"] mask_image_pure = gradio_image["layers"][0] source_image = preprocess_image(source_image_pure.convert('RGB'), device) mask = preprocess_mask(mask_image_pure, device) START_STEP = 0 # AAS start step END_STEP = int(strength * num_inference_steps) # AAS end step LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer END_LAYER = 70 # AAS end layer ss_steps = 9 # similarity suppression steps ss_scale = 0.3 # similarity suppression scale image = pipeline( prompt=prompt, image=source_image, mask_image=mask, height=1024, width=1024, AAS=True, # enable AAS strength=strength, # inpainting strength rm_guidance_scale=rm_guidance_scale, # removal guidance scale ss_steps = ss_steps, # similarity suppression steps ss_scale = ss_scale, # similarity suppression scale AAS_start_step=START_STEP, # AAS start step AAS_start_layer=LAYER, # AAS start layer AAS_end_layer=END_LAYER, # AAS end layer num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps) generator=generator, guidance_scale=1 ).images[0] print('Inferece: DONE.') pil_mask = to_pil_image(mask.squeeze(0)) pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15)) mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device) mask_f = 1-(1 - mask) * (1 - mask_blurred) # image_1 = image.unsqueeze(0) return source_image_pure, pil_mask, image except: print(traceback.format_exc()) title = """

Object Remove

""" with gr.Blocks() as demo: gr.HTML(load_description("assets/title.md")) with gr.Row(): with gr.Column(): with gr.Accordion("Advanced Options", open=False): guidance_scale = gr.Slider( minimum=1, maximum=20, value=9, step=0.1, label="Guidance Scale" ) num_steps = gr.Slider( minimum=5, maximum=100, value=50, step=1, label="Steps" ) seed = gr.Slider( minimum=42, maximum=999999, value=42, step=1, label="Seed" ) strength = gr.Slider( minimum=0, maximum=1, value=0.8, step=0.1, label="Strength" ) input_image = gr.ImageMask( type="pil", label="Input Image",crop_size=(1200,1200), layers=False ) with gr.Column(): with gr.Row(): with gr.Column(): run_button = gr.Button("Generate") result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto") run_button.click( fn=remove, inputs=[input_image, guidance_scale, num_steps, seed, strength], outputs=result, ) demo.queue(max_size=12).launch(share=False)