diff --git a/.gitignore b/.gitignore
index 34296f3f74a4fdc83385ae540d1a3b3e91a388ef..9587d1172e51842cd0515bbdf95e911e57a7aef1 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,6 @@
outputs/
gradio_tmp/
-__pycache__/
\ No newline at end of file
+__pycache__/
+
+checkpoints/
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..daa17f15be45251e1b2fe1306af2d361bc1681ab
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Picsart AI Research (PAIR)
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/assets/config/ddpm/v1.yaml b/__assets__/demo/config/ddpm/v1.yaml
similarity index 100%
rename from assets/config/ddpm/v1.yaml
rename to __assets__/demo/config/ddpm/v1.yaml
diff --git a/assets/config/ddpm/v2-upsample.yaml b/__assets__/demo/config/ddpm/v2-upsample.yaml
similarity index 100%
rename from assets/config/ddpm/v2-upsample.yaml
rename to __assets__/demo/config/ddpm/v2-upsample.yaml
diff --git a/assets/config/encoders/clip.yaml b/__assets__/demo/config/encoders/clip.yaml
similarity index 100%
rename from assets/config/encoders/clip.yaml
rename to __assets__/demo/config/encoders/clip.yaml
diff --git a/assets/config/encoders/openclip.yaml b/__assets__/demo/config/encoders/openclip.yaml
similarity index 100%
rename from assets/config/encoders/openclip.yaml
rename to __assets__/demo/config/encoders/openclip.yaml
diff --git a/assets/config/unet/inpainting/v1.yaml b/__assets__/demo/config/unet/inpainting/v1.yaml
similarity index 100%
rename from assets/config/unet/inpainting/v1.yaml
rename to __assets__/demo/config/unet/inpainting/v1.yaml
diff --git a/assets/config/unet/inpainting/v2.yaml b/__assets__/demo/config/unet/inpainting/v2.yaml
similarity index 100%
rename from assets/config/unet/inpainting/v2.yaml
rename to __assets__/demo/config/unet/inpainting/v2.yaml
diff --git a/assets/config/unet/upsample/v2.yaml b/__assets__/demo/config/unet/upsample/v2.yaml
similarity index 100%
rename from assets/config/unet/upsample/v2.yaml
rename to __assets__/demo/config/unet/upsample/v2.yaml
diff --git a/assets/config/vae-upsample.yaml b/__assets__/demo/config/vae-upsample.yaml
similarity index 100%
rename from assets/config/vae-upsample.yaml
rename to __assets__/demo/config/vae-upsample.yaml
diff --git a/assets/config/vae.yaml b/__assets__/demo/config/vae.yaml
similarity index 100%
rename from assets/config/vae.yaml
rename to __assets__/demo/config/vae.yaml
diff --git a/assets/examples/images_1024/a19.jpg b/__assets__/demo/examples/images_1024/a19.jpg
similarity index 100%
rename from assets/examples/images_1024/a19.jpg
rename to __assets__/demo/examples/images_1024/a19.jpg
diff --git a/assets/examples/images_1024/a2.jpg b/__assets__/demo/examples/images_1024/a2.jpg
similarity index 100%
rename from assets/examples/images_1024/a2.jpg
rename to __assets__/demo/examples/images_1024/a2.jpg
diff --git a/assets/examples/images_1024/a4.jpg b/__assets__/demo/examples/images_1024/a4.jpg
similarity index 100%
rename from assets/examples/images_1024/a4.jpg
rename to __assets__/demo/examples/images_1024/a4.jpg
diff --git a/assets/examples/images_1024/a40.jpg b/__assets__/demo/examples/images_1024/a40.jpg
similarity index 100%
rename from assets/examples/images_1024/a40.jpg
rename to __assets__/demo/examples/images_1024/a40.jpg
diff --git a/assets/examples/images_1024/a46.jpg b/__assets__/demo/examples/images_1024/a46.jpg
similarity index 100%
rename from assets/examples/images_1024/a46.jpg
rename to __assets__/demo/examples/images_1024/a46.jpg
diff --git a/assets/examples/images_1024/a51.jpg b/__assets__/demo/examples/images_1024/a51.jpg
similarity index 100%
rename from assets/examples/images_1024/a51.jpg
rename to __assets__/demo/examples/images_1024/a51.jpg
diff --git a/assets/examples/images_1024/a54.jpg b/__assets__/demo/examples/images_1024/a54.jpg
similarity index 100%
rename from assets/examples/images_1024/a54.jpg
rename to __assets__/demo/examples/images_1024/a54.jpg
diff --git a/assets/examples/images_1024/a65.jpg b/__assets__/demo/examples/images_1024/a65.jpg
similarity index 100%
rename from assets/examples/images_1024/a65.jpg
rename to __assets__/demo/examples/images_1024/a65.jpg
diff --git a/assets/examples/images_2048/a19.jpg b/__assets__/demo/examples/images_2048/a19.jpg
similarity index 100%
rename from assets/examples/images_2048/a19.jpg
rename to __assets__/demo/examples/images_2048/a19.jpg
diff --git a/assets/examples/images_2048/a2.jpg b/__assets__/demo/examples/images_2048/a2.jpg
similarity index 100%
rename from assets/examples/images_2048/a2.jpg
rename to __assets__/demo/examples/images_2048/a2.jpg
diff --git a/assets/examples/images_2048/a4.jpg b/__assets__/demo/examples/images_2048/a4.jpg
similarity index 100%
rename from assets/examples/images_2048/a4.jpg
rename to __assets__/demo/examples/images_2048/a4.jpg
diff --git a/assets/examples/images_2048/a40.jpg b/__assets__/demo/examples/images_2048/a40.jpg
similarity index 100%
rename from assets/examples/images_2048/a40.jpg
rename to __assets__/demo/examples/images_2048/a40.jpg
diff --git a/assets/examples/images_2048/a46.jpg b/__assets__/demo/examples/images_2048/a46.jpg
similarity index 100%
rename from assets/examples/images_2048/a46.jpg
rename to __assets__/demo/examples/images_2048/a46.jpg
diff --git a/assets/examples/images_2048/a51.jpg b/__assets__/demo/examples/images_2048/a51.jpg
similarity index 100%
rename from assets/examples/images_2048/a51.jpg
rename to __assets__/demo/examples/images_2048/a51.jpg
diff --git a/assets/examples/images_2048/a54.jpg b/__assets__/demo/examples/images_2048/a54.jpg
similarity index 100%
rename from assets/examples/images_2048/a54.jpg
rename to __assets__/demo/examples/images_2048/a54.jpg
diff --git a/assets/examples/images_2048/a65.jpg b/__assets__/demo/examples/images_2048/a65.jpg
similarity index 100%
rename from assets/examples/images_2048/a65.jpg
rename to __assets__/demo/examples/images_2048/a65.jpg
diff --git a/assets/examples/sbs/a19.png b/__assets__/demo/examples/sbs/a19.png
similarity index 100%
rename from assets/examples/sbs/a19.png
rename to __assets__/demo/examples/sbs/a19.png
diff --git a/assets/examples/sbs/a2.png b/__assets__/demo/examples/sbs/a2.png
similarity index 100%
rename from assets/examples/sbs/a2.png
rename to __assets__/demo/examples/sbs/a2.png
diff --git a/assets/examples/sbs/a4.png b/__assets__/demo/examples/sbs/a4.png
similarity index 100%
rename from assets/examples/sbs/a4.png
rename to __assets__/demo/examples/sbs/a4.png
diff --git a/assets/examples/sbs/a40.png b/__assets__/demo/examples/sbs/a40.png
similarity index 100%
rename from assets/examples/sbs/a40.png
rename to __assets__/demo/examples/sbs/a40.png
diff --git a/assets/examples/sbs/a46.png b/__assets__/demo/examples/sbs/a46.png
similarity index 100%
rename from assets/examples/sbs/a46.png
rename to __assets__/demo/examples/sbs/a46.png
diff --git a/assets/examples/sbs/a51.png b/__assets__/demo/examples/sbs/a51.png
similarity index 100%
rename from assets/examples/sbs/a51.png
rename to __assets__/demo/examples/sbs/a51.png
diff --git a/assets/examples/sbs/a54.png b/__assets__/demo/examples/sbs/a54.png
similarity index 100%
rename from assets/examples/sbs/a54.png
rename to __assets__/demo/examples/sbs/a54.png
diff --git a/assets/examples/sbs/a65.png b/__assets__/demo/examples/sbs/a65.png
similarity index 100%
rename from assets/examples/sbs/a65.png
rename to __assets__/demo/examples/sbs/a65.png
diff --git a/assets/sr_info.png b/__assets__/demo/sr_info.png
similarity index 100%
rename from assets/sr_info.png
rename to __assets__/demo/sr_info.png
diff --git a/app.py b/app.py
index f77c0ff2de5cffe24d3fc24894a1f3cc6b68bbf9..071b36fa59e3e4f33d7662c6f96c48d1820a0e30 100644
--- a/app.py
+++ b/app.py
@@ -1,40 +1,44 @@
import os
+import sys
+from pathlib import Path
from collections import OrderedDict
import gradio as gr
import shutil
import uuid
import torch
-from pathlib import Path
-from lib.utils.iimage import IImage
from PIL import Image
-from lib import models
-from lib.methods import rasg, sd, sr
-from lib.utils import poisson_blend, image_from_url_text
+demo_path = Path(__file__).resolve().parent
+root_path = demo_path
+sys.path.append(str(root_path))
+from src import models
+from src.methods import rasg, sd, sr
+from src.utils import IImage, poisson_blend, image_from_url_text
-TMP_DIR = 'gradio_tmp'
-if Path(TMP_DIR).exists():
- shutil.rmtree(TMP_DIR)
-Path(TMP_DIR).mkdir(exist_ok=True, parents=True)
+TMP_DIR = root_path / 'gradio_tmp'
+if TMP_DIR.exists():
+ shutil.rmtree(str(TMP_DIR))
+TMP_DIR.mkdir(exist_ok=True, parents=True)
-os.environ['GRADIO_TEMP_DIR'] = TMP_DIR
+os.environ['GRADIO_TEMP_DIR'] = str(TMP_DIR)
on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality"
positive_prompt_str = "Full HD, 4K, high quality, high resolution"
+examples_path = root_path / '__assets__/demo/examples'
example_inputs = [
- ['assets/examples/images_1024/a40.jpg', 'assets/examples/images_2048/a40.jpg', 'medieval castle'],
- ['assets/examples/images_1024/a4.jpg', 'assets/examples/images_2048/a4.jpg', 'parrot'],
- ['assets/examples/images_1024/a65.jpg', 'assets/examples/images_2048/a65.jpg', 'hoodie'],
- ['assets/examples/images_1024/a54.jpg', 'assets/examples/images_2048/a54.jpg', 'salad'],
- ['assets/examples/images_1024/a51.jpg', 'assets/examples/images_2048/a51.jpg', 'space helmet'],
- ['assets/examples/images_1024/a46.jpg', 'assets/examples/images_2048/a46.jpg', 'stack of books'],
- ['assets/examples/images_1024/a19.jpg', 'assets/examples/images_2048/a19.jpg', 'antique greek vase'],
- ['assets/examples/images_1024/a2.jpg', 'assets/examples/images_2048/a2.jpg', 'sunglasses'],
+ [f'{examples_path}/images_1024/a40.jpg', f'{examples_path}/images_2048/a40.jpg', 'medieval castle'],
+ [f'{examples_path}/images_1024/a4.jpg', f'{examples_path}/images_2048/a4.jpg', 'parrot'],
+ [f'{examples_path}/images_1024/a65.jpg', f'{examples_path}/images_2048/a65.jpg', 'hoodie'],
+ [f'{examples_path}/images_1024/a54.jpg', f'{examples_path}/images_2048/a54.jpg', 'salad'],
+ [f'{examples_path}/images_1024/a51.jpg', f'{examples_path}/images_2048/a51.jpg', 'space helmet'],
+ [f'{examples_path}/images_1024/a46.jpg', f'{examples_path}/images_2048/a46.jpg', 'stack of books'],
+ [f'{examples_path}/images_1024/a19.jpg', f'{examples_path}/images_2048/a19.jpg', 'antique greek vase'],
+ [f'{examples_path}/images_1024/a2.jpg', f'{examples_path}/images_2048/a2.jpg', 'sunglasses'],
]
thumbnails = [
@@ -60,27 +64,35 @@ example_previews = [
]
# Load models
+models.pre_download_inpainting_models()
inpainting_models = OrderedDict([
- ("Dreamshaper Inpainting V8", models.ds_inp.load_model()),
- ("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
- ("Stable-Inpainting 1.5", models.sd15_inp.load_model())
+ ("Dreamshaper Inpainting V8", 'ds8_inp'),
+ ("Stable-Inpainting 2.0", 'sd2_inp'),
+ ("Stable-Inpainting 1.5", 'sd15_inp')
])
sr_model = models.sd2_sr.load_model(device='cuda:1')
sam_predictor = models.sam.load_model(device='cuda:0')
-inp_model = inpainting_models[list(inpainting_models.keys())[0]]
-def set_model_from_name(inp_model_name):
+inp_model_name = list(inpainting_models.keys())[0]
+inp_model = models.load_inpainting_model(
+ inpainting_models[inp_model_name], device='cuda:0', cache=False)
+
+
+def set_model_from_name(new_inp_model_name):
global inp_model
- print (f"Activating Inpaintng Model: {inp_model_name}")
- inp_model = inpainting_models[inp_model_name]
+ global inp_model_name
+ if new_inp_model_name != inp_model_name:
+ print (f"Activating Inpaintng Model: {new_inp_model_name}")
+ inp_model = models.load_inpainting_model(
+ inpainting_models[new_inp_model_name], device='cuda:0', cache=False)
+ inp_model_name = new_inp_model_name
def save_user_session(hr_image, hr_mask, lr_results, prompt, session_id=None):
if session_id == '':
session_id = str(uuid.uuid4())
- tmp_dir = Path(TMP_DIR)
- session_dir = tmp_dir / session_id
+ session_dir = TMP_DIR / session_id
session_dir.mkdir(exist_ok=True, parents=True)
hr_image.save(session_dir / 'hr_image.png')
@@ -103,8 +115,7 @@ def recover_user_session(session_id):
if session_id == '':
return None, None, [], ''
- tmp_dir = Path(TMP_DIR)
- session_dir = tmp_dir / session_id
+ session_dir = TMP_DIR / session_id
lr_results_dir = session_dir / 'lr_results'
hr_image = Image.open(session_dir / 'hr_image.png')
@@ -121,64 +132,22 @@ def recover_user_session(session_id):
return hr_image, hr_mask, gallery, prompt
-def rasg_run(
- use_painta, prompt, imageMask, hr_image, seed, eta,
- negative_prompt, positive_prompt, ddim_steps,
- guidance_scale=7.5,
- batch_size=1, session_id=''
+def inpainting_run(model_name, use_rasg, use_painta, prompt, imageMask,
+ hr_image, seed, eta, negative_prompt, positive_prompt, ddim_steps,
+ guidance_scale=7.5, batch_size=1, session_id=''
):
torch.cuda.empty_cache()
+ set_model_from_name(model_name)
- seed = int(seed)
- batch_size = max(1, min(int(batch_size), 4))
-
- image = IImage(hr_image).resize(512)
- mask = IImage(imageMask['mask']).rgb().resize(512)
-
- method = ['rasg']
+ method = ['default']
if use_painta: method.append('painta')
+ if use_rasg: method.append('rasg')
method = '-'.join(method)
- inpainted_images = []
- blended_images = []
- for i in range(batch_size):
- seed = seed + i * 1000
-
- inpainted_image = rasg.run(
- ddim=inp_model,
- method=method,
- prompt=prompt,
- image=image,
- mask=mask,
- seed=seed,
- eta=eta,
- negative_prompt=negative_prompt,
- positive_prompt=positive_prompt,
- num_steps=ddim_steps,
- guidance_scale=guidance_scale
- ).crop(image.size)
-
- blended_image = poisson_blend(
- orig_img=image.data[0],
- fake_img=inpainted_image.data[0],
- mask=mask.data[0],
- dilation=12
- )
- blended_images.append(blended_image)
- inpainted_images.append(inpainted_image.pil())
-
- session_id = save_user_session(
- hr_image, imageMask['mask'], inpainted_images, prompt, session_id=session_id)
-
- return blended_images, session_id
-
-
-def sd_run(use_painta, prompt, imageMask, hr_image, seed, eta,
- negative_prompt, positive_prompt, ddim_steps,
- guidance_scale=7.5,
- batch_size=1, session_id=''
-):
- torch.cuda.empty_cache()
+ if use_rasg:
+ inpainting_f = rasg.run
+ else:
+ inpainting_f = sd.run
seed = int(seed)
batch_size = max(1, min(int(batch_size), 4))
@@ -195,7 +164,7 @@ def sd_run(use_painta, prompt, imageMask, hr_image, seed, eta,
for i in range(batch_size):
seed = seed + i * 1000
- inpainted_image = sd.run(
+ inpainted_image = inpainting_f(
ddim=inp_model,
method=method,
prompt=prompt,
@@ -226,13 +195,12 @@ def sd_run(use_painta, prompt, imageMask, hr_image, seed, eta,
def upscale_run(
ddim_steps, seed, use_sam_mask, session_id, img_index,
- negative_prompt='',
- positive_prompt=', high resolution professional photo'
+ negative_prompt='', positive_prompt='high resolution professional photo'
):
hr_image, hr_mask, gallery, prompt = recover_user_session(session_id)
if len(gallery) == 0:
- return Image.open('./assets/sr_info.png')
+ return Image.open(root_path / '__assets__/sr_info.png')
torch.cuda.empty_cache()
@@ -249,7 +217,7 @@ def upscale_run(
inpainted_image,
hr_image,
hr_mask,
- prompt=prompt + positive_prompt,
+ prompt=f'{prompt}, {positive_prompt}',
noise_level=20,
blend_trick=True,
blend_output=True,
@@ -261,14 +229,7 @@ def upscale_run(
return output_image
-def switch_run(use_rasg, model_name, *args):
- set_model_from_name(model_name)
- if use_rasg:
- return rasg_run(*args)
- return sd_run(*args)
-
-
-with gr.Blocks(css='style.css') as demo:
+with gr.Blocks(css=demo_path / 'style.css') as demo:
gr.HTML(
"""
@@ -300,7 +261,7 @@ with gr.Blocks(css='style.css') as demo:
""")
- with open('script.js', 'r') as f:
+ with open(demo_path / 'script.js', 'r') as f:
js_str = f.read()
demo.load(_js=js_str)
@@ -380,10 +341,10 @@ with gr.Blocks(css='style.css') as demo:
html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
inpaint_btn.click(
- fn=switch_run,
+ fn=inpainting_run,
inputs=[
- use_rasg,
model_picker,
+ use_rasg,
use_painta,
prompt,
imageMask,
@@ -415,4 +376,4 @@ with gr.Blocks(css='style.css') as demo:
)
demo.queue(max_size=20)
-demo.launch(share=True, allowed_paths=[TMP_DIR])
\ No newline at end of file
+demo.launch(share=True, allowed_paths=[str(TMP_DIR)])
\ No newline at end of file
diff --git a/assets/.gitignore b/assets/.gitignore
deleted file mode 100644
index 6ea8874968d000cd47f52f55f32a92f0127532b3..0000000000000000000000000000000000000000
--- a/assets/.gitignore
+++ /dev/null
@@ -1 +0,0 @@
-models/
\ No newline at end of file
diff --git a/config/ddpm/v1.yaml b/config/ddpm/v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..95c4053aac12d443ea8071c23f07c3d1a8b97488
--- /dev/null
+++ b/config/ddpm/v1.yaml
@@ -0,0 +1,14 @@
+linear_start: 0.00085
+linear_end: 0.0120
+num_timesteps_cond: 1
+log_every_t: 200
+timesteps: 1000
+first_stage_key: "jpg"
+cond_stage_key: "txt"
+image_size: 64
+channels: 4
+cond_stage_trainable: false
+conditioning_key: crossattn
+monitor: val/loss_simple_ema
+scale_factor: 0.18215
+use_ema: False # we set this to false because this is an inference only config
\ No newline at end of file
diff --git a/config/ddpm/v2-upsample.yaml b/config/ddpm/v2-upsample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..450576d21d0cb33958465db0179151a521828606
--- /dev/null
+++ b/config/ddpm/v2-upsample.yaml
@@ -0,0 +1,24 @@
+parameterization: "v"
+low_scale_key: "lr"
+linear_start: 0.0001
+linear_end: 0.02
+num_timesteps_cond: 1
+log_every_t: 200
+timesteps: 1000
+first_stage_key: "jpg"
+cond_stage_key: "txt"
+image_size: 128
+channels: 4
+cond_stage_trainable: false
+conditioning_key: "hybrid-adm"
+monitor: val/loss_simple_ema
+scale_factor: 0.08333
+use_ema: False
+
+low_scale_config:
+ target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
+ params:
+ noise_schedule_config: # image space
+ linear_start: 0.0001
+ linear_end: 0.02
+ max_noise_level: 350
diff --git a/config/encoders/clip.yaml b/config/encoders/clip.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8082b5b56b0cec7d586f4e0830d206ab3fccde10
--- /dev/null
+++ b/config/encoders/clip.yaml
@@ -0,0 +1 @@
+__class__: smplfusion.models.encoders.clip_embedder.FrozenCLIPEmbedder
\ No newline at end of file
diff --git a/config/encoders/openclip.yaml b/config/encoders/openclip.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ca74e9c97a230642a8023d843d793858c9e4c5c0
--- /dev/null
+++ b/config/encoders/openclip.yaml
@@ -0,0 +1,4 @@
+__class__: smplfusion.models.encoders.open_clip_embedder.FrozenOpenCLIPEmbedder
+__init__:
+ freeze: True
+ layer: "penultimate"
\ No newline at end of file
diff --git a/config/unet/inpainting/v1.yaml b/config/unet/inpainting/v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b6be2f6b5c129bb8c52f0432824723851bda325d
--- /dev/null
+++ b/config/unet/inpainting/v1.yaml
@@ -0,0 +1,15 @@
+__class__: smplfusion.models.unet.UNetModel
+__init__:
+ image_size: 32 # unused
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: False
+ legacy: False
\ No newline at end of file
diff --git a/config/unet/inpainting/v2.yaml b/config/unet/inpainting/v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c78bc2a37a344dd1499aaab01580e0c2cd7e27bc
--- /dev/null
+++ b/config/unet/inpainting/v2.yaml
@@ -0,0 +1,16 @@
+__class__: smplfusion.models.unet.UNetModel
+__init__:
+ use_checkpoint: False
+ image_size: 32 # unused
+ in_channels: 9
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
\ No newline at end of file
diff --git a/config/unet/upsample/v2.yaml b/config/unet/upsample/v2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1fcbd87e98a510aa718afc552bb111e1110d71fe
--- /dev/null
+++ b/config/unet/upsample/v2.yaml
@@ -0,0 +1,19 @@
+__class__: smplfusion.models.unet.UNetModel
+__init__:
+ use_checkpoint: False
+ num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
+ image_size: 128
+ in_channels: 7
+ out_channels: 4
+ model_channels: 256
+ attention_resolutions: [ 2,4,8]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 2, 4]
+ disable_self_attentions: [True, True, True, False]
+ disable_middle_self_attn: False
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+ use_linear_in_transformer: True
\ No newline at end of file
diff --git a/config/vae-upsample.yaml b/config/vae-upsample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..989a5c6d374ebc3bce469c151765250d49071330
--- /dev/null
+++ b/config/vae-upsample.yaml
@@ -0,0 +1,16 @@
+__class__: smplfusion.models.vae.AutoencoderKL
+__init__:
+ embed_dim: 4
+ ddconfig:
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [ 1,2,4 ]
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
\ No newline at end of file
diff --git a/config/vae.yaml b/config/vae.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..52e68334ce00ab350838f4dd865aff9da6482dae
--- /dev/null
+++ b/config/vae.yaml
@@ -0,0 +1,17 @@
+__class__: smplfusion.models.vae.AutoencoderKL
+__init__:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult: [1,2,4,4]
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
\ No newline at end of file
diff --git a/lib/models/__init__.py b/lib/models/__init__.py
deleted file mode 100644
index 0f38837760733e194c475623e9b3821017918940..0000000000000000000000000000000000000000
--- a/lib/models/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from . import sd2_inp, ds_inp, sd15_inp, sd2_sr, sam
\ No newline at end of file
diff --git a/lib/models/common.py b/lib/models/common.py
deleted file mode 100644
index 70c5473274efadf65e4bfc67da956a4317273998..0000000000000000000000000000000000000000
--- a/lib/models/common.py
+++ /dev/null
@@ -1,49 +0,0 @@
-import importlib
-import requests
-from pathlib import Path
-from os.path import dirname
-
-from omegaconf import OmegaConf
-from tqdm import tqdm
-
-
-PROJECT_DIR = dirname(dirname(dirname(__file__)))
-CONFIG_FOLDER = f'{PROJECT_DIR}/assets/config'
-MODEL_FOLDER = f'{PROJECT_DIR}/assets/models'
-
-
-def download_file(url, save_path, chunk_size=1024):
- try:
- save_path = Path(save_path)
- if save_path.exists():
- print(f'{save_path.name} exists')
- return
- save_path.parent.mkdir(exist_ok=True, parents=True)
- resp = requests.get(url, stream=True)
- total = int(resp.headers.get('content-length', 0))
- with open(save_path, 'wb') as file, tqdm(
- desc=save_path.name,
- total=total,
- unit='iB',
- unit_scale=True,
- unit_divisor=1024,
- ) as bar:
- for data in resp.iter_content(chunk_size=chunk_size):
- size = file.write(data)
- bar.update(size)
- print(f'{save_path.name} download finished')
- except Exception as e:
- raise Exception(f"Download failed: {e}")
-
-
-def get_obj_from_str(string):
- module, cls = string.rsplit(".", 1)
- try:
- return getattr(importlib.import_module(module, package=None), cls)
- except:
- return getattr(importlib.import_module('lib.' + module, package=None), cls)
-
-
-def load_obj(path):
- objyaml = OmegaConf.load(path)
- return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
diff --git a/lib/models/ds_inp.py b/lib/models/ds_inp.py
deleted file mode 100644
index 45c368e2b01233b4cbae97c04dc66c9b8d212f33..0000000000000000000000000000000000000000
--- a/lib/models/ds_inp.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import importlib
-from omegaconf import OmegaConf
-import torch
-import safetensors
-import safetensors.torch
-
-from lib.smplfusion import DDIM, share, scheduler
-from .common import *
-
-
-MODEL_PATH = f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors'
-DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
-
-# pre-download
-download_file(DOWNLOAD_URL, MODEL_PATH)
-
-
-def load_model(dtype=torch.float16):
- print ("Loading model: Dreamshaper Inpainting V8")
-
- download_file(DOWNLOAD_URL, MODEL_PATH)
-
- state_dict = safetensors.torch.load_file(MODEL_PATH)
-
- config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
- unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
- vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
- encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
-
- extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
- unet_state = extract(state_dict, 'model.diffusion_model')
- encoder_state = extract(state_dict, 'cond_stage_model')
- vae_state = extract(state_dict, 'first_stage_model')
-
- unet.load_state_dict(unet_state)
- encoder.load_state_dict(encoder_state)
- vae.load_state_dict(vae_state)
-
- if dtype == torch.float16:
- unet.convert_to_fp16()
- vae.to(dtype)
- encoder.to(dtype)
-
- unet = unet.requires_grad_(False)
- encoder = encoder.requires_grad_(False)
- vae = vae.requires_grad_(False)
-
- ddim = DDIM(config, vae, encoder, unet)
- share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
-
- return ddim
diff --git a/lib/models/sd15_inp.py b/lib/models/sd15_inp.py
deleted file mode 100644
index ea6d6c740fc5b8aee00d92dc34ffe3e0d9fe3756..0000000000000000000000000000000000000000
--- a/lib/models/sd15_inp.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from omegaconf import OmegaConf
-import torch
-
-from lib.smplfusion import DDIM, share, scheduler
-from .common import *
-
-
-DOWNLOAD_URL = 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt?download=true'
-MODEL_PATH = f'{MODEL_FOLDER}/sd-1-5-inpainting/sd-v1-5-inpainting.ckpt'
-
-# pre-download
-download_file(DOWNLOAD_URL, MODEL_PATH)
-
-
-def load_model(dtype=torch.float16):
- download_file(DOWNLOAD_URL, MODEL_PATH)
-
- state_dict = torch.load(MODEL_PATH)['state_dict']
-
- config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
-
- print ("Loading model: Stable-Inpainting 1.5")
-
- unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
- vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
- encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
-
- extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
- unet_state = extract(state_dict, 'model.diffusion_model')
- encoder_state = extract(state_dict, 'cond_stage_model')
- vae_state = extract(state_dict, 'first_stage_model')
-
- unet.load_state_dict(unet_state)
- encoder.load_state_dict(encoder_state)
- vae.load_state_dict(vae_state)
-
- if dtype == torch.float16:
- unet.convert_to_fp16()
- vae.to(dtype)
- encoder.to(dtype)
-
- unet = unet.requires_grad_(False)
- encoder = encoder.requires_grad_(False)
- vae = vae.requires_grad_(False)
-
- ddim = DDIM(config, vae, encoder, unet)
- share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
-
- return ddim
diff --git a/lib/models/sd2_inp.py b/lib/models/sd2_inp.py
deleted file mode 100644
index 42aa9b5bdcced92c296b5cdf89d62a17eb8cf7df..0000000000000000000000000000000000000000
--- a/lib/models/sd2_inp.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import safetensors
-import safetensors.torch
-import torch
-from omegaconf import OmegaConf
-
-from lib.smplfusion import DDIM, share, scheduler
-from .common import *
-
-MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-inpainting/512-inpainting-ema.safetensors'
-DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true'
-
-# pre-download
-download_file(DOWNLOAD_URL, MODEL_PATH)
-
-
-def load_model(dtype=torch.float16, device='cuda:0'):
- print ("Loading model: Stable-Inpainting 2.0")
-
- download_file(DOWNLOAD_URL, MODEL_PATH)
-
- state_dict = safetensors.torch.load_file(MODEL_PATH)
-
- config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
-
- unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda()
- vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
- encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
- ddim = DDIM(config, vae, encoder, unet)
-
- extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
- unet_state = extract(state_dict, 'model.diffusion_model')
- encoder_state = extract(state_dict, 'cond_stage_model')
- vae_state = extract(state_dict, 'first_stage_model')
-
- unet.load_state_dict(unet_state)
- encoder.load_state_dict(encoder_state)
- vae.load_state_dict(vae_state)
-
- if dtype == torch.float16:
- unet.convert_to_fp16()
- unet.to(device=device)
- vae.to(dtype=dtype, device=device)
- encoder.to(dtype=dtype, device=device)
- encoder.device = device
-
- unet = unet.requires_grad_(False)
- encoder = encoder.requires_grad_(False)
- vae = vae.requires_grad_(False)
-
- ddim = DDIM(config, vae, encoder, unet)
- share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
-
- print('Stable-Inpainting 2.0 loaded')
- return ddim
diff --git a/lib/__init__.py b/src/__init__.py
similarity index 100%
rename from lib/__init__.py
rename to src/__init__.py
diff --git a/lib/methods/__init__.py b/src/methods/__init__.py
similarity index 100%
rename from lib/methods/__init__.py
rename to src/methods/__init__.py
diff --git a/lib/methods/rasg.py b/src/methods/rasg.py
similarity index 86%
rename from lib/methods/rasg.py
rename to src/methods/rasg.py
index 0b3a64b8fe16add290415e1d9523ebb999db0f35..ac41c2ce605e8592d69e9cb8320f9ba0e9bb2d8c 100644
--- a/lib/methods/rasg.py
+++ b/src/methods/rasg.py
@@ -1,11 +1,11 @@
import torch
-from lib.utils.iimage import IImage
+from src.utils.iimage import IImage
from pytorch_lightning import seed_everything
from tqdm import tqdm
-from lib.smplfusion import share, router, attentionpatch, transformerpatch
-from lib.smplfusion.patches.attentionpatch import painta
-from lib.utils import tokenize, scores
+from src.smplfusion import share, router, attentionpatch, transformerpatch
+from src.smplfusion.patches.attentionpatch import painta
+from src.utils import tokenize, scores
verbose = False
@@ -37,8 +37,10 @@ def run(
guidance_scale=7.5
):
image = image.padx(64)
- mask = mask.alpha().padx(64)
- full_prompt = f'{prompt}, {positive_prompt}'
+ mask = mask.dilate(1).alpha().padx(64)
+ full_prompt = prompt
+ if positive_prompt != '':
+ full_prompt = f'{prompt}, {positive_prompt}'
dt = 1000 // num_steps
# Text condition
@@ -91,18 +93,15 @@ def run(
score = scores.bce(share._crossattn_similarity_res16, share.mask16, token_idx = token_idx)
score.backward()
grad = zt.grad.detach()
- ddim.unet.zero_grad() # Cleanup already
+ ddim.unet.zero_grad()
# DDIM Step
with torch.no_grad():
sigma = share.schedule.sigma(share.timestep, dt)
- # Standartization
- grad -= grad.mean()
grad /= grad.std()
-
zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + \
- torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - sigma ** 2) * eps + \
- eta * sigma * grad
+ torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - (eta * sigma) ** 2) * eps + \
+ (eta * sigma) * grad
with torch.no_grad():
output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
diff --git a/lib/methods/sd.py b/src/methods/sd.py
similarity index 86%
rename from lib/methods/sd.py
rename to src/methods/sd.py
index 45d94b2a2a55317236664338bdae53d77528cf2e..130fa59325f450ec34c7b161abeebbe8a9575142 100644
--- a/lib/methods/sd.py
+++ b/src/methods/sd.py
@@ -2,10 +2,10 @@ import torch
from pytorch_lightning import seed_everything
from tqdm import tqdm
-from lib.utils.iimage import IImage
-from lib.smplfusion import share, router, attentionpatch, transformerpatch
-from lib.smplfusion.patches.attentionpatch import painta
-from lib.utils import tokenize
+from src.utils.iimage import IImage
+from src.smplfusion import share, router, attentionpatch, transformerpatch
+from src.smplfusion.patches.attentionpatch import painta
+from src.utils import tokenize
verbose = False
@@ -25,15 +25,17 @@ def run(
image,
mask,
seed=0,
- eta=0.1,
+ eta=0.0,
negative_prompt='',
positive_prompt='',
num_steps=50,
guidance_scale=7.5
):
image = image.padx(64)
- mask = mask.alpha().padx(64)
- full_prompt = f'{prompt}, {positive_prompt}'
+ mask = mask.dilate(1).alpha().padx(64)
+ full_prompt = prompt
+ if positive_prompt != '':
+ full_prompt = f'{prompt}, {positive_prompt}'
dt = 1000 // num_steps
# Text condition
diff --git a/lib/methods/sr.py b/src/methods/sr.py
similarity index 96%
rename from lib/methods/sr.py
rename to src/methods/sr.py
index ed0e3caea9a06b1452f1d5fbfa6ea6c177d5705e..32c07cc191acb49dadd66af5d33644da901a3bc3 100644
--- a/lib/methods/sr.py
+++ b/src/methods/sr.py
@@ -11,11 +11,11 @@ import numpy as np
from inspect import isfunction
from PIL import Image
-from lib import smplfusion
-from lib.smplfusion import share, router, attentionpatch, transformerpatch
-from lib.utils.iimage import IImage
-from lib.utils import poisson_blend
-from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v
+from src import smplfusion
+from src.smplfusion import share, router, attentionpatch, transformerpatch
+from src.utils.iimage import IImage
+from src.utils import poisson_blend
+from src.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v
def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfe39f57e8bf9790a079b00bc8135b43649fe0e2
--- /dev/null
+++ b/src/models/__init__.py
@@ -0,0 +1,2 @@
+from . import sd2_sr, sam
+from .inpainting import load_inpainting_model, pre_download_inpainting_models
diff --git a/src/models/common.py b/src/models/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..764b48bc6a91eed2de06624d3985745234262cf6
--- /dev/null
+++ b/src/models/common.py
@@ -0,0 +1,149 @@
+import importlib
+import requests
+from collections import OrderedDict
+from pathlib import Path
+from os.path import dirname
+
+import torch
+import safetensors
+import safetensors.torch
+from omegaconf import OmegaConf
+from tqdm import tqdm
+
+from src.smplfusion import DDIM, share, scheduler
+from src.utils.convert_diffusers_to_sd import (
+ convert_vae_state_dict,
+ convert_unet_state_dict,
+ convert_text_enc_state_dict,
+ convert_text_enc_state_dict_v20
+)
+
+
+PROJECT_DIR = dirname(dirname(dirname(__file__)))
+CONFIG_FOLDER = f'{PROJECT_DIR}/config'
+MODEL_FOLDER = f'{PROJECT_DIR}/checkpoints'
+
+
+def download_file(url, save_path, chunk_size=1024):
+ try:
+ save_path = Path(save_path)
+ if save_path.exists():
+ print(f'{save_path.name} exists')
+ return
+ save_path.parent.mkdir(exist_ok=True, parents=True)
+ resp = requests.get(url, stream=True)
+ total = int(resp.headers.get('content-length', 0))
+ with open(save_path, 'wb') as file, tqdm(
+ desc=save_path.name,
+ total=total,
+ unit='iB',
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as bar:
+ for data in resp.iter_content(chunk_size=chunk_size):
+ size = file.write(data)
+ bar.update(size)
+ print(f'{save_path.name} download finished')
+ except Exception as e:
+ raise Exception(f"Download failed: {e}")
+
+
+def get_obj_from_str(string):
+ module, cls = string.rsplit(".", 1)
+ try:
+ return getattr(importlib.import_module(module, package=None), cls)
+ except:
+ return getattr(importlib.import_module('src.' + module, package=None), cls)
+
+
+def load_obj(path):
+ objyaml = OmegaConf.load(path)
+ return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
+
+
+def load_state_dict(model_path):
+ model_ext = Path(model_path).suffix
+ if model_ext == '.safetensors':
+ state_dict = safetensors.torch.load_file(model_path)
+ elif model_ext == '.ckpt':
+ state_dict = torch.load(model_path)['state_dict']
+ elif model_ext == '.bin':
+ state_dict = torch.load(model_path)
+ else:
+ raise Exception(f'Unsupported model extension {model_ext}')
+ return state_dict
+
+
+def load_sd_inpainting_model(
+ download_url,
+ model_path,
+ sd_version,
+ diffusers_ckpt=False,
+ dtype=torch.float16,
+ device='cuda:0'
+):
+ if type(download_url) == str and type(model_path) == str:
+ model_path = f'{MODEL_FOLDER}/{model_path}'
+ download_file(download_url, model_path)
+ state_dict = load_state_dict(model_path)
+ if diffusers_ckpt:
+ raise Exception('Not implemented')
+ extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
+ unet_state = extract(state_dict, 'model.diffusion_model')
+ encoder_state = extract(state_dict, 'cond_stage_model')
+ vae_state = extract(state_dict, 'first_stage_model')
+ elif type(download_url) == OrderedDict and type(model_path) == OrderedDict:
+ for key in download_url.keys():
+ download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}')
+ unet_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["unet"]}')
+ encoder_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["encoder"]}')
+ vae_state = load_state_dict(f'{MODEL_FOLDER}/{model_path["vae"]}')
+ if diffusers_ckpt:
+ unet_state = convert_unet_state_dict(unet_state)
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in encoder_state
+ if is_v20_model:
+ encoder_state = {"transformer." + k: v for k, v in encoder_state .items()}
+ encoder_state = convert_text_enc_state_dict_v20(encoder_state)
+ encoder_state = {"model." + k: v for k, v in encoder_state .items()}
+ else:
+ encoder_state = convert_text_enc_state_dict(encoder_state)
+ encoder_state = {"transformer." + k: v for k, v in encoder_state .items()}
+ vae_state = convert_vae_state_dict(vae_state)
+ else:
+ raise Exception('download_url or model_path definition type is not supported')
+
+ # Load common config files
+ config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
+ vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
+
+ # Load version specific config files
+ if sd_version == 1:
+ encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
+ unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
+ elif sd_version == 2:
+ encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
+ unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda()
+ else:
+ raise Exception(f'Unsupported SD version {sd_version}.')
+
+ ddim = DDIM(config, vae, encoder, unet)
+
+ unet.load_state_dict(unet_state)
+ encoder.load_state_dict(encoder_state, strict=False)
+ vae.load_state_dict(vae_state)
+
+ if dtype == torch.float16:
+ unet.convert_to_fp16()
+ unet.to(device=device)
+ vae.to(dtype=dtype, device=device)
+ encoder.to(dtype=dtype, device=device)
+ encoder.device = device
+
+ unet = unet.requires_grad_(False)
+ encoder = encoder.requires_grad_(False)
+ vae = vae.requires_grad_(False)
+
+ ddim = DDIM(config, vae, encoder, unet)
+ share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
+
+ return ddim
diff --git a/src/models/inpainting.py b/src/models/inpainting.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e08e84f761e1f400e3170ecdd2f09fe2df8b1d5
--- /dev/null
+++ b/src/models/inpainting.py
@@ -0,0 +1,66 @@
+from collections import OrderedDict
+
+import torch
+from .common import MODEL_FOLDER, load_sd_inpainting_model, download_file
+
+model_dict = {
+ 'sd15_inp': {
+ 'sd_version': 1,
+ 'diffusers_ckpt': False,
+ 'model_path': 'sd-1-5-inpainting/sd-v1-5-inpainting.ckpt',
+ 'download_url': 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt?download=true'
+ },
+ 'ds8_inp': {
+ 'sd_version': 1,
+ 'diffusers_ckpt': True,
+ 'model_path': OrderedDict([
+ ('unet', 'ds-8-inpainting/unet.fp16.safetensors'),
+ ('encoder', 'ds-8-inpainting/encoder.fp16.safetensors'),
+ ('vae', 'ds-8-inpainting/vae.fp16.safetensors')
+ ]),
+ 'download_url': OrderedDict([
+ ('unet', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors?download=true'),
+ ('encoder', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/text_encoder/model.fp16.safetensors?download=true'),
+ ('vae', 'https://huggingface.co/Lykon/dreamshaper-8-inpainting/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors?download=true')
+ ])
+ },
+ 'sd2_inp': {
+ 'sd_version': 2,
+ 'diffusers_ckpt': False,
+ 'model_path': 'sd-2-0-inpainting/512-inpainting-ema.safetensors',
+ 'download_url': 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true'
+ }
+}
+
+model_cache = {}
+
+
+def pre_download_inpainting_models():
+ for model_id, model_details in model_dict.items():
+ download_url = model_details['download_url']
+ model_path = model_details["model_path"]
+
+ if type(download_url) == str and type(model_path) == str:
+ download_file(download_url, f'{MODEL_FOLDER}/{model_path}')
+ elif type(download_url) == OrderedDict and type(model_path) == OrderedDict:
+ for key in download_url.keys():
+ download_file(download_url[key], f'{MODEL_FOLDER}/{model_path[key]}')
+ else:
+ raise Exception('download_url definition type is not supported')
+
+
+def load_inpainting_model(model_id, dtype=torch.float16, device='cuda:0', cache=False):
+ if cache and model_id in model_cache:
+ return model_cache[model_id]
+ else:
+ if model_id not in model_dict:
+ raise Exception(f'Unsupported model-id. Choose one from {list(model_dict.keys())}.')
+
+ model = load_sd_inpainting_model(
+ **model_dict[model_id],
+ dtype=dtype,
+ device=device
+ )
+ if cache:
+ model_cache[model_id] = model
+ return model
diff --git a/lib/models/sam.py b/src/models/sam.py
similarity index 83%
rename from lib/models/sam.py
rename to src/models/sam.py
index 23f0e19db83f324576b15480a968f601b1bd14e1..8ba94f40bb47cb11b45c78ea1726cde0af7bd329 100644
--- a/lib/models/sam.py
+++ b/src/models/sam.py
@@ -6,14 +6,12 @@ MODEL_PATH = f'{MODEL_FOLDER}/sam/sam_vit_h_4b8939.pth'
DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
# pre-download
-download_file(DOWNLOAD_URL, MODEL_PATH)
+# download_file(DOWNLOAD_URL, MODEL_PATH)
def load_model(device='cuda:0'):
- print ("Loading model: SAM")
download_file(DOWNLOAD_URL, MODEL_PATH)
sam = sam_model_registry["vit_h"](checkpoint=MODEL_PATH)
sam.to(device=device)
sam_predictor = SamPredictor(sam)
- print ("SAM loaded")
return sam_predictor
diff --git a/lib/models/sd2_sr.py b/src/models/sd2_sr.py
similarity index 96%
rename from lib/models/sd2_sr.py
rename to src/models/sd2_sr.py
index 96d26214a23421a76cbd01c7d15202772184c4c8..2116ca32e84bdc51c8132e1d701cdf536f376992 100644
--- a/lib/models/sd2_sr.py
+++ b/src/models/sd2_sr.py
@@ -10,7 +10,7 @@ import torch.nn as nn
from inspect import isfunction
from omegaconf import OmegaConf
-from lib.smplfusion import DDIM, share, scheduler
+from src.smplfusion import DDIM, share, scheduler
from .common import *
@@ -18,7 +18,7 @@ DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/
MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-upsample/x4-upscaler-ema.safetensors'
# pre-download
-download_file(DOWNLOAD_URL, MODEL_PATH)
+# download_file(DOWNLOAD_URL, MODEL_PATH)
def exists(x):
@@ -147,15 +147,13 @@ def get_obj_from_str(string):
try:
return getattr(importlib.import_module(module, package=None), cls)
except:
- return getattr(importlib.import_module('lib.' + module, package=None), cls)
+ return getattr(importlib.import_module('src.' + module, package=None), cls)
def load_obj(path):
objyaml = OmegaConf.load(path)
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
def load_model(dtype=torch.bfloat16, device='cuda:0'):
- print ("Loading model: SD2 superresolution...")
-
download_file(DOWNLOAD_URL, MODEL_PATH)
state_dict = safetensors.torch.load_file(MODEL_PATH)
@@ -203,5 +201,4 @@ def load_model(dtype=torch.bfloat16, device='cuda:0'):
low_scale_model = low_scale_model.to(dtype=dtype, device=device)
ddim.low_scale_model = low_scale_model
- print('SD2 superresolution loaded')
return ddim
diff --git a/lib/smplfusion/__init__.py b/src/smplfusion/__init__.py
similarity index 100%
rename from lib/smplfusion/__init__.py
rename to src/smplfusion/__init__.py
diff --git a/lib/smplfusion/ddim.py b/src/smplfusion/ddim.py
similarity index 95%
rename from lib/smplfusion/ddim.py
rename to src/smplfusion/ddim.py
index d95f57a7682c58aec564ca4dbecace7e1de0554e..6e1b5658c33ebcdc4aca0b69c6d3c3fd6fe63990 100644
--- a/lib/smplfusion/ddim.py
+++ b/src/smplfusion/ddim.py
@@ -3,7 +3,7 @@ from tqdm.notebook import tqdm
from . import scheduler
from . import share
-from lib.utils.iimage import IImage
+from src.utils.iimage import IImage
class DDIM:
def __init__(self, config, vae, encoder, unet):
@@ -48,9 +48,7 @@ class DDIM:
masked_image = image.torch().cuda() * ~mask.torch(0).bool().cuda()
masked_image = masked_image.to(dtype)
condition_x0 = self.vae.encode(masked_image).mean * self.config.scale_factor
-
condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().to(dtype)
- condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask)
return torch.cat([condition_mask, condition_x0], 1)
inpainting_condition = get_inpainting_condition
diff --git a/lib/smplfusion/models/__init__.py b/src/smplfusion/models/__init__.py
similarity index 100%
rename from lib/smplfusion/models/__init__.py
rename to src/smplfusion/models/__init__.py
diff --git a/lib/smplfusion/models/encoders/clip_embedder.py b/src/smplfusion/models/encoders/clip_embedder.py
similarity index 100%
rename from lib/smplfusion/models/encoders/clip_embedder.py
rename to src/smplfusion/models/encoders/clip_embedder.py
diff --git a/lib/smplfusion/models/encoders/open_clip_embedder.py b/src/smplfusion/models/encoders/open_clip_embedder.py
similarity index 100%
rename from lib/smplfusion/models/encoders/open_clip_embedder.py
rename to src/smplfusion/models/encoders/open_clip_embedder.py
diff --git a/lib/smplfusion/models/unet.py b/src/smplfusion/models/unet.py
similarity index 100%
rename from lib/smplfusion/models/unet.py
rename to src/smplfusion/models/unet.py
diff --git a/lib/smplfusion/models/util.py b/src/smplfusion/models/util.py
similarity index 100%
rename from lib/smplfusion/models/util.py
rename to src/smplfusion/models/util.py
diff --git a/lib/smplfusion/models/vae.py b/src/smplfusion/models/vae.py
similarity index 100%
rename from lib/smplfusion/models/vae.py
rename to src/smplfusion/models/vae.py
diff --git a/lib/smplfusion/modules/__init__.py b/src/smplfusion/modules/__init__.py
similarity index 100%
rename from lib/smplfusion/modules/__init__.py
rename to src/smplfusion/modules/__init__.py
diff --git a/lib/smplfusion/modules/attention/__init__.py b/src/smplfusion/modules/attention/__init__.py
similarity index 100%
rename from lib/smplfusion/modules/attention/__init__.py
rename to src/smplfusion/modules/attention/__init__.py
diff --git a/lib/smplfusion/modules/attention/basic_transformer_block.py b/src/smplfusion/modules/attention/basic_transformer_block.py
similarity index 100%
rename from lib/smplfusion/modules/attention/basic_transformer_block.py
rename to src/smplfusion/modules/attention/basic_transformer_block.py
diff --git a/lib/smplfusion/modules/attention/cross_attention.py b/src/smplfusion/modules/attention/cross_attention.py
similarity index 100%
rename from lib/smplfusion/modules/attention/cross_attention.py
rename to src/smplfusion/modules/attention/cross_attention.py
diff --git a/lib/smplfusion/modules/attention/feed_forward.py b/src/smplfusion/modules/attention/feed_forward.py
similarity index 100%
rename from lib/smplfusion/modules/attention/feed_forward.py
rename to src/smplfusion/modules/attention/feed_forward.py
diff --git a/lib/smplfusion/modules/attention/memory_efficient_cross_attention.py b/src/smplfusion/modules/attention/memory_efficient_cross_attention.py
similarity index 100%
rename from lib/smplfusion/modules/attention/memory_efficient_cross_attention.py
rename to src/smplfusion/modules/attention/memory_efficient_cross_attention.py
diff --git a/lib/smplfusion/modules/attention/spatial_transformer.py b/src/smplfusion/modules/attention/spatial_transformer.py
similarity index 100%
rename from lib/smplfusion/modules/attention/spatial_transformer.py
rename to src/smplfusion/modules/attention/spatial_transformer.py
diff --git a/lib/smplfusion/modules/autoencoder.py b/src/smplfusion/modules/autoencoder.py
similarity index 100%
rename from lib/smplfusion/modules/autoencoder.py
rename to src/smplfusion/modules/autoencoder.py
diff --git a/lib/smplfusion/modules/distributions.py b/src/smplfusion/modules/distributions.py
similarity index 100%
rename from lib/smplfusion/modules/distributions.py
rename to src/smplfusion/modules/distributions.py
diff --git a/lib/smplfusion/modules/ema.py b/src/smplfusion/modules/ema.py
similarity index 100%
rename from lib/smplfusion/modules/ema.py
rename to src/smplfusion/modules/ema.py
diff --git a/lib/smplfusion/modules/util.py b/src/smplfusion/modules/util.py
similarity index 100%
rename from lib/smplfusion/modules/util.py
rename to src/smplfusion/modules/util.py
diff --git a/lib/smplfusion/patches/__init__.py b/src/smplfusion/patches/__init__.py
similarity index 100%
rename from lib/smplfusion/patches/__init__.py
rename to src/smplfusion/patches/__init__.py
diff --git a/lib/smplfusion/patches/attentionpatch/__init__.py b/src/smplfusion/patches/attentionpatch/__init__.py
similarity index 100%
rename from lib/smplfusion/patches/attentionpatch/__init__.py
rename to src/smplfusion/patches/attentionpatch/__init__.py
diff --git a/lib/smplfusion/patches/attentionpatch/default.py b/src/smplfusion/patches/attentionpatch/default.py
similarity index 100%
rename from lib/smplfusion/patches/attentionpatch/default.py
rename to src/smplfusion/patches/attentionpatch/default.py
diff --git a/lib/smplfusion/patches/attentionpatch/painta.py b/src/smplfusion/patches/attentionpatch/painta.py
similarity index 99%
rename from lib/smplfusion/patches/attentionpatch/painta.py
rename to src/smplfusion/patches/attentionpatch/painta.py
index 0ffbe84f02eb433b0eb7303d55cbc632739a02a5..d96069302229a6e598dcfacea2d7b1b6004e779e 100644
--- a/lib/smplfusion/patches/attentionpatch/painta.py
+++ b/src/smplfusion/patches/attentionpatch/painta.py
@@ -9,7 +9,7 @@ from torch import nn, einsum
from einops import rearrange, repeat
from ... import share
-from lib.utils.iimage import IImage
+from src.utils.iimage import IImage
# params
painta_res = [16, 32]
diff --git a/lib/smplfusion/patches/router.py b/src/smplfusion/patches/router.py
similarity index 100%
rename from lib/smplfusion/patches/router.py
rename to src/smplfusion/patches/router.py
diff --git a/lib/smplfusion/patches/transformerpatch/__init__.py b/src/smplfusion/patches/transformerpatch/__init__.py
similarity index 100%
rename from lib/smplfusion/patches/transformerpatch/__init__.py
rename to src/smplfusion/patches/transformerpatch/__init__.py
diff --git a/lib/smplfusion/patches/transformerpatch/default.py b/src/smplfusion/patches/transformerpatch/default.py
similarity index 100%
rename from lib/smplfusion/patches/transformerpatch/default.py
rename to src/smplfusion/patches/transformerpatch/default.py
diff --git a/lib/smplfusion/patches/transformerpatch/painta.py b/src/smplfusion/patches/transformerpatch/painta.py
similarity index 100%
rename from lib/smplfusion/patches/transformerpatch/painta.py
rename to src/smplfusion/patches/transformerpatch/painta.py
diff --git a/lib/smplfusion/scheduler.py b/src/smplfusion/scheduler.py
similarity index 100%
rename from lib/smplfusion/scheduler.py
rename to src/smplfusion/scheduler.py
diff --git a/lib/smplfusion/share.py b/src/smplfusion/share.py
similarity index 98%
rename from lib/smplfusion/share.py
rename to src/smplfusion/share.py
index 3807657fdb54db02ff21f4831ad8dff0513fe559..b9b5a66966982ccbb987552ca992c11e35fec6dd 100644
--- a/lib/smplfusion/share.py
+++ b/src/smplfusion/share.py
@@ -1,5 +1,5 @@
import torchvision.transforms.functional as TF
-from lib.utils.iimage import IImage
+from src.utils.iimage import IImage
import torch
import sys
from .utils import *
diff --git a/lib/smplfusion/util.py b/src/smplfusion/util.py
similarity index 95%
rename from lib/smplfusion/util.py
rename to src/smplfusion/util.py
index 5cfadaf9a094146e67a0bc1fe12e4b27e9cfc6f1..ee8b2e540b4a8a4cb9aaba3227d4ef3c2a81c6b3 100644
--- a/lib/smplfusion/util.py
+++ b/src/smplfusion/util.py
@@ -1,5 +1,5 @@
import importlib
-from lib.utils import IImage
+from src.utils import IImage
def instantiate_from_config(config):
diff --git a/lib/smplfusion/utils/__init__.py b/src/smplfusion/utils/__init__.py
similarity index 100%
rename from lib/smplfusion/utils/__init__.py
rename to src/smplfusion/utils/__init__.py
diff --git a/lib/smplfusion/utils/input_image.py b/src/smplfusion/utils/input_image.py
similarity index 98%
rename from lib/smplfusion/utils/input_image.py
rename to src/smplfusion/utils/input_image.py
index 5b56a7087f9f455d4676bdec304f5533724067aa..32127b9dabebc511be786d659422fb006adeb4a0 100644
--- a/lib/smplfusion/utils/input_image.py
+++ b/src/smplfusion/utils/input_image.py
@@ -1,5 +1,5 @@
import torch
-from lib.utils.iimage import IImage
+from src.utils.iimage import IImage
class InputImage:
def to(self, device): return InputImage(self.image, device = device)
diff --git a/lib/smplfusion/utils/input_mask.py b/src/smplfusion/utils/input_mask.py
similarity index 99%
rename from lib/smplfusion/utils/input_mask.py
rename to src/smplfusion/utils/input_mask.py
index 755c297226b100ef2a83860c72ca364f62551ad2..c4b3d1aaad65612988f82f48f98792e5b40eb305 100644
--- a/lib/smplfusion/utils/input_mask.py
+++ b/src/smplfusion/utils/input_mask.py
@@ -1,5 +1,5 @@
import torch
-from lib.utils.iimage import IImage
+from src.utils.iimage import IImage
class InputMask:
def to(self, device): return InputMask(self.image, device = device)
diff --git a/lib/smplfusion/utils/input_shape.py b/src/smplfusion/utils/input_shape.py
similarity index 100%
rename from lib/smplfusion/utils/input_shape.py
rename to src/smplfusion/utils/input_shape.py
diff --git a/lib/utils/__init__.py b/src/utils/__init__.py
similarity index 88%
rename from lib/utils/__init__.py
rename to src/utils/__init__.py
index 2f07d571faa4022a8312a7094ea9a2706bc89e11..7c0cdf81e509bf0339c91dbb4351d4dc2d12e62f 100644
--- a/lib/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -1,4 +1,5 @@
import base64
+from typing import Tuple, Union
import cv2
import numpy as np
@@ -74,3 +75,13 @@ def image_from_url_text(filedata):
filedata = base64.decodebytes(filedata.encode('utf-8'))
image = Image.open(io.BytesIO(filedata))
return image
+
+
+def resize(image: Image, size: Union[int, Tuple[int, int]], resample=Image.BICUBIC):
+ if isinstance(size, int):
+ w, h = image.size
+ aspect_ratio = w / h
+ size = (min(size, int(size * aspect_ratio)),
+ min(size, int(size / aspect_ratio)))
+ return image.resize(size, resample=resample)
+
diff --git a/src/utils/convert_diffusers_to_sd.py b/src/utils/convert_diffusers_to_sd.py
new file mode 100644
index 0000000000000000000000000000000000000000..690fb4f449e4c0979bc85fd6793828b011698322
--- /dev/null
+++ b/src/utils/convert_diffusers_to_sd.py
@@ -0,0 +1,329 @@
+# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
+# *Only* converts the UNet, VAE, and Text Encoder.
+# Does not convert optimizer state or any other thing.
+
+import argparse
+import os.path as osp
+import re
+
+import torch
+
+from safetensors.torch import save_file
+
+
+# =================#
+# UNet Conversion #
+# =================#
+
+unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+]
+
+unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+]
+
+unet_conversion_map_layer = []
+# hardcoded number of downblocks and resnets/attentions...
+# would need smarter logic for other networks.
+for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+hf_mid_atn_prefix = "mid_block.attentions.0."
+sd_mid_atn_prefix = "middle_block.1."
+unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+def convert_unet_state_dict(unet_state_dict):
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+]
+
+for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+# this part accounts for mid blocks in both the encoder and the decoder
+for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+
+vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "to_q."), #
+ ("k.", "to_k."), #
+ ("v.", "to_v."), #
+ ("proj_out.", "to_out.0."), #
+]
+
+# Original
+# vae_conversion_map_attn = [
+# # (stable-diffusion, HF Diffusers)
+# ("norm.", "group_norm."),
+# ("q.", "query."),
+# ("k.", "key."),
+# ("v.", "value."),
+# ("proj_out.", "proj_attn."),
+# ]
+
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+ return new_state_dict
+
+
+# =========================#
+# Text Encoder Conversion #
+# =========================#
+
+
+textenc_conversion_lst = [
+ # (stable-diffusion, HF Diffusers)
+ ("resblocks.", "text_model.encoder.layers."),
+ ("ln_1", "layer_norm1"),
+ ("ln_2", "layer_norm2"),
+ (".c_fc.", ".fc1."),
+ (".c_proj.", ".fc2."),
+ (".attn", ".self_attn"),
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
+]
+protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
+textenc_pattern = re.compile("|".join(protected.keys()))
+
+# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
+code2idx = {"q": 0, "k": 1, "v": 2}
+
+
+def convert_text_enc_state_dict_v20(text_enc_dict):
+ new_state_dict = {}
+ capture_qkv_weight = {}
+ capture_qkv_bias = {}
+ for k, v in text_enc_dict.items():
+ if (
+ k.endswith(".self_attn.q_proj.weight")
+ or k.endswith(".self_attn.k_proj.weight")
+ or k.endswith(".self_attn.v_proj.weight")
+ ):
+ k_pre = k[: -len(".q_proj.weight")]
+ k_code = k[-len("q_proj.weight")]
+ if k_pre not in capture_qkv_weight:
+ capture_qkv_weight[k_pre] = [None, None, None]
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
+ continue
+
+ if (
+ k.endswith(".self_attn.q_proj.bias")
+ or k.endswith(".self_attn.k_proj.bias")
+ or k.endswith(".self_attn.v_proj.bias")
+ ):
+ k_pre = k[: -len(".q_proj.bias")]
+ k_code = k[-len("q_proj.bias")]
+ if k_pre not in capture_qkv_bias:
+ capture_qkv_bias[k_pre] = [None, None, None]
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
+ continue
+
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
+ new_state_dict[relabelled_key] = v
+
+ for k_pre, tensors in capture_qkv_weight.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
+
+ for k_pre, tensors in capture_qkv_bias.items():
+ if None in tensors:
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
+ new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
+
+ return new_state_dict
+
+
+def convert_text_enc_state_dict(text_enc_dict):
+ return text_enc_dict
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
+ parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
+ parser.add_argument(
+ "--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
+ )
+
+ args = parser.parse_args()
+
+ assert args.model_path is not None, "Must provide a model path!"
+
+ assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
+
+ unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
+ vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
+ text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
+
+ # Convert the UNet model
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
+ unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
+
+ # Convert the VAE model
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
+
+ # Convert the text encoder model
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
+
+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
+
+ if is_v20_model:
+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
+ text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
+ text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
+ else:
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
+ text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
+
+ # Put together new checkpoint
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
+ if args.half:
+ state_dict = {k: v.half() for k, v in state_dict.items()}
+
+ if args.use_safetensors:
+ save_file(state_dict, args.checkpoint_path)
+ else:
+ state_dict = {"state_dict": state_dict}
+ torch.save(state_dict, args.checkpoint_path)
\ No newline at end of file
diff --git a/lib/utils/iimage.py b/src/utils/iimage.py
similarity index 100%
rename from lib/utils/iimage.py
rename to src/utils/iimage.py
diff --git a/lib/utils/scores.py b/src/utils/scores.py
similarity index 100%
rename from lib/utils/scores.py
rename to src/utils/scores.py