|
''' |
|
# -------------------------------------------------------------------------------- |
|
# |
|
# StableSR for Automatic1111 WebUI |
|
# |
|
# Introducing state-of-the super-resolution method: StableSR! |
|
# Techniques is originally proposed by my schoolmate Jianyi Wang et, al. |
|
# |
|
# Project Page: https://iceclear.github.io/projects/stablesr/ |
|
# Official Repo: https://github.com/IceClear/StableSR |
|
# Paper: https://arxiv.org/abs/2305.07015 |
|
# |
|
# @original author: Jianyi Wang et, al. |
|
# @migration: LI YI |
|
# @organization: Nanyang Technological University - Singapore |
|
# @date: 2023-05-20 |
|
# @license: |
|
# S-Lab License 1.0 (see LICENSE file) |
|
# CC BY-NC-SA 4.0 (required by NVIDIA SPADE module) |
|
# |
|
# @disclaimer: |
|
# All code in this extension is for research purpose only. |
|
# The commercial use of the code & checkpoint is strictly prohibited. |
|
# |
|
# -------------------------------------------------------------------------------- |
|
# |
|
# IMPORTANT NOTICE FOR OUTCOME IMAGES: |
|
# - Please be aware that the CC BY-NC-SA 4.0 license in SPADE module |
|
# also prohibits the commercial use of outcome images. |
|
# - Jianyi Wang may change the SPADE module to a commercial-friendly one. |
|
# If you want to use the outcome images for commercial purposes, please |
|
# contact Jianyi Wang for more information. |
|
# |
|
# Please give me a star (and also Jianyi's repo) if you like this project! |
|
# |
|
# -------------------------------------------------------------------------------- |
|
''' |
|
|
|
import os |
|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
import PIL.Image as Image |
|
|
|
from pathlib import Path |
|
from torch import Tensor |
|
from tqdm import tqdm |
|
|
|
from modules import scripts, processing, sd_samplers, devices, images, shared |
|
from modules.processing import StableDiffusionProcessingImg2Img, Processed |
|
from modules.shared import opts |
|
from ldm.modules.diffusionmodules.openaimodel import UNetModel |
|
|
|
from srmodule.spade import SPADELayers |
|
from srmodule.struct_cond import EncoderUNetModelWT, build_unetwt |
|
from srmodule.colorfix import adain_color_fix, wavelet_color_fix |
|
|
|
SD_WEBUI_PATH = Path.cwd() |
|
ME_PATH = SD_WEBUI_PATH / 'extensions' / 'sd-webui-stablesr' |
|
MODEL_PATH = ME_PATH / 'models' |
|
FORWARD_CACHE_NAME = 'org_forward_stablesr' |
|
|
|
class StableSR: |
|
def __init__(self, path, dtype, device): |
|
state_dict = torch.load(path, map_location='cpu') |
|
self.struct_cond_model: EncoderUNetModelWT = build_unetwt() |
|
self.spade_layers: SPADELayers = SPADELayers() |
|
self.struct_cond_model.load_from_dict(state_dict) |
|
self.spade_layers.load_from_dict(state_dict) |
|
del state_dict |
|
self.struct_cond_model.apply(lambda x: x.to(dtype=dtype, device=device)) |
|
self.spade_layers.apply(lambda x: x.to(dtype=dtype, device=device)) |
|
|
|
self.latent_image: Tensor = None |
|
self.set_image_hooks = {} |
|
self.struct_cond: Tensor = None |
|
|
|
def set_latent_image(self, latent_image): |
|
self.latent_image = latent_image |
|
for hook in self.set_image_hooks.values(): |
|
hook(latent_image) |
|
|
|
def hook(self, unet: UNetModel): |
|
|
|
if not hasattr(unet, FORWARD_CACHE_NAME): |
|
setattr(unet, FORWARD_CACHE_NAME, unet.forward) |
|
|
|
def unet_forward(x, timesteps=None, context=None, y=None,**kwargs): |
|
self.latent_image = self.latent_image.to(x.device) |
|
|
|
|
|
self.spade_layers.to(x.device) |
|
self.struct_cond_model.to(x.device) |
|
timesteps = timesteps.to(x.device) |
|
self.struct_cond = None |
|
self.struct_cond = self.struct_cond_model(self.latent_image, timesteps[:self.latent_image.shape[0]]) |
|
return getattr(unet, FORWARD_CACHE_NAME)(x, timesteps, context, y, **kwargs) |
|
|
|
unet.forward = unet_forward |
|
|
|
self.spade_layers.hook(unet, lambda: self.struct_cond) |
|
|
|
|
|
def unhook(self, unet: UNetModel): |
|
|
|
self.latent_image = None |
|
self.struct_cond = None |
|
self.set_image_hooks = {} |
|
|
|
if hasattr(unet, FORWARD_CACHE_NAME): |
|
unet.forward = getattr(unet, FORWARD_CACHE_NAME) |
|
delattr(unet, FORWARD_CACHE_NAME) |
|
|
|
|
|
self.spade_layers.unhook() |
|
|
|
|
|
class Script(scripts.Script): |
|
def __init__(self) -> None: |
|
self.model_list = {} |
|
self.load_model_list() |
|
self.last_path = None |
|
self.stablesr_model: StableSR = None |
|
|
|
def load_model_list(self): |
|
|
|
self.model_list = {} |
|
if not MODEL_PATH.exists(): |
|
MODEL_PATH.mkdir() |
|
for file in MODEL_PATH.iterdir(): |
|
if file.is_file(): |
|
|
|
self.model_list[file.name] = str(file.absolute()) |
|
self.model_list['None'] = None |
|
|
|
def title(self): |
|
return "StableSR" |
|
|
|
def show(self, is_img2img): |
|
return is_img2img |
|
|
|
def ui(self, is_img2img): |
|
with gr.Row(): |
|
model = gr.Dropdown(list(self.model_list.keys()), label="SR Model") |
|
refresh = gr.Button(value='↻', variant='tool') |
|
def refresh_fn(selected): |
|
self.load_model_list() |
|
if selected not in self.model_list: |
|
selected = 'None' |
|
return gr.Dropdown.update(value=selected, choices=list(self.model_list.keys())) |
|
refresh.click(fn=refresh_fn,inputs=model, outputs=model) |
|
with gr.Row(): |
|
scale_factor = gr.Slider(minimum=1, maximum=16, step=0.1, value=2, label='Scale Factor', elem_id=f'StableSR-scale') |
|
with gr.Row(): |
|
color_fix = gr.Dropdown(['None', 'Wavelet', 'AdaIN'], label="Color Fix", value='Wavelet', elem_id=f'StableSR-color-fix') |
|
save_original = gr.Checkbox(label='Save Original', value=False, elem_id=f'StableSR-save-original', visible=color_fix.value != 'None') |
|
color_fix.change(fn=lambda selected: gr.Checkbox.update(visible=selected != 'None'), inputs=color_fix, outputs=save_original, show_progress=False) |
|
pure_noise = gr.Checkbox(label='Pure Noise', value=True, elem_id=f'StableSR-pure-noise') |
|
unload_model= gr.Button(value='Unload Model', variant='tool') |
|
def unload_model_fn(): |
|
if self.stablesr_model is not None: |
|
self.stablesr_model = None |
|
devices.torch_gc() |
|
print('[StableSR] Model unloaded!') |
|
else: |
|
print('[StableSR] No model loaded.') |
|
unload_model.click(fn=unload_model_fn) |
|
return [model, scale_factor, pure_noise, color_fix, save_original] |
|
|
|
def run(self, p: StableDiffusionProcessingImg2Img, model: str, scale_factor:float, pure_noise: bool, color_fix:str, save_original:bool) -> Processed: |
|
|
|
if model == 'None': |
|
|
|
self.stablesr_model = None |
|
self.last_model_path = None |
|
return |
|
|
|
if model not in self.model_list: |
|
raise gr.Error(f"Model {model} is not in the list! Please refresh your browser!") |
|
|
|
if not os.path.exists(self.model_list[model]): |
|
raise gr.Error(f"Model {model} is not on your disk! Please refresh the model list!") |
|
|
|
if color_fix not in ['None', 'Wavelet', 'AdaIN']: |
|
print(f'[StableSR] Invalid color fix method: {color_fix}') |
|
color_fix = 'None' |
|
|
|
|
|
init_img: Image = p.init_images[0] |
|
target_width = int(init_img.width * scale_factor) |
|
target_height = int(init_img.height * scale_factor) |
|
|
|
if target_width % 8 != 0: |
|
target_width = target_width + 8 - target_width % 8 |
|
|
|
if target_height % 8 != 0: |
|
target_height = target_height + 8 - target_height % 8 |
|
init_img = init_img.resize((target_width, target_height), Image.LANCZOS) |
|
p.init_images[0] = init_img |
|
p.width = init_img.width |
|
p.height = init_img.height |
|
|
|
print('[StableSR] Target image size: {}x{}'.format(init_img.width, init_img.height)) |
|
|
|
first_param = shared.sd_model.parameters().__next__() |
|
if self.last_path != self.model_list[model]: |
|
|
|
self.stablesr_model = None |
|
|
|
if self.stablesr_model is None: |
|
self.stablesr_model = StableSR(self.model_list[model], dtype=first_param.dtype, device=first_param.device) |
|
self.last_path = self.model_list[model] |
|
|
|
def sample_custom(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): |
|
try: |
|
unet: UNetModel = shared.sd_model.model.diffusion_model |
|
self.stablesr_model.hook(unet) |
|
self.stablesr_model.set_latent_image(p.init_latent) |
|
x = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) |
|
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) |
|
if pure_noise: |
|
|
|
samples = sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) |
|
else: |
|
if p.initial_noise_multiplier != 1.0: |
|
p.extra_generation_params["Noise multiplier"] =p.initial_noise_multiplier |
|
x *= p.initial_noise_multiplier |
|
samples = sampler.sample_img2img(p, p.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) |
|
|
|
if p.mask is not None: |
|
samples = samples * p.nmask + p.init_latent * p.mask |
|
del x |
|
devices.torch_gc() |
|
return samples |
|
finally: |
|
self.stablesr_model.unhook(unet) |
|
|
|
self.stablesr_model.struct_cond_model.to(device=first_param.device) |
|
self.stablesr_model.spade_layers.to(device=first_param.device) |
|
|
|
|
|
|
|
p.sample = sample_custom |
|
|
|
if color_fix != 'None': |
|
p.do_not_save_samples = True |
|
|
|
result: Processed = processing.process_images(p) |
|
|
|
if color_fix != 'None': |
|
|
|
fixed_images = [] |
|
|
|
color_fix_func = wavelet_color_fix if color_fix == 'Wavelet' else adain_color_fix |
|
for i in range(len(result.images)): |
|
try: |
|
fixed_images.append(color_fix_func(result.images[i], init_img)) |
|
except Exception as e: |
|
print(f'[StableSR] Error fixing color with default method: {e}') |
|
|
|
|
|
for i in range(len(fixed_images)): |
|
try: |
|
images.save_image(fixed_images[i], p.outpath_samples, "", p.all_seeds[i], p.all_prompts[i], opts.samples_format, info=result.infotexts[i], p=p) |
|
except Exception as e: |
|
print(f'[StableSR] Error saving color fixed image: {e}') |
|
|
|
if save_original: |
|
for i in range(len(result.images)): |
|
try: |
|
images.save_image(result.images[i], p.outpath_samples, "", p.all_seeds[i], p.all_prompts[i], opts.samples_format, info=result.infotexts[i], p=p, suffix="-before-color-fix") |
|
except Exception as e: |
|
print(f'[StableSR] Error saving original image: {e}') |
|
result.images = result.images + fixed_images |
|
|
|
return result |
|
|