game-of-life-controlnet / controlnet.py
jerpint's picture
add safety checker
87470f9
import torch
from diffusers import (
StableDiffusionControlNetImg2ImgPipeline,
ControlNetModel,
DDIMScheduler,
)
from PIL import Image
class QRControlNet:
def __init__(self, device: str = "cuda"):
torch_dtype = torch.float16 if device == "cuda" else torch.float32
controlnet = ControlNetModel.from_pretrained(
"DionTimmer/controlnet_qrcode-control_v1p_sd15", torch_dtype=torch_dtype
)
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
controlnet=controlnet,
# safety_checker=None,
torch_dtype=torch_dtype,
).to(device)
if device == "cuda":
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_model_cpu_offload()
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
self.pipe = pipe
def generate_image(
self,
source_image: Image,
control_image: Image,
prompt: str,
negative_prompt: str,
img_size=512,
num_inference_steps: int = 50,
guidance_scale: int = 20,
controlnet_conditioning_scale: float = 3.0,
strength=0.9,
seed=42,
**kwargs
):
width = height = img_size
generator = torch.manual_seed(seed)
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=source_image,
control_image=control_image,
width=width,
height=height,
guidance_scale=guidance_scale,
controlnet_conditioning_scale=controlnet_conditioning_scale, # 3.0,
generator=generator,
strength=strength,
num_inference_steps=num_inference_steps,
**kwargs
)
return image.images[0]