import random import cv2 import einops import numpy as np import torch from pytorch_lightning import seed_everything from utils.data import HWC3, apply_color, resize_image from utils.ddim import DDIMSampler from utils.model import create_model, load_state_dict model = create_model('./models/cldm_v21.yaml').cpu() model.load_state_dict(load_state_dict( 'lightning_logs/version_6/checkpoints/colorizenet-sd21.ckpt', location='cuda')) model = model.cuda() ddim_sampler = DDIMSampler(model) input_image = cv2.imread("sample_data/sample1_bw.jpg") input_image = HWC3(input_image) img = resize_image(input_image, resolution=512) H, W, C = img.shape num_samples = 1 control = torch.from_numpy(img.copy()).float().cuda() / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() # seed = random.randint(0, 65535) seed = 1294574436 seed_everything(seed) prompt = "Colorize this image" n_prompt = "" guess_mode = False strength = 1.0 eta = 0.0 ddim_steps = 20 scale = 9.0 cond = {"c_concat": [control], "c_crossattn": [ model.get_learned_conditioning([prompt] * num_samples)]} un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [ model.get_learned_conditioning([n_prompt] * num_samples)]} shape = (4, H // 8, W // 8) model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ( [strength] * 13) samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples, shape, cond, verbose=False, eta=eta, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond) x_samples = model.decode_first_stage(samples) x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) results = [x_samples[i] for i in range(num_samples)] colored_results = [apply_color(img, result) for result in results] [cv2.imwrite(f"colorized_{i}.jpg", cv2.cvtColor(result, cv2.COLOR_RGB2BGR)) for i, result in enumerate(colored_results)]