from __future__ import annotations import gc import numpy as np from PIL import Image import torch from diffusers import ( ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler ) import cv2 from torchvision import transforms CONTROLNET_MODEL_IDS = { "Canny": "briaai/BRIA-2.2-ControlNet-Canny", "Depth": "briaai/BRIA-2.2-ControlNet-Depth", "Recoloring": "briaai/BRIA-2.2-ControlNet-Recoloring", } def download_all_controlnet_weights() -> None: for model_id in CONTROLNET_MODEL_IDS.values(): ControlNetModel.from_pretrained(model_id) class Model: def __init__(self, base_model_id: str = "briaai/BRIA-2.2", task_name: str = "Canny"): self.device = torch.device("cuda:0") self.base_model_id = "" self.task_name = "" self.pipe = self.load_pipe(base_model_id, task_name) def load_pipe(self, base_model_id: str, task_name) -> DiffusionPipeline: if ( base_model_id == self.base_model_id and task_name == self.task_name and hasattr(self, "pipe") and self.pipe is not None ): return self.pipe model_id = CONTROLNET_MODEL_IDS[task_name] controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda') pipe = StableDiffusionXLControlNetPipeline.from_pretrained( base_model_id, controlnet=controlnet, torch_dtype=torch.float16, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True, ).to('cuda') pipe.scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, steps_offset=1 ) # pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7) pipe.enable_xformers_memory_efficient_attention() pipe.force_zeros_for_empty_prompt = False torch.cuda.empty_cache() gc.collect() self.base_model_id = base_model_id self.task_name = task_name return pipe def set_base_model(self, base_model_id: str) -> str: if not base_model_id or base_model_id == self.base_model_id: return self.base_model_id del self.pipe torch.cuda.empty_cache() gc.collect() try: self.pipe = self.load_pipe(base_model_id, self.task_name) except Exception: self.pipe = self.load_pipe(self.base_model_id, self.task_name) return self.base_model_id def load_controlnet_weight(self, task_name: str) -> None: if task_name == self.task_name: return if self.pipe is not None and hasattr(self.pipe, "controlnet"): del self.pipe.controlnet torch.cuda.empty_cache() gc.collect() model_id = CONTROLNET_MODEL_IDS[task_name] controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16) controlnet.to(self.device) torch.cuda.empty_cache() gc.collect() self.pipe.controlnet = controlnet self.task_name = task_name def get_prompt(self, prompt: str, additional_prompt: str) -> str: if not prompt: prompt = additional_prompt else: prompt = f"{prompt}, {additional_prompt}" return prompt @torch.autocast("cuda") def run_pipe( self, prompt: str, negative_prompt: str, control_image: Image.Image, num_images: int, num_steps: int, controlnet_conditioning_scale: float, seed: int, ) -> list[Image.Image]: generator = torch.Generator().manual_seed(seed) return self.pipe( prompt=prompt, negative_prompt=negative_prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, num_images_per_prompt=num_images, num_inference_steps=num_steps, generator=generator, image=control_image, ).images def resize_image(self, image): image = image.convert('RGB') current_size = image.size if current_size[0] > current_size[1]: center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1])) else: center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0])) resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024)) return resized_image def get_canny_filter(self, image): low_threshold = 100 high_threshold = 200 if not isinstance(image, np.ndarray): image = np.array(image) image = cv2.Canny(image, low_threshold, high_threshold) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) canny_image = Image.fromarray(image) return canny_image @torch.inference_mode() def process_canny( self, image: np.ndarray, prompt: str, negative_prompt: str, image_resolution: int, num_steps: int, controlnet_conditioning_scale: float, seed: int, ) -> list[Image.Image]: # resize input_image to 1024x1024 input_image = self.resize_image(image) canny_image = self.get_canny_filter(input_image) self.load_controlnet_weight("Canny") results = self.run_pipe( prompt=prompt, negative_prompt=negative_prompt, control_image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale) ) return [control_image] + results ################################################################################################################################ # from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL # from diffusers.utils import load_image # from PIL import Image # import torch # import numpy as np # import cv2 # import gradio as gr # from torchvision import transforms # controlnet = ControlNetModel.from_pretrained( # "briaai/BRIA-2.2-ControlNet-Canny", # torch_dtype=torch.float16 # ).to('cuda') # pipe = StableDiffusionXLControlNetPipeline.from_pretrained( # "briaai/BRIA-2.2", # controlnet=controlnet, # torch_dtype=torch.float16, # device_map='auto', # low_cpu_mem_usage=True, # offload_state_dict=True, # ).to('cuda') # pipe.scheduler = EulerAncestralDiscreteScheduler( # beta_start=0.00085, # beta_end=0.012, # beta_schedule="scaled_linear", # num_train_timesteps=1000, # steps_offset=1 # ) # # pipe.enable_freeu(b1=1.1, b2=1.1, s1=0.5, s2=0.7) # pipe.enable_xformers_memory_efficient_attention() # pipe.force_zeros_for_empty_prompt = False # low_threshold = 100 # high_threshold = 200 # def resize_image(image): # image = image.convert('RGB') # current_size = image.size # if current_size[0] > current_size[1]: # center_cropped_image = transforms.functional.center_crop(image, (current_size[1], current_size[1])) # else: # center_cropped_image = transforms.functional.center_crop(image, (current_size[0], current_size[0])) # resized_image = transforms.functional.resize(center_cropped_image, (1024, 1024)) # return resized_image # def get_canny_filter(image): # if not isinstance(image, np.ndarray): # image = np.array(image) # image = cv2.Canny(image, low_threshold, high_threshold) # image = image[:, :, None] # image = np.concatenate([image, image, image], axis=2) # canny_image = Image.fromarray(image) # return canny_image # def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed): # generator = torch.manual_seed(seed) # # resize input_image to 1024x1024 # input_image = resize_image(input_image) # canny_image = get_canny_filter(input_image) # images = pipe( # prompt, negative_prompt=negative_prompt, image=canny_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale), # generator=generator, # ).images # return [canny_image,images[0]] # block = gr.Blocks().queue() # with block: # gr.Markdown("## BRIA 2.2 ControlNet Canny") # gr.HTML(''' #
# This is a demo for ControlNet Canny that using # BRIA 2.2 text-to-image model as backbone. # Trained on licensed data, BRIA 2.2 provide full legal liability coverage for copyright and privacy infringement. #
# ''') # with gr.Row(): # with gr.Column(): # input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam # prompt = gr.Textbox(label="Prompt") # negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers") # num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1) # controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05) # seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,) # run_button = gr.Button(value="Run") # with gr.Column(): # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", columns=[2], height='auto') # ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed] # run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) # block.launch(debug = True)