Spaces:
Paused
Paused
File size: 5,868 Bytes
bfd34e9 f1cc496 bfd34e9 736e88e 1df97f6 bfd34e9 da1e12f bfd34e9 736e88e bfd34e9 736e88e bfd34e9 736e88e bfd34e9 736e88e bfd34e9 736e88e da1e12f 736e88e bfd34e9 736e88e bfd34e9 da1e12f bfd34e9 da1e12f bfd34e9 736e88e bfd34e9 736e88e bfd34e9 1df97f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
from functools import partial
from glob import glob
from pathlib import Path as PythonPath
import cv2
import torchvision.transforms.functional as TvF
import torch
import torch.nn as nn
import numpy as np
from inspect import isfunction
from PIL import Image
from src import smplfusion
from src.smplfusion import share, router, attentionpatch, transformerpatch
from src.utils.iimage import IImage
from src.utils import poisson_blend
from src.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v
def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
lr_mask = hr_mask.resize(512)
x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0])
x_min = max(x_min - 1, 0)
y_min = max(y_min - 1, 0)
x_max = x_min + rect_w + 1
y_max = y_min + rect_h + 1
input_box = np.array([x_min, y_min, x_max, y_max])
sam_predictor.set_image(hr_image.resize(512).data[0])
masks, _, _ = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=True,
)
dilation_kernel = np.ones((13, 13))
original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
original_object_mask = cv2.dilate(original_object_mask, dilation_kernel)
sam_predictor.set_image(lr_image.resize(512).data[0])
masks, _, _ = sam_predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=True,
)
dilation_kernel = np.ones((3, 3))
inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel)
lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8)
new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis]
new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC)
return new_mask
def run(
ddim,
sam_predictor,
lr_image,
hr_image,
hr_mask,
prompt = 'high resolution professional photo',
noise_level=20,
blend_output = True,
blend_trick = True,
dt = 50,
seed = 1,
guidance_scale = 7.5,
negative_prompt = '',
use_sam_mask = False
):
hr_image_info = hr_image.info
lr_image = IImage(lr_image)
hr_image = IImage(hr_image).resize(2048)
hr_mask = IImage(hr_mask).resize(2048)
torch.manual_seed(seed)
dtype = ddim.vae.encoder.conv_in.weight.dtype
device = ddim.vae.encoder.conv_in.weight.device
router.attention_forward = attentionpatch.default.forward_xformers
router.basic_transformer_forward = transformerpatch.default.forward
hr_image_orig = hr_image
hr_mask_orig = hr_mask
if use_sam_mask:
with torch.no_grad():
hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)
orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
hr_image = hr_image.padx(256, padding_mode='reflect')
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
lr_image = lr_image.padx(64, padding_mode='reflect').torch()
lr_mask = hr_mask.resize((lr_image.shape[2:]), resample = Image.BICUBIC)
lr_mask = lr_mask.alpha().torch(vmin=0).to(device)
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
# encode hr image
with torch.no_grad():
hr_image = hr_image.torch().to(dtype=dtype, device=device)
hr_z0 = ddim.vae.encode(hr_image).mean * ddim.config.scale_factor
assert hr_z0.shape[2] == lr_image.shape[2]
assert hr_z0.shape[3] == lr_image.shape[3]
with torch.no_grad():
context = ddim.encoder.encode([negative_prompt, prompt])
noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
unet_condition = lr_image.to(dtype=dtype, device=device, memory_format=torch.contiguous_format)
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
with torch.autocast('cuda'), torch.no_grad():
zt = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3]))
zt = zt.cuda().to(dtype=dtype, device=device)
for index,t in enumerate(range(999, 0, -dt)):
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
eps_uncond, eps = ddim.unet(
torch.cat([_zt, _zt]).to(dtype=dtype, device=device),
timesteps = torch.tensor([t, t]).to(device=device),
context = context,
y=torch.cat([noise_level]*2)
).chunk(2)
ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
if blend_trick:
z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps
with torch.no_grad():
hr_result = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
# postprocess
hr_result = (255 * ((hr_result + 1) / 2).clip(0, 1)).to(torch.uint8)
hr_result = hr_result.cpu().permute(0, 2, 3, 1)[0].numpy()
hr_result = hr_result[:orig_h, :orig_w, :]
if blend_output:
hr_result = poisson_blend(
orig_img=hr_image_orig.data[0],
fake_img=hr_result,
mask=hr_mask_orig.alpha().data[0]
)
hr_result = Image.fromarray(hr_result)
hr_result.info = hr_image_info # save metadata
return hr_result
|