Spaces:
PAIR
/
Running on A10G

Andranik Sargsyan commited on
Commit
bfd34e9
·
1 Parent(s): 919cdba

add demo code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -35
  2. .gitignore +7 -0
  3. README.md +9 -7
  4. app.py +350 -0
  5. assets/.gitignore +1 -0
  6. assets/config/ddpm/v1.yaml +14 -0
  7. assets/config/ddpm/v2-upsample.yaml +24 -0
  8. assets/config/encoders/clip.yaml +1 -0
  9. assets/config/encoders/openclip.yaml +4 -0
  10. assets/config/unet/inpainting/v1.yaml +15 -0
  11. assets/config/unet/inpainting/v2.yaml +16 -0
  12. assets/config/unet/upsample/v2.yaml +19 -0
  13. assets/config/vae-upsample.yaml +16 -0
  14. assets/config/vae.yaml +17 -0
  15. assets/examples/images/a19.jpg +3 -0
  16. assets/examples/images/a2.jpg +3 -0
  17. assets/examples/images/a4.jpg +3 -0
  18. assets/examples/images/a40.jpg +3 -0
  19. assets/examples/images/a46.jpg +3 -0
  20. assets/examples/images/a51.jpg +3 -0
  21. assets/examples/images/a54.jpg +3 -0
  22. assets/examples/images/a65.jpg +3 -0
  23. assets/examples/masked/a19.png +3 -0
  24. assets/examples/masked/a2.png +3 -0
  25. assets/examples/masked/a4.png +3 -0
  26. assets/examples/masked/a40.png +3 -0
  27. assets/examples/masked/a46.png +3 -0
  28. assets/examples/masked/a51.png +3 -0
  29. assets/examples/masked/a54.png +3 -0
  30. assets/examples/masked/a65.png +3 -0
  31. assets/examples/sbs/a19.png +3 -0
  32. assets/examples/sbs/a2.png +3 -0
  33. assets/examples/sbs/a4.png +3 -0
  34. assets/examples/sbs/a40.png +3 -0
  35. assets/examples/sbs/a46.png +3 -0
  36. assets/examples/sbs/a51.png +3 -0
  37. assets/examples/sbs/a54.png +3 -0
  38. assets/examples/sbs/a65.png +3 -0
  39. lib/__init__.py +0 -0
  40. lib/methods/__init__.py +0 -0
  41. lib/methods/rasg.py +88 -0
  42. lib/methods/sd.py +74 -0
  43. lib/methods/sr.py +141 -0
  44. lib/models/__init__.py +1 -0
  45. lib/models/common.py +49 -0
  46. lib/models/ds_inp.py +46 -0
  47. lib/models/sam.py +20 -0
  48. lib/models/sd15_inp.py +44 -0
  49. lib/models/sd2_inp.py +47 -0
  50. lib/models/sd2_sr.py +204 -0
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+
3
+ .gradio/
4
+
5
+ outputs/
6
+ gradio_tmp/
7
+ __pycache__/
README.md CHANGED
@@ -1,12 +1,14 @@
1
  ---
2
- title: HD Painter
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.11.0
 
 
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: HD-Painter
3
+ emoji: 🧑‍🎨
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.47.1
8
+ python_version: 3.9
9
+ suggested_hardware: a100-large
10
  app_file: app.py
11
  pinned: false
12
+ pipeline_tag: hd-painter
13
  ---
14
+ Paper: https://arxiv.org/abs/2312.14091
 
app.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ import gradio as gr
5
+ import shutil
6
+ import uuid
7
+ import torch
8
+ from pathlib import Path
9
+ from lib.utils.iimage import IImage
10
+ from PIL import Image
11
+
12
+ from lib import models
13
+ from lib.methods import rasg, sd, sr
14
+ from lib.utils import poisson_blend, image_from_url_text
15
+
16
+
17
+ TMP_DIR = 'gradio_tmp'
18
+ if Path(TMP_DIR).exists():
19
+ shutil.rmtree(TMP_DIR)
20
+ Path(TMP_DIR).mkdir(exist_ok=True, parents=True)
21
+
22
+ os.environ['GRADIO_TEMP_DIR'] = TMP_DIR
23
+
24
+ on_huggingspace = os.environ.get("SPACE_AUTHOR_NAME") == "PAIR"
25
+
26
+ negative_prompt_str = "text, bad anatomy, bad proportions, blurry, cropped, deformed, disfigured, duplicate, error, extra limbs, gross proportions, jpeg artifacts, long neck, low quality, lowres, malformed, morbid, mutated, mutilated, out of frame, ugly, worst quality"
27
+ positive_prompt_str = "Full HD, 4K, high quality, high resolution"
28
+
29
+ example_inputs = [
30
+ ['assets/examples/images/a40.jpg', 'medieval castle'],
31
+ ['assets/examples/images/a4.jpg', 'parrot'],
32
+ ['assets/examples/images/a65.jpg', 'hoodie'],
33
+ ['assets/examples/images/a54.jpg', 'salad'],
34
+ ['assets/examples/images/a51.jpg', 'space helmet'],
35
+ ['assets/examples/images/a46.jpg', 'teddy bear'],
36
+ ['assets/examples/images/a19.jpg', 'antique greek vase'],
37
+ ['assets/examples/images/a2.jpg', 'sunglasses'],
38
+ ]
39
+ thumbnails = [
40
+ 'https://lh3.googleusercontent.com/fife/AK0iWDxaRlJeZGIBuB3_oGKKhd5buKaL3kJ6moPp7r6svDYFkehrv5XKyF6mj_Pqy3yV-qDQQZj_n8CMpuYH_iDy5717rPL-qpXf-prcIv2pET4LDjFInFQVLoxuurB3_7fugCUogt5ZYIGlTgSbirJkHDqN5min3riiRJLd0ZuGN-ETDDCs5e0wohdX_Wl_Kv5RAdYjZqWFGfKcmzCF-1ny6bjCab-1hcDzIaokggkl3INTG23nhSLWhNB8EdeCdfkQmoF5fROfCPe0Lsvk6RwlAr-ZQ9jaszJ355oXOz4Y-IQLfWvnyfdyxQ02anJF7DiBZbfhH4WcA6faK7Pbjo7RIt-nTN6THwTGlEmxHEzrO-1iuy3j4UfMxPB7r-RjDrn7M9KwLuWSybIJ3dgSx_DF9OqIYkHeG_TSvs1Vi-ugbq9E0K20HNXkFlhe6ty7Ee5xr0nNqhD6lVr-6wbBvI1SJUQg03KoJkduYbTTQ-ibGIwCu7J24kdlo-_d1xBWC01zeW2MfqPjpdNHUtnPAF2IuldEsWMJMHEpWjmXwYfM1D1BmIcOuRLvdNEA6IyPD6VCgLVD27MdLWxdKHgpSvTZ1beTcMM_QuHV2vMJnBT4H7faPmLBe9yMTKlIe7Yf8fGiehOrfEgXZwNyRzmbOyGKfNRKiSVuqeqJt3h1Vze0UKFzAx3rYnibzc--58atwasvdKY_dj8f0_QKR0l2vzWogqh3NdhO9m2r77ni5JXaik25QQzF-BPe70ikVgEomHa4xySlf-Gr3Z1v_HWDa4kYg2KE3P0WjqUD_cmdylbZ45TIOkGAU1qmiEcTCs_wOkOIfCC58Z9P6Lff_BRxc_lhut8hp-Pe8tfRXhITYFFRthXforXyDuqzPmWBBz2EnUHqxa1aYOo4WeQc7KTXK2kF36qAzPwm2QFDreqV3QmS02Gev8MCz2U3bQQa9H8VSB4ZGhrzNWIwG1R8YU_G1Xb4BAgjYEZ5qX6WJaQrjuD6_Zw_pDxRew8t0mjj_tjCrmoZjpxjHsgtudH4IBah5xag0bGdUSThnszYJPM5g1weldimKE63HqaQTG-IN_N51nBken4K_0-liw73ABzUiA6EJzqCQKEoT2pejNTN88N9RFXXB5ZJ9x0NvtuMcy_JsrsVArfA5b7m4OGbwF6b5wN3Ag3XNQ3d58hVJ_Hw99HNIvrjTVCVmU0-DsYIu_njIfyjdqps1cyv6_f23F0X-q4ZsbooPoNg2lc3oeFtC2K58Dgr8JsBjt33Lnmra3YG2nBh3lkIycCvrUOS69xo8Aq9z2ODklCd9soUUQVNa2XKKsMofRi8ESGwuiWYKWdSI7XAXi4dbzFQhWwMRfsSAk3KGMpnUnlOd8Jx68fiGMwTKCCsgIPgo8mFZHSaN29ipzNoACG1bd0ZFX2qxa7UfMnnRobl-AcFLzTtOYD2T2PcIKKjnxy1-gG1Ff0mgGS-BCDl9TtCscAoKcXqTYWCSY5otnfVPcjgEFw_0KEmIbmrf6Rpyr7YGJFPTQfDRah5Ro5LhBIbXhFAkyyvWQZCecGWi3lRAh-pSm2zuGl0nVdykgMqzVwiat8_lhnHMBNRg9xWyLWrgw3bWpdf64Bjgshr0V1XV7kikgRDpyPMJPDNzX1jVASiL5S3dmWDnd49tdnaKBUXGAIWmTyUWfs1bDFZZoOfrmLrnLLUT9R11lMO0EgkZ48_UU7CMJKwgq6hegh-ErsV6S6_SLa2tQYqYlSNUESO7jw1w4_T9KTkM2635QlPH-A81yvgrSsRo0lq3uPtFEHP35-dfGT63yFd=w3385-h1364',
41
+ 'https://lh3.googleusercontent.com/fife/AK0iWDw7CKs6uhv5cJnJ48zW_5UW3IjrXwWqvemhbJGzB_xNhHgNyt4NIP0xPkCK_nyGFEsbb-4iaeWnoJbbeQxU4yBnnuckhD5d-t1qgOuFke8NRVaGCDgtDimbwNY7Jkkc9pj3pMd33P7UhBg7rVueTL7hmAyhtT_wLD9B5a50VfM_T9ErEeYUgsiEYPoE-msuALr2Sl5t0jXttYNt8R2JHdXG8Y6EegiQaEdZjZbJxdwA5S175pCmrOvVVBH334NwarLF_HDr8ESIVEkYxNtLDwb0QPDZSDeci-JS2WZcvEQUmClbxuk3hjtOR0Rm6VX_wXo0-nczybNDdTiZFsBc1Fj1L_08VLYP1OIz1fbcRF7iv2luyGmgUijGaJLr0wtW_C65eXK9aU-NJ5KVt67JT-TncumCqBpvn1-msbzbcL8krOmYqiZthOcvnrIQHeW1jIUVD_zMZADmuRdpjxO-Dn3yC3uev0ve0u4b3vEHR4iiLx-4Jf5DrMvpfdHHIL76verfqhLkiz3gtZENg1jsRTcjH35AWspo3-lMJnZqMP9wzvw7ubBBP1QrblSwfflz4aAIzmh-WisQ5UWMSltQ9DwpT0UggiXe0JbbtloWkbo2_VpaaMJh7VhBQbvRbFUxCm__UPVfglTfdiNv1m2777oGwyDbv682I_qGDW5nG91D22pqo9-enRGLvCnN-STKtWTnJQ5Qod2QYRAEI-1IR9_h-UWtCyBLLpqcGKxkHaLZvjpDmTdiPhVkoG1irracCLbPJGvrrclorr1k0nqTIhvVVH8pNZdx-yCK6KFFGNAz3aSsmxGJURWEt4TQEVDLfet8iFzuLfD2Tg9gJ4kozB88G8PrfLnGopwlPO4y6GkcfJggtshQr7yuo3xnxtci1FlLhOJNkMKCg0xhL-tVDIfWsMCzSIk08XtGgitU7DK8CE7DlkA7chCuq2BHUTeFCEALF-5DLFxa-3ru30gtNsbd71sNH01S8qhTVegWM8yQwwXXhIVqoWM-e-SSUyQldgklvatev1V6iDxLh9u2nbhoxxCcVZvyKx7rUyhSEEQzXF6nYNQn5HaAmp4jW1rCoVVvVMNP_wfvY76vamAPkw22z_nylCpbKW0BQmTBZStyoB4u_lZANVkL0fbyWoxyfZHKOoDgyTKdzDWmbeiIPmr7nvF8QWYsq160pKzIKIhj9ccF2NqyXHqjaDzmqVfZ6EUg0sJ8CeiCUskbSnyxZHVeM5GgG7cj9xMncUUsdvd0WTw3g0aqstb2MAxF3ucO0EXP8dAWqU2XrJ6E0rA51Jdn16CW9KV5HMGEcB5Y6bN-Md8-FT1r_0NQq8ZOJ5nmTzg5JNSTEw-FOBwSvukEMDmbf6ADs1emocozcqk9KiYv1ii3niD58XPIhQWfBcgyYBsqbJQ5x-UqijqcihN7i3RgQMOVJem954MKU8D-DzSet2FUcGWseyzF6Sr4GJQn5g0rOXtFP2HwT2kTXy_pqad1ukQkovfQG30gbRvIRAzMmigVJskadvadasF7Lc5eo5Pm7CPPZQ-ZJUlZuowbagWm0Kz_T-PIreVNJ5WxxQ4w_HH27QqKxe_LZvHJ9y74O8oVCywVJeKQirxDc-yKuUNJ7vagJ5a-paB_cXf6RECpu7caFM6U3g43xVMtxDf-3d0r6G7zs0oPHb4uDbovgLzqooAnjNY3_RGWldXUks0pmCqFbEhLgl8JJd4Lyv5mt34WTDqPggRnraAV5AEtZ8NvGNDWk-ItblCZQ-neXYVj0z1p1rWSJiG1XaDa8ro7G9-7XJtwfJDnFZsf=w3385-h1364',
42
+ 'https://lh3.googleusercontent.com/fife/AK0iWDxK7tdzr1tx6g49uS-pyR-7BdPJfk0_wWoErGJiQPMRbggXquguwth2I35go_GFsW8TUoj17p-jtCKCi0ryGH2gTGNvZEki8xfPn1aroIxOXRt7Ucl64Lu2cyixNSBqyJUBQEOL6LzY6DlwXJ5SVxRRvJkj1Vz-yJYz4mNA12YdDFPj7UkLs_7PEDRanydfEVFhIoiMrYitDW8fGsqn53PHaB8GoxoQHAXFyf82HPfhqqgKNRUR9yz1LAN5q72ERzxg5h4apStr16aLzijuuJF7wmtltf3HdTIMwAAd8Blsgk0d88rNZzItVZMdCeHgAyvoD7JKUCTWWJbu-uXf1PRd6mIg11OzweG04c-uloZFvJdF0pdfcfocibniKt-filYASFN43KYtO7Eyzc-YUHJn1qm73eDr0RkntHy1kQaDPmlRvlZrtbDgZIEau_FsMa-BRNq_jHlTDkyRR_8dyBf__na0I45lcChJWZCOfinbQwzQfiryK13yBB1bQDGk47JM5PTIEBK2JvsOEwezA9geZWhi3oIM6N4b1YXzvWbDYbN-cQvrBd6cuXoecUL-i_Qqda2-xGByMP4BUI0J1qF3mKwzTLvzZAk1f-zCUNHpOrzF2-WsXCUoL4R_t7MOZ2OeT7a5pCAnyiFX8Vdq3x9QzrJjjTqvttS5ElXVUGi5NBVeRV0hbXI5XLMKmiu-lWV3625eosYK0FE6hmFZk9ZAeSkrLDjkh9k1XwpB2x83kcpC_tkk25G9czy6utoNhC2YHc1uVvnUjZ0DjK2d0naLVOCKOCVRcTSvN3kt7PaLASArW5vOzZX0_LcjuhHSzK63mok6XvXXZ5AOiF2fDhUNUl1W8TaDz7aIssbibuLXxjpluone6WJglRrxAkpwtIw2rwk73icupAeVG6Frx2QctrHC0vLG4PKSfrIrZRiJC52Oxpp2dAMnhj76CIwX9wwbj86uSyIuxeIviiKumwTolZjbOKhPrHY-ZOydVDn_ZUsln4mOfDZXwUl9p1CanzpLCUsZTXcq12zTbHJRPwNw6SH9srQc3cTYYsWpBu77VmS5zVkIndUPItXWmUqds_2AI8LCUSWE9NVHiCRSw-B8J8j3SkqOD95-np2cMxDxoVX4nD11CwBr2W0xWS1kqc4mZL4aHdwhKdDcSKk6EGG4kmY1Eq1RsYxc0I08TJK-_nrbWTgA4NDTjh-oJpFFiF11ZHbEksKlSWBhH-MXF-0zxmiar-EIhAQe4hQX_suDty4GXxMwzcF84JthDAvWC0tGLJV5WQGLkLvBkTODythSNoN7l5HjnzoJf2JKI5AF8W3HJJWgntD0F2iFe5K0Ik_huovMbzGjkgPZDvWzrYpA2V6VZyPI2q08axvAfaevPrCd7G8Jy_gliK3hj3qxjIvqJXBUSO2puqum-TlUaSYgbjhWUCLXKE1BH8RRDt0brLklVvNwGpHxG6Eg9vmzp00fV3qhjoJqokTYOAxcAumtfUpyJFVk3cvpgm4826o4kKUM6eXRz8ke1L34YI9Rtf7Ppk2864hIBIy5xA6ajIpI_rIXWG8ogrhEp9GMbXGkLGiLMfd4t7P50JynPc9hYWfachaTlQfQWuz1NtV3ZGdTaEe5y_SpEwWG-UQ-_MglvY8p0EYNjGc6yhu4oEx9D0RN5H4QA2e0YjmVybG05Hfm7LwEBMBvHb8GHnpHGTxL-WDlZtU8DqDmGmhv5l4npcLnuVmSE4JcEtRJ651ZYn88wD2ghTAJtcEuyFW3nzZ2Na3z0LW4Y8Y42YdWb3hmgGsJAYBIa9YC_Cmh=w3385-h1364',
43
+ 'https://lh3.googleusercontent.com/fife/AK0iWDwPG9L-g7hX4NKVZ45PuqnRkQUV3XZlMEWZxXASV2zYVytOhbhds--yBA0ZXUzxeEAbpGa3qCl8sXlu-of8ZjVVNSdSen-zgpoN0BT-R7JGmTqjT9aEdhnysy85Gqr4e4A5KPrInLLvCCkFSHolBEhu-hO1u8kEZ7aDuk8FsPvchP1SvqjnuWtY_OCYa2FpjHH_i24cITjs58nlDxNTFdQQNCnX4KLJQAzh5cKxtG_7pqoBfBzpXBvVDwoxPafbFn0X0u8oJ7V-VOh7faO0JtJrsfcdS06tvw_J29RbDdWojpUS_pRUtA8w9Z6flNlbShj4Ib8X1V6veQLrEYWi3SojO5tfntpl5bKXPGgWhsT8mykvfhg4Tq-Ti9kyPBDPQqEqf0ll9-wFHYoAWSMsmCvZ7kakuDMH1766rOk8QgqYbDr6kMuJw_OFBR2qX7DaQw9XnJqGv7J3guCj-yU5vXes4sOtZ3n1IOheXnlvJL79KhXbsLgznYBjC9uv0fDLuqKaFL34YisLY5xY_2zi_uzTc5BXmtoAFH8otrUVUTVyt6sEPaCtjyPzG2uoSq094QFY_FxagW0E6OYPqskBUwPzg_Wc2eBwazaGy4MXEoDzgvK3PId5N4MWqU232uHsUrEEaCUUKm8-KX43c0C6O2daqjwsh-bxKOIic3pHqDoAOcq-QA83qB7pPyWwGddsaOWRIdzf-QLrB55YuvTeTmOEL84m2YDtxNEHWdYcnXYlZEXAex1xMOfqkbCQGM8jSgC2di_794HKMpsNtKwYH31WoI-Pl73t4THuq9CX1pWYdjhH0ss1j3PUMJ4hELE187E2m6fhQLHNNSHRajfesIwPVP1FgP0W5o-AKLaC7o53R0XrOZ7SqYjOua5TX0RXyJE6ceZATyiZ_2tzUpc4baqRGJb88vr8dhfECh_1J7O_8ufMvYL47HWhVNqltcIGjujtIXp64XEqShlut8TkqtB1-3OqE7gNwfiP4859pPkk0q-kfIdaLCjqB6PeqNFNgcX0dfK_-weXpM-LjrgNM_aVlQEBKwTDjJteUHY3pFTrEU8FoPnPeIjbU_rGIcwA3Lf_NE0CJUhKG1gzKYWG00CCg0srvQRipJpVrW-4SkUrxtW28iq6kVRMER83sN2RnVi7ugEuZ3S-OPdgXkUglk3bIz9ehDcFfQrll0iCabIGQxrN9-7GfKSpi3j2nU0aLeZ6DxgQ7x9f9f-hUgxM97i_90SX-S_M6Lo6bA28RB515HqqXc8FN72Wsp_XlFp3oTb4cJtOI5SSUYkdxGtrn0AQWXuE0Rp1DEaUDEdVbu757FQsct4Us0jLByauEfiZcQS_lGTjzGQxud-2NMKIeIYOWAxBGz08eTWv4b2k_IsnekNnXAHzP5WCYFoNtDiGbgURj7QWz6Qh8HbTsFJCrI3mPYkTN2jMUwviTUKpeElUwqDeA_DhT5tiHa7ldtdzncNzpxH-J75asams5J_O2W5dJMN4PYUxGmVw5mhWEClFo2stMJOfPkrmaga4D1OXc_C3Utf6OWB5CBBHGNjfAekE3QWm7ibEtwC91g1pIujyCUEYVs8YiFi0RWcMsmhWG2yrghA9Hu3kERWuVT0nHHfLRx8L0_PjlQBkNijUjK_gI-C66729qLNsmdwZxym1JCFgV4xRT7Vu9EQGL3tyhbOLRWYlHAjBE6itM-DM5T7idIyWamFBb9Nt6ZFehCpslKzEHy2VtEyiRlZ2z9kH-IEIuZ2qCu3kiGL8m6yQTyboTCI7LNLYbSn497VR8h3WkcqWqlVlWOzGUikL=w3385-h1364',
44
+ 'https://lh3.googleusercontent.com/fife/AK0iWDyoheia4yOtefRVFWnA5EHolnS-xa4pPy0yL4Wb3uIr1-mSHxQ54i0wKr3lk6uhAm5qjJ9UVqwut3UqR54log5XVnaeNu_3y9Gsn44quU_HsGAY0-84HygWkr9Ld0_Dt1JEefD_f82Ijp1TXQf1CS6IbruUOnrljFOraQ2Bu-1To3p2Pk-T3tU8xVCxU3pr6zvFz9UYHFTss8_Xw70ZtLRMT4x4suHtOaSPI42VTq_T8HEnm04Ie_0Yh0Ri2_P-qsaP2ysz-Wnw4Ykbj9nc2VIDqtvCRwti36mlNheyg_8xOLD9sMNWDu3PXoRtn7aBUpw50GMCqeGUMAUMPKhJuZoTDdQK2SVHJ8QwNhYcC8mhAmRFvt85hvuT6NyZ00SeYwyj9_rux77vZThx5ioDoUAH3CQBcwgH82xahatReym7ehL3DXm3JLHDdQLbRM6xvsE0X0MsMXkuNlx8wn6HyteW_yq8fK3wQJ_3XLh-gK5YOFdvd08A7IIK6qi__-o8nvEK_hfHgMMS-O9eT9acfa2Sr0rGGNvoUlpljVOyONyQftP2nGD92R51K2Xcq9oV2sjSu8TDDel2t2DYr5gMB3FsDfgEhQEWE-O9fRLkzIZnOTTAcUDoS-b3R_kB485Ry56FzKFbz3w3tvHhwJ3W-sqvygb8LDoF3qjURWrf7Pau6UMjTSPH6FTjbVzZeWITKsSnA14xA2wj6xi9Bp3JkCOsT0qOfXrPkK7uT3H2U1M8uqFjpNTj6u3tZyF8GgueprmH13rJjYst9d_vevhpXpXIhSDvAbJuA3xG-YNr_SG8BMXpi7N3NC0VHmBXhl_wDBVUnAD6VVqqNtXzB-6NdZzjZKnxApDdi5SGp8C9kDd6bkaXUmwG__BRNVqdbchMw3H1re5t9VxiqWTelfGl6UqAX1W8RzQR21bgu1x7EAGbVsC_UpMxeDeJq9PprMF9cCRC9ziT_H2-ubctN7O9qPpADPT0nqWN28vH-9CqB4jPeBqYwi658twIpLwRFvEukajsrvb-OqmeesnT9QpCPEpL4G0HrjB9rkRX7g-T6q3kqbGfvnqgW4Q8ilOUnkEsFK3qLCIVxwp6B3t2yg8XSOsM3tnzsA1Ua8MofFKqvmwaq7QbIBcMOHa50I1fxRMs4YEVLgu89fSZxRrKSr_8oUXcRWiqSgF-pLLU37GYMrn3yXUJxxO4bXiEifeK8id5H1khO-8ZEBXzwuZBQBYLXbCkCou6enZu98tfPA-prr_NpKfSZM4e0clZWhjo1761-BmJSlIG0JrTo2N2cKhVz-WM5BZjVr1FPYOri5fIjORUL17_RbqMw5MefYN6tLPnpVrOSUmKW9bVgdFdOVpj9Wg5lZxxswAM5qK-wOzEjfdBCW0xjKzxD97zhszCKxc7Rj2uoaJzk9CeauU83LYcihkyMHn5IKhLeAou2yKEwgqXkU0LUObdUxqjavnVgVMcVYRds_j7zpmpKNT_KOV9s2jus8aptJl8sXZ_Gzw0vi7wC_AuAGfmNCsZBEFhn_b_ZgONqdaR9EKwP0hVRDNw4dZkYn9MGqeiX77I40eEbnwVbKaUZvK2Nrt7ukjkgSP6FdvZFfs-aVIUFMc6rBAknAFDFHPzFYcy9ANPDgVAlms8fO5GGuid8kgpxtjSoneUG_A9Az6JY-suY8zFr6mDtJC1LuY9ftIKW37ZaMtutqhbBX5b5w4DLnO32Uv_ZtK1nbV8E4T0KY6qHm2kZiIYuYCsHysbvNW4MODUihrpDh1UfILdMEsae6zGpCTQ7gnKlrB7QJ5Ig28y96Om=w3385-h1364',
45
+ 'https://lh3.googleusercontent.com/fife/AK0iWDxriikbWcV-sJ5xBcy8xcJHsRC9EEBmCimXzsrBiSOy-UnRwSoGhdXNqnB92vMZk1LiOn0h3KbBMgbFr0I0SmQG5TwG__bM8AvtMrXA-DGAaNTktO1JcSPb-wVr_OepLK6P1hHyGYSvcGDdF03pIFNcKxN6QLCZB6rgFLdaWd72z3Dx8eB8QLtU8P_4G4sT2oJ0hAUlz9mKz9lTxWDWl0-1ufKUctmvWfU8EjNuQQzckJA0YwvV2jDl4ZA1r42UHDss6dy6hkjmPoDZN-p2UW_3Ju4vOVAdv0Pf73demNu-L1LALuK0rq4d7OHUaAuP0bubXJAH-wsuVervwPQDEsmBwR-FdW8jfdppKxy4MC2ISf-eyLmsTYR9dLPIlKkOAHh_84vdLeWdGtxs9gES-jhrqOiW-brFtIZoKbH_nR1yLeq9IJ7Z7-GJk3PVi_Ex7gT9WJyIuySNi6s7GH6AnDFf9wfHkzyJ1qLDKMddrNi4GfEyO89S0yScxZFW7hERAH7T2_1YqkeMv48ik9dMA0RcpJK0LYA8GDD5_MQaycXUjeoCO5tvlGQUEE8Dt815Ev3xh9dlKJiKJf_cClK_kL7iBICasjcVSNxP606Zn1fBTc_hF3QymDP8Q_Xl-g9p-ufobo_x2x6nOOdfiq1q37ik_3kPZCPPCnI2OtM3EYZ_yVlFmwWbtqxp2Rz3jaWo6t1TL8PWkFZno-aqJ8YEm4ppZVyX7ne6GTFKzFpW1SKnAnjqu67LjS0DCFhFATAcYhWRmdb5cMXve72eU5DNLgKZOaM0THDa5dnQPaRmu-c-7HlB3WISebcJHj0vIDw7DLwxaCnqqLgybvqW7O-Rt82bat5Lc-jIyMjjkZOutnc3OoYTPRN5PrLZ6HXGFBTq_s5fCimxpXvlw0bzNHqOzovgP4NC6UChXwn9CxSrbLou8vuqeD4YqyjBhh5Do3l7KcGMZtYUUMUhVf7fUrZIJcLb-ZBrg6UdiPRc3h9jcubLjXcPIrzxZeqgQ2ccJRljZqq4CBoX8WU6AiEe51Hbp0C96693G7rVomhzxa8JCMCCL3sy8v8nl1tqfqQ3kE53XnvzuqNMzJCNIfZg-GehqMBmJZp-Vaup9DOLWYWnzRqYsO9pC9r37Ajhh-wkxZyV0XhXD6EecYGcXBX8TOQRF9OMV9BYRDYVnpnGfTBGyxI7lkCU3prZQL0H4JlS48k38oWj9fE1F-PnGysDXLAxYqEERP-AXIwgsPyoKYrq7mG5UzBcEn6qminlmT2wM-EnInAb7F9e2aMxFwWv17RG2fWkO28p4t9hiMEyvAQfYyWKhGui1yqPnHnytNe3BJspV0uekBUBauvvHHLF0_tgHTJwcU03KlOLto7iIOUuwLgBT3z-_diWu_w9WTiOqZppBYKHUVfQlR1JsFn0j0Tg-kk5NbRfRU5BB4nMtQlkpW3vjTrQW88SrSKNOr_vepd1F39EqxBNG4tIyR6lvSCSGefWzUxwHBoHou8MtGPvcdxInB7imHdpwCFtoVy0wGh2kXy36CyxXqV82VehVvmfc8AnNtgVWTPJ8vvs2AQ2WL7xd7lE-DCfy-la9d31t6iLxVAkq_S067Z3iyTiJg508a_BGt_UF41VW2m2dW2v8KRjgXks4Wz7kAvO7rLLhL2CMdh-W7bBP4bfVa9BfcKVKGxaa6kCQB4rXtFs2gEbWx4o6iXf6KIUgl-RAdAgYYjIsavlkyihtgQ28-1--7JjaeO61d3wLYgJ5_t_POL8OjmFJfPeO1m9HTzDJf30C5hVP348dRhrkxSTPlSo=w3385-h1364',
46
+ 'https://lh3.googleusercontent.com/fife/AK0iWDzcrCgKc1IMiriL10KOqvN79VCXca9U5g657RP5HU1zsUNVcZPyefBpRCtevRE0k6FkGxUa4yKW5ELtqaSmVRE8R2jwZPHmpd0xeBvIHYeoySmntJD3wJl_iC9Ma20qubbH3OFIxNFFLCXMJDPY9cJ2D1xFzjD5t6jQi7tJKXc0B5W9vGTyQ3dzHlZuLEfTp4D58WszOFBLPsOq8zeve0ej-2wCEMkCrT8kTwfKTnsi_GXRpK1xFRZjczbA9kSwIxZ_x1iWNWW_XY8aw5Kgn5Du-4r3rBqpHr4_fzv5ehY20T7Cq6Rf15zCBnBad1HGxLLnrJrXujKwxHELaRfqXPItfQAmuoIfL-tgK6giSSNKDDT2Ynn7AXxHz5vzD3m4ALYWql9Qpe1Jw69AxGZt3vyuSC76LPhMmNIJzVIXVMLnw55hHCLH28GVg2WxLTLyUXoX7o-PHhFjGoBbx3X58yQjyrEQPbfmY3gtuijG_vrjoz3SPcoG4eXB6d_NnEM3b84Ml9eqE8BjlACMEJriwtqLEstPBhIsJygDZKB822bY63pG2hYzYi5-bWS19NRwun-jdNraiye8D9AIqxrXaknLvbCnWbgbKDJzlRnt9Nz8GJKU__oZ1wQgzq2DOQqEoJxPwpocJxwYjnzIb-dFqhcLkR9SQVx7rAEyJ_u23WBSp4AVTw6c7sYxd597C5EFBvzOv4qYtJp1G-hBrvHYFvYD13oDtgcFDO_F0nXg7XwmCJ7aQ28Hp79dWewQiSq0nBcgWSOic1Q4feWLFEL21Dw0pFmSVF9f3C67YNA7ZXukOAnv6xUN5RNVoBl58-KDOFSQZYysT8FGYaaY8bFRvYvI-VXcBtMuAw1VYMKLKHrhg2mSvvdPCxsGH_rQWijgGUm9pvqbmyrPzYRd-Tr_i6pKH7EMcBYgcvZNMk9_B2EMj6NUWLVkHEUTDiovgFQJDrYAuiUcKZgp_MXJcTDO79qJEsiP52C83C8Vg91w4t_Q0dPklY8EjjpbMZGvsr5NC-TlQraNFszsrUA_JsNdnMW6Te2GMnXHDaHjEDCdX2kY3XC1Ltzvl3f4tHY12OMOUzMKHFWHWyBNYNKlHnwPdxYNxvHKa-9p_okvxw547oBfsXwUrRpQQVxljmLZJbbGpxkfbEW6Rg70MRKHdbEUg0h0FS7F90kuD7pR1zUFv92fPoI5BUNUn2XjQb-DZZaC1oF5VtRMgO8RSFM_Dolrx_c6ZrPLy7bllakKTDj48CNvafL3UxaD6x9FFu4dNQVHKidcVS__EZ1SYMSZoDkD2sna5Pzpjl0Qz61U4KVAz5lSCJbF1stdiwO4jMwzmAgMNV5-fnJA3kkf9dDaIzqk2diKUh-WGSOwEwwizH54Y-e2EASTqCzGK7FGoVRVr1d1PdN7wcd8MuRMXqBfonrIGmf5cSiuvOL6odbSChO2WFOKkRbBHeV1pO46uaBVeLKGjcbAALymYsv23_veIW-RQdHnaVlvB2HE7wD1afI-3LaRKCjCUD_X5QJuTXff_EQayDwtW46_0hYRI76LGXjQuOc_LUdIi2QHlO4kt2kR8eLBPm09gKl0GJrZoS9HxKS0sERHCjUvugFcvVmgt7idjFKfi0AuUz9XmqYCoSRiMYj9z612Ot5L2D00SLTOcTLV6nlx-PeBpAyiu2ia0ehvVjzLUn9ai0_XOg1bU-ab8fvzrRJIrao5ZxdRp9wF41lJpkanwvurynNJXXu0uk2SoA1_soIeLshOsONxT4DG_PitVFMYjxY8rgx7hfNBJmAJW4GZPsnx1_P5Ojip=w3385-h1364',
47
+ 'https://lh3.googleusercontent.com/fife/AK0iWDziEIX1beA1lUdgcqMcnVcmRuODfH0IHpIkZW5YYZzhQRcRmYCfi9N7-vRWfcfMTuR28ZWDA5EngVjUpwIbVxRBF2DaIb_lJhd9zG3arGxRbi7CwmWdhAXeODEvniYR-IxtWB9lYpNd9hJ8wdleTP_ai10Xscy7iFeXOmFZ76dnr0r690LnZULOd7iyGv9EWZmhKhw4wtEJsqi3e1Yu4CXRsVrM5KLYKG-EWdRW_-m9H3o0G2W7KvOwDvqAewIz5zApPBHvuEE6x-XaUOuF_FuYQVhsKpfI_1Y70SCjOCphpbn51Bv8idg0tgTDn8oL8hkvSl50VqgQqjCLNxmCHlQE88xmjK_4NMI3kbIBLWfiPGCURr95dZt8eniqVi7yu8LNgkaMixAdRBCrQF_z56EsIvkozenXBdS2FiaUEh7LqIHLBcOa6ZaxV6_3t5Q83wgJTaM4cNkvH5_nCeQ9wkwKjf7zcBxFusa5LvhM-qSm4BJz3WzE1zgqTLVnDeh-EFNPMilPhevOdBuNfTY_VF8chvWS5Nwwcxlls8xSdVVqblYGw8YBlzWdi_X5PqynTKn6aWE0IiWOzA_O0hg2q1FtAHRT-PaINo4wjIbBar6fiNNwcZIeTKJrijHcpkIhnI8PHxrUtO3s0c2pfLFvuxCCRMSEfxpcwt0rz-ODEWIZkALajwE2SyFV6Qioc4fH_xWnI-jgRvzHjDf2c14vx3bXjM_gy-25mrECLQYcSWWZVINUbvKf6_YQDwwzAKL9zhMpyGa4EToTBhMSmroGi-NwIPxh8gAfdBCAh8TFAdg2aA7D3N_KpAv4Eh5bkovhCiALFYkGLch6KogZcn7NU3OX8qyn_wJ3oGO2CFmfKkMtLHqmjQpnLtM1U9BPnRELir7pNyG9bTNzs-Vz7-Hzu3vJavGeRhypl0JCoGfO08be-ee_7EnUcKSdepfd3dG39Gc1eulLgIVCRb82Ga5mAkNp_SDSa2BGI24--uOyAUwTBazpQjJ25W0wsHpLRF4obk8Tygl8Fgt2F-VPXYz1-q0x3_KZVWf-PJmKjYD6t3ICuBMoFeJtQUxp88WlSKC7KvhEZYdEaHmEabNNK7j-VTAgi0BeBaw_dTTO0tad9rXbCW9Co3Tc1YXv53oz96VURj-FAKHk_PKPRSV7-NO-BWAk1DOTq3ZDnlKUTA5-x6k4IR5HyNzW9C7rIPGzd_PRA9ddSiRxOjSiBru_P8xS0zQn6p75V48ZkoNsLPWEWCKhANJOaOB7Y01pg3wjjnftuxkp0KpokrlCZVUn2eKPmB0Oee6TP_6DVFhgM6ksqLHO-sNxpehUjWDx84znkN0MihGRgl6TK-6xnWzD9tjvIOsK0mBzk_XY3Vuvb5OEZvLzDJ5POqNHjLcAFaDtX7gsAUtEWk20qmRbpGBnHiZv2kLOUWCy6ICkc3yFv5uUMx7pxgfc_YO95ybO8-FTDG7m1yaoz-WdLV3tHao4_MfFaRXGKtV0_7xnlyXEZ3tMYwKu4hRx2lIOsL4Aff_O8-H0jmJId0llt__iOdVDkuypQWQDOKGGP9B1_gfLkV-ymEP0Bl59jQWNnAqE-jUpTeRRcUB6FkcH8XBPKL7F9N0sq-6XeOmPPpsecmm3SflF6zJ1YV8Uv6H4_9_uQLVBB8wXSvtcQuwgzYnrtpjpMQwFqSvJDhcCPGRfRCR6H7oa-T_ACYAMcICpl8felwVUOQs4O03ywLHNrZBY05hS13cj-_aYw69kw9TdetT-GbvTKC6eY5uwBTq4ytb4eeJQJc4zBlB2Dw1vKmcgIFfZ=w3385-h1364',
48
+ ]
49
+
50
+ example_previews = [
51
+ [thumbnails[0], 'Prompt: medieval castle'],
52
+ [thumbnails[1], 'Prompt: parrot'],
53
+ [thumbnails[2], 'Prompt: hoodie'],
54
+ [thumbnails[3], 'Prompt: salad'],
55
+ [thumbnails[4], 'Prompt: space helmet'],
56
+ [thumbnails[5], 'Prompt: laptop'],
57
+ [thumbnails[6], 'Prompt: antique greek vase'],
58
+ [thumbnails[7], 'Prompt: sunglasses'],
59
+ ]
60
+
61
+ # Load models
62
+ inpainting_models = OrderedDict([
63
+ ("Dreamshaper Inpainting V8", models.ds_inp.load_model()),
64
+ ("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
65
+ ("Stable-Inpainting 1.5", models.sd15_inp.load_model())
66
+ ])
67
+ sr_model = models.sd2_sr.load_model()
68
+ sam_predictor = models.sam.load_model()
69
+
70
+ inp_model = None
71
+ cached_inp_model_name = ''
72
+
73
+ def remove_cached_inpainting_model():
74
+ global inp_model
75
+ global cached_inp_model_name
76
+ del inp_model
77
+ inp_model = None
78
+ cached_inp_model_name = ''
79
+ torch.cuda.empty_cache()
80
+
81
+
82
+ def set_model_from_name(inp_model_name):
83
+ global cached_inp_model_name
84
+ global inp_model
85
+
86
+ if inp_model_name == cached_inp_model_name:
87
+ print (f"Activating Cached Inpaintng Model: {inp_model_name}")
88
+ return
89
+
90
+ print (f"Activating Inpaintng Model: {inp_model_name}")
91
+ inp_model = inpainting_models[inp_model_name]
92
+ cached_inp_model_name = inp_model_name
93
+
94
+
95
+ def rasg_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps,
96
+ guidance_scale=7.5, batch_size=4):
97
+ torch.cuda.empty_cache()
98
+
99
+ seed = int(seed)
100
+ batch_size = max(1, min(int(batch_size), 4))
101
+
102
+ image = IImage(input['image']).resize(512)
103
+ mask = IImage(input['mask']).rgb().resize(512)
104
+
105
+ method = ['rasg']
106
+ if use_painta: method.append('painta')
107
+
108
+ inpainted_images = []
109
+ blended_images = []
110
+ for i in range(batch_size):
111
+ inpainted_image = rasg.run(
112
+ ddim = inp_model,
113
+ method = '-'.join(method),
114
+ prompt = prompt,
115
+ image = image.padx(64),
116
+ mask = mask.alpha().padx(64),
117
+ seed = seed+i*1000,
118
+ eta = eta,
119
+ prefix = '{}',
120
+ negative_prompt = negative_prompt,
121
+ positive_prompt = f', {positive_prompt}',
122
+ dt = 1000 // ddim_steps,
123
+ guidance_scale = guidance_scale
124
+ ).crop(image.size)
125
+ blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0],
126
+ mask = mask.data[0], dilation = 12)
127
+
128
+ blended_images.append(blended_image)
129
+ inpainted_images.append(inpainted_image.numpy()[0])
130
+
131
+ return blended_images, inpainted_images
132
+
133
+
134
+ def sd_run(use_painta, prompt, input, seed, eta, negative_prompt, positive_prompt, ddim_steps,
135
+ guidance_scale=7.5, batch_size=4):
136
+ torch.cuda.empty_cache()
137
+
138
+ seed = int(seed)
139
+ batch_size = max(1, min(int(batch_size), 4))
140
+
141
+ image = IImage(input['image']).resize(512)
142
+ mask = IImage(input['mask']).rgb().resize(512)
143
+
144
+ method = ['default']
145
+ if use_painta: method.append('painta')
146
+
147
+ inpainted_images = []
148
+ blended_images = []
149
+ for i in range(batch_size):
150
+ inpainted_image = sd.run(
151
+ ddim = inp_model,
152
+ method = '-'.join(method),
153
+ prompt = prompt,
154
+ image = image.padx(64),
155
+ mask = mask.alpha().padx(64),
156
+ seed = seed+i*1000,
157
+ eta = eta,
158
+ prefix = '{}',
159
+ negative_prompt = negative_prompt,
160
+ positive_prompt = f', {positive_prompt}',
161
+ dt = 1000 // ddim_steps,
162
+ guidance_scale = guidance_scale
163
+ ).crop(image.size)
164
+
165
+ blended_image = poisson_blend(orig_img = image.data[0], fake_img = inpainted_image.data[0],
166
+ mask = mask.data[0], dilation = 12)
167
+
168
+ blended_images.append(blended_image)
169
+ inpainted_images.append(inpainted_image.numpy()[0])
170
+
171
+ return blended_images, inpainted_images
172
+
173
+
174
+ def upscale_run(
175
+ prompt, input, ddim_steps, seed, use_sam_mask, gallery, img_index,
176
+ negative_prompt='', positive_prompt=', high resolution professional photo'):
177
+ torch.cuda.empty_cache()
178
+
179
+ # Load SR model and SAM predictor
180
+ # sr_model = models.sd2_sr.load_model()
181
+ # sam_predictor = None
182
+ # if use_sam_mask:
183
+ # sam_predictor = models.sam.load_model()
184
+
185
+ seed = int(seed)
186
+ img_index = int(img_index)
187
+
188
+ img_index = 0 if img_index < 0 else img_index
189
+ img_index = len(gallery) - 1 if img_index >= len(gallery) else img_index
190
+ img_info = gallery[img_index if img_index >= 0 else 0]
191
+ inpainted_image = image_from_url_text(img_info)
192
+ lr_image = IImage(inpainted_image)
193
+ hr_image = IImage(input['image']).resize(2048)
194
+ hr_mask = IImage(input['mask']).resize(2048)
195
+ output_image = sr.run(sr_model, sam_predictor, lr_image, hr_image, hr_mask, prompt=prompt + positive_prompt,
196
+ noise_level=0, blend_trick=True, blend_output=True, negative_prompt=negative_prompt,
197
+ seed=seed, use_sam_mask=use_sam_mask)
198
+ return output_image.numpy()[0], output_image.numpy()[0]
199
+
200
+
201
+ def switch_run(use_rasg, model_name, *args):
202
+ set_model_from_name(model_name)
203
+ if use_rasg:
204
+ return rasg_run(*args)
205
+ return sd_run(*args)
206
+
207
+
208
+ with gr.Blocks(css='style.css') as demo:
209
+ gr.HTML(
210
+ """
211
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
212
+ <h1 style="font-weight: 900; font-size: 3rem; margin-bottom: 0.5rem">
213
+ 🧑‍🎨 HD-Painter Demo
214
+ </h1>
215
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
216
+ Hayk Manukyan<sup>1*</sup>, Andranik Sargsyan<sup>1*</sup>, Barsegh Atanyan<sup>1</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>
217
+ and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a><sup>1,3</sup>
218
+ </h2>
219
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
220
+ <sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UT Austin, <sup>3</sup>Georgia Tech
221
+ </h2>
222
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
223
+ [<a href="https://arxiv.org/abs/2312.14091" style="color:blue;">arXiv</a>]
224
+ [<a href="https://github.com/Picsart-AI-Research/HD-Painter" style="color:blue;">GitHub</a>]
225
+ </h2>
226
+ <h2 style="font-weight: 450; font-size: 1rem; margin: 0.7rem auto; max-width: 1000px">
227
+ <b>HD-Painter</b> enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method.
228
+ </h2>
229
+ </div>
230
+ """)
231
+
232
+ if on_huggingspace:
233
+ gr.HTML("""
234
+ <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
235
+ <br/>
236
+ <a href="https://huggingface.co/spaces/PAIR/HD-Painter?duplicate=true">
237
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
238
+ </p>""")
239
+
240
+ with open('script.js', 'r') as f:
241
+ js_str = f.read()
242
+
243
+ demo.load(_js=js_str)
244
+
245
+ with gr.Row():
246
+ with gr.Column():
247
+ model_picker = gr.Dropdown(
248
+ list(inpainting_models.keys()),
249
+ value=0,
250
+ label = "Please select a model!",
251
+ )
252
+ with gr.Column():
253
+ use_painta = gr.Checkbox(value = True, label = "Use PAIntA")
254
+ use_rasg = gr.Checkbox(value = True, label = "Use RASG")
255
+
256
+ prompt = gr.Textbox(label = "Inpainting Prompt")
257
+ with gr.Row():
258
+ with gr.Column():
259
+ input = gr.ImageMask(label = "Input Image", brush_color='#ff0000', elem_id="inputmask")
260
+
261
+ with gr.Row():
262
+ inpaint_btn = gr.Button("Inpaint", scale = 0)
263
+
264
+ with gr.Accordion('Advanced options', open=False):
265
+ guidance_scale = gr.Slider(minimum = 0, maximum = 30, value = 7.5, label = "Guidance Scale")
266
+ eta = gr.Slider(minimum = 0, maximum = 1, value = 0.1, label = "eta")
267
+ ddim_steps = gr.Slider(minimum = 10, maximum = 100, value = 50, step = 1, label = 'Number of diffusion steps')
268
+ with gr.Row():
269
+ seed = gr.Number(value = 49123, label = "Seed")
270
+ batch_size = gr.Number(value = 1, label = "Batch size", minimum=1, maximum=4)
271
+ negative_prompt = gr.Textbox(value=negative_prompt_str, label = "Negative prompt", lines=3)
272
+ positive_prompt = gr.Textbox(value=positive_prompt_str, label = "Positive prompt", lines=1)
273
+
274
+ with gr.Column():
275
+ with gr.Row():
276
+ output_gallery = gr.Gallery(
277
+ [],
278
+ columns = 4,
279
+ preview = True,
280
+ allow_preview = True,
281
+ object_fit='scale-down',
282
+ elem_id='outputgallery'
283
+ )
284
+ with gr.Row():
285
+ upscale_btn = gr.Button("Send to Inpainting-Specialized Super-Resolution (x4)", scale = 1)
286
+ with gr.Row():
287
+ use_sam_mask = gr.Checkbox(value = False, label = "Use SAM mask for background preservation (for SR only, experimental feature)")
288
+ with gr.Row():
289
+ hires_image = gr.Image(label = "Hi-res Image")
290
+
291
+ label = gr.Markdown("## High-Resolution Generation Samples (2048px large side)")
292
+
293
+ with gr.Column():
294
+ example_container = gr.Gallery(
295
+ example_previews,
296
+ columns = 4,
297
+ preview = True,
298
+ allow_preview = True,
299
+ object_fit='scale-down'
300
+ )
301
+
302
+ gr.Examples(
303
+ [
304
+ example_inputs[i] + [[example_previews[i]]]
305
+ for i in range(len(example_previews))
306
+ ],
307
+ [input, prompt, example_container]
308
+ )
309
+
310
+ mock_output_gallery = gr.Gallery([], columns = 4, visible=False)
311
+ mock_hires = gr.Image(label = "__MHRO__", visible = False)
312
+ html_info = gr.HTML(elem_id=f'html_info', elem_classes="infotext")
313
+
314
+ inpaint_btn.click(
315
+ fn=switch_run,
316
+ inputs=[
317
+ use_rasg,
318
+ model_picker,
319
+ use_painta,
320
+ prompt,
321
+ input,
322
+ seed,
323
+ eta,
324
+ negative_prompt,
325
+ positive_prompt,
326
+ ddim_steps,
327
+ guidance_scale,
328
+ batch_size
329
+ ],
330
+ outputs=[output_gallery, mock_output_gallery],
331
+ api_name="inpaint"
332
+ )
333
+ upscale_btn.click(
334
+ fn=upscale_run,
335
+ inputs=[
336
+ prompt,
337
+ input,
338
+ ddim_steps,
339
+ seed,
340
+ use_sam_mask,
341
+ mock_output_gallery,
342
+ html_info
343
+ ],
344
+ outputs=[hires_image, mock_hires],
345
+ api_name="upscale",
346
+ _js="function(a, b, c, d, e, f, g){ return [a, b, c, d, e, f, selected_gallery_index()] }",
347
+ )
348
+
349
+ demo.queue()
350
+ demo.launch(share=True, allowed_paths=[TMP_DIR])
assets/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ models/
assets/config/ddpm/v1.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ linear_start: 0.00085
2
+ linear_end: 0.0120
3
+ num_timesteps_cond: 1
4
+ log_every_t: 200
5
+ timesteps: 1000
6
+ first_stage_key: "jpg"
7
+ cond_stage_key: "txt"
8
+ image_size: 64
9
+ channels: 4
10
+ cond_stage_trainable: false
11
+ conditioning_key: crossattn
12
+ monitor: val/loss_simple_ema
13
+ scale_factor: 0.18215
14
+ use_ema: False # we set this to false because this is an inference only config
assets/config/ddpm/v2-upsample.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parameterization: "v"
2
+ low_scale_key: "lr"
3
+ linear_start: 0.0001
4
+ linear_end: 0.02
5
+ num_timesteps_cond: 1
6
+ log_every_t: 200
7
+ timesteps: 1000
8
+ first_stage_key: "jpg"
9
+ cond_stage_key: "txt"
10
+ image_size: 128
11
+ channels: 4
12
+ cond_stage_trainable: false
13
+ conditioning_key: "hybrid-adm"
14
+ monitor: val/loss_simple_ema
15
+ scale_factor: 0.08333
16
+ use_ema: False
17
+
18
+ low_scale_config:
19
+ target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
20
+ params:
21
+ noise_schedule_config: # image space
22
+ linear_start: 0.0001
23
+ linear_end: 0.02
24
+ max_noise_level: 350
assets/config/encoders/clip.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ __class__: smplfusion.models.encoders.clip_embedder.FrozenCLIPEmbedder
assets/config/encoders/openclip.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __class__: smplfusion.models.encoders.open_clip_embedder.FrozenOpenCLIPEmbedder
2
+ __init__:
3
+ freeze: True
4
+ layer: "penultimate"
assets/config/unet/inpainting/v1.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.unet.UNetModel
2
+ __init__:
3
+ image_size: 32 # unused
4
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
5
+ out_channels: 4
6
+ model_channels: 320
7
+ attention_resolutions: [ 4, 2, 1 ]
8
+ num_res_blocks: 2
9
+ channel_mult: [ 1, 2, 4, 4 ]
10
+ num_heads: 8
11
+ use_spatial_transformer: True
12
+ transformer_depth: 1
13
+ context_dim: 768
14
+ use_checkpoint: False
15
+ legacy: False
assets/config/unet/inpainting/v2.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.unet.UNetModel
2
+ __init__:
3
+ use_checkpoint: False
4
+ image_size: 32 # unused
5
+ in_channels: 9
6
+ out_channels: 4
7
+ model_channels: 320
8
+ attention_resolutions: [ 4, 2, 1 ]
9
+ num_res_blocks: 2
10
+ channel_mult: [ 1, 2, 4, 4 ]
11
+ num_head_channels: 64 # need to fix for flash-attn
12
+ use_spatial_transformer: True
13
+ use_linear_in_transformer: True
14
+ transformer_depth: 1
15
+ context_dim: 1024
16
+ legacy: False
assets/config/unet/upsample/v2.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.unet.UNetModel
2
+ __init__:
3
+ use_checkpoint: False
4
+ num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
5
+ image_size: 128
6
+ in_channels: 7
7
+ out_channels: 4
8
+ model_channels: 256
9
+ attention_resolutions: [ 2,4,8]
10
+ num_res_blocks: 2
11
+ channel_mult: [ 1, 2, 2, 4]
12
+ disable_self_attentions: [True, True, True, False]
13
+ disable_middle_self_attn: False
14
+ num_heads: 8
15
+ use_spatial_transformer: True
16
+ transformer_depth: 1
17
+ context_dim: 1024
18
+ legacy: False
19
+ use_linear_in_transformer: True
assets/config/vae-upsample.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.vae.AutoencoderKL
2
+ __init__:
3
+ embed_dim: 4
4
+ ddconfig:
5
+ double_z: True
6
+ z_channels: 4
7
+ resolution: 256
8
+ in_channels: 3
9
+ out_ch: 3
10
+ ch: 128
11
+ ch_mult: [ 1,2,4 ]
12
+ num_res_blocks: 2
13
+ attn_resolutions: [ ]
14
+ dropout: 0.0
15
+ lossconfig:
16
+ target: torch.nn.Identity
assets/config/vae.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __class__: smplfusion.models.vae.AutoencoderKL
2
+ __init__:
3
+ embed_dim: 4
4
+ monitor: val/rec_loss
5
+ ddconfig:
6
+ double_z: true
7
+ z_channels: 4
8
+ resolution: 256
9
+ in_channels: 3
10
+ out_ch: 3
11
+ ch: 128
12
+ ch_mult: [1,2,4,4]
13
+ num_res_blocks: 2
14
+ attn_resolutions: []
15
+ dropout: 0.0
16
+ lossconfig:
17
+ target: torch.nn.Identity
assets/examples/images/a19.jpg ADDED

Git LFS Details

  • SHA256: 4622138454df716ad6a8015c13cf7889a94c63e54c759f130a576eb5280eabf1
  • Pointer size: 131 Bytes
  • Size of remote file: 237 kB
assets/examples/images/a2.jpg ADDED

Git LFS Details

  • SHA256: 74cc2a7407234fc477e66a5a776a57d7b21618e5fea166a32f9b20d6dbd272ba
  • Pointer size: 132 Bytes
  • Size of remote file: 1.32 MB
assets/examples/images/a4.jpg ADDED

Git LFS Details

  • SHA256: 7811c416c6352720d853de721cc41f6d52b2d034c10277c35301c272c3843f7f
  • Pointer size: 131 Bytes
  • Size of remote file: 824 kB
assets/examples/images/a40.jpg ADDED

Git LFS Details

  • SHA256: 03c76588a2a8782e2bab2f5b309e1b3a69932b1c87c38b16599ea0fedb9d30e7
  • Pointer size: 131 Bytes
  • Size of remote file: 470 kB
assets/examples/images/a46.jpg ADDED

Git LFS Details

  • SHA256: c0e06678355e5798e1ae280932b806891d111ac9f23fbbc8fde7429df666aadb
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
assets/examples/images/a51.jpg ADDED

Git LFS Details

  • SHA256: 3c18be20ed65489dfc851624923828a481275f680d200d1e57ba743d04208ff6
  • Pointer size: 131 Bytes
  • Size of remote file: 191 kB
assets/examples/images/a54.jpg ADDED

Git LFS Details

  • SHA256: a4333811c5597d088979d58ff778e854bc55f829e9c9b4f564148c1f59b99c36
  • Pointer size: 131 Bytes
  • Size of remote file: 639 kB
assets/examples/images/a65.jpg ADDED

Git LFS Details

  • SHA256: 7f7afcb446f4903c37df7a640f931278c9defe79d1d014a50268b3c0ae232543
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
assets/examples/masked/a19.png ADDED

Git LFS Details

  • SHA256: af2982b36bce993fa9bc0152317a805dea3e765dd5ffae961eba05cf7c54164f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.8 MB
assets/examples/masked/a2.png ADDED

Git LFS Details

  • SHA256: 7e860721ef68093859d6aae3dc4b595c07719149fb9f5c55cc3cc147ef82ed6b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.07 MB
assets/examples/masked/a4.png ADDED

Git LFS Details

  • SHA256: 4d8f25579699a11a94dc9dafdcab95000a71e82707dc8f6f0d18ceaf4fe44c0a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
assets/examples/masked/a40.png ADDED

Git LFS Details

  • SHA256: 988a74bfd29509b1fede12e4234ca8afd9926ad8bca2b87c9a7e6c5e0e758bb9
  • Pointer size: 132 Bytes
  • Size of remote file: 3.73 MB
assets/examples/masked/a46.png ADDED

Git LFS Details

  • SHA256: a624e2155e1c88c11c7f221179c329e44e7a3f500e22775805ce0280ee39b4af
  • Pointer size: 131 Bytes
  • Size of remote file: 769 kB
assets/examples/masked/a51.png ADDED

Git LFS Details

  • SHA256: cfa02db8e37622915725c9a3bdafb8b16462bddca414f004f51a3c41d2faaa51
  • Pointer size: 132 Bytes
  • Size of remote file: 1.87 MB
assets/examples/masked/a54.png ADDED

Git LFS Details

  • SHA256: 31cd7b538ee33e1e83a5aea6879fa3b898dc379d4e2d210bcd80246d8bec9a40
  • Pointer size: 132 Bytes
  • Size of remote file: 5.33 MB
assets/examples/masked/a65.png ADDED

Git LFS Details

  • SHA256: 0039ca6e324c3603cb1ed4c830b1d9b69224de12a3fa33dddfaa558b4076cefe
  • Pointer size: 132 Bytes
  • Size of remote file: 2.31 MB
assets/examples/sbs/a19.png ADDED

Git LFS Details

  • SHA256: 305df57833fcaf6d75c15cb1a29d048a43e687bb23e400ee37b2b8d48004bb39
  • Pointer size: 132 Bytes
  • Size of remote file: 2.42 MB
assets/examples/sbs/a2.png ADDED

Git LFS Details

  • SHA256: a034a4eeef7f49654e7d0efc29517f0cf79b5acef398721df8db437bd91a3ada
  • Pointer size: 132 Bytes
  • Size of remote file: 4.05 MB
assets/examples/sbs/a4.png ADDED

Git LFS Details

  • SHA256: fd4e51f460c957abf6195d76a49915531faa3a0945d1f6dd37b870dd172f4de8
  • Pointer size: 132 Bytes
  • Size of remote file: 7.04 MB
assets/examples/sbs/a40.png ADDED

Git LFS Details

  • SHA256: 186d588650180703e7046ecc61bd8c53f35d53769ceae45083304f1adc7e31d7
  • Pointer size: 132 Bytes
  • Size of remote file: 5.87 MB
assets/examples/sbs/a46.png ADDED

Git LFS Details

  • SHA256: 87487caeccfa7fea5280b8fa6c97ab0179d7c4568cf0f01d363852d6a48168aa
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
assets/examples/sbs/a51.png ADDED

Git LFS Details

  • SHA256: 8e5c983156abb921c25e43794e6a6be5dc5bc4af815121331bccb8259f35efd4
  • Pointer size: 132 Bytes
  • Size of remote file: 2.45 MB
assets/examples/sbs/a54.png ADDED

Git LFS Details

  • SHA256: 3e12ff2497806dfe1f04f2c29a62cec89eef7e8c9382b25ac04cc966cdfdd157
  • Pointer size: 132 Bytes
  • Size of remote file: 8.36 MB
assets/examples/sbs/a65.png ADDED

Git LFS Details

  • SHA256: 60f66a473f8e52bbcdfdd6c92ff99bae8a2af654a6e0f1b4c6d75385400fe0f9
  • Pointer size: 132 Bytes
  • Size of remote file: 3.29 MB
lib/__init__.py ADDED
File without changes
lib/methods/__init__.py ADDED
File without changes
lib/methods/rasg.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from lib.utils.iimage import IImage
3
+ from pytorch_lightning import seed_everything
4
+ from tqdm import tqdm
5
+
6
+ from lib.smplfusion import share, router, attentionpatch, transformerpatch
7
+ from lib.smplfusion.patches.attentionpatch import painta
8
+ from lib.utils import tokenize, scores
9
+
10
+ verbose = False
11
+
12
+
13
+ def init_painta(token_idx):
14
+ # Initialize painta
15
+ router.attention_forward = attentionpatch.painta.forward
16
+ router.basic_transformer_forward = transformerpatch.painta.forward
17
+ painta.painta_on = True
18
+ painta.painta_res = [16, 32]
19
+ painta.token_idx = token_idx
20
+
21
+ def init_guidance():
22
+ # Setup model for guidance only!
23
+ router.attention_forward = attentionpatch.default.forward_and_save
24
+ router.basic_transformer_forward = transformerpatch.default.forward
25
+
26
+ def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, positive_prompt, dt, guidance_scale):
27
+ # Text condition
28
+ prompt = prefix.format(prompt)
29
+ context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt])
30
+ token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>')))
31
+ token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')]
32
+
33
+ # Initialize painta
34
+ if 'painta' in method: init_painta(token_idx)
35
+ else: init_guidance()
36
+
37
+ # Image condition
38
+ unet_condition = ddim.get_inpainting_condition(image, mask)
39
+ share.set_mask(mask)
40
+
41
+ # Starting latent
42
+ seed_everything(seed)
43
+ zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
44
+
45
+ # Setup unet for guidance
46
+ ddim.unet.requires_grad_(True)
47
+
48
+ pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt)
49
+
50
+ for timestep in share.DDIMIterator(pbar):
51
+ if 'painta' in method and share.timestep <= 500: init_guidance()
52
+
53
+ zt = zt.detach()
54
+ zt.requires_grad = True
55
+
56
+ # Reset storage
57
+ share._crossattn_similarity_res16 = []
58
+
59
+ # Run the model
60
+ _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
61
+ eps_uncond, eps = ddim.unet(
62
+ torch.cat([_zt, _zt]),
63
+ timesteps = torch.tensor([timestep, timestep]).cuda(),
64
+ context = context
65
+ ).detach().chunk(2)
66
+
67
+ # Unconditional guidance
68
+ eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
69
+ z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
70
+
71
+ # Gradient Computation
72
+ score = scores.bce(share._crossattn_similarity_res16, share.mask16, token_idx = token_idx)
73
+ score.backward()
74
+ grad = zt.grad.detach()
75
+ ddim.unet.zero_grad() # Cleanup already
76
+
77
+ # DDIM Step
78
+ with torch.no_grad():
79
+ sigma = share.schedule.sigma(share.timestep, dt)
80
+ # Standartization
81
+ grad -= grad.mean()
82
+ grad /= grad.std()
83
+
84
+ zt = share.schedule.sqrt_alphas[share.timestep - dt] * z0 + torch.sqrt(1 - share.schedule.alphas[share.timestep - dt] - sigma ** 2) * eps + eta * sigma * grad
85
+
86
+ with torch.no_grad():
87
+ output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
88
+ return output_image
lib/methods/sd.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch_lightning import seed_everything
3
+ from tqdm import tqdm
4
+
5
+ from lib.utils.iimage import IImage
6
+ from lib.smplfusion import share, router, attentionpatch, transformerpatch
7
+ from lib.smplfusion.patches.attentionpatch import painta
8
+ from lib.utils import tokenize
9
+
10
+ verbose = False
11
+
12
+
13
+ def init_painta(token_idx):
14
+ # Initialize painta
15
+ router.attention_forward = attentionpatch.painta.forward
16
+ router.basic_transformer_forward = transformerpatch.painta.forward
17
+ painta.painta_on = True
18
+ painta.painta_res = [16, 32]
19
+ painta.token_idx = token_idx
20
+
21
+ def run(
22
+ ddim,
23
+ method,
24
+ prompt,
25
+ image,
26
+ mask,
27
+ seed,
28
+ eta,
29
+ prefix,
30
+ negative_prompt,
31
+ positive_prompt,
32
+ dt,
33
+ guidance_scale
34
+ ):
35
+ # Text condition
36
+ context = ddim.encoder.encode([negative_prompt, prompt + positive_prompt])
37
+ token_idx = list(range(1 + prefix.split(' ').index('{}'), tokenize(prompt).index('<end_of_text>')))
38
+ token_idx += [tokenize(prompt + positive_prompt).index('<end_of_text>')]
39
+
40
+ # Setup painta if needed
41
+ if 'painta' in method: init_painta(token_idx)
42
+ else: router.reset()
43
+
44
+ # Image condition
45
+ unet_condition = ddim.get_inpainting_condition(image, mask)
46
+ share.set_mask(mask)
47
+
48
+ # Starting latent
49
+ seed_everything(seed)
50
+ zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
51
+
52
+ # Turn off gradients
53
+ ddim.unet.requires_grad_(False)
54
+
55
+ pbar = tqdm(range(999, 0, -dt)) if verbose else range(999, 0, -dt)
56
+
57
+ for timestep in share.DDIMIterator(pbar):
58
+ if share.timestep <= 500: router.reset()
59
+
60
+ _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
61
+ eps_uncond, eps = ddim.unet(
62
+ torch.cat([_zt, _zt]),
63
+ timesteps = torch.tensor([timestep, timestep]).cuda(),
64
+ context = context
65
+ ).chunk(2)
66
+
67
+ eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
68
+ z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
69
+ zt = share.schedule.sqrt_alphas[timestep - dt] * z0 + share.schedule.sqrt_one_minus_alphas[timestep - dt] * eps
70
+
71
+ with torch.no_grad():
72
+ output_image = IImage(ddim.vae.decode(z0 / ddim.config.scale_factor))
73
+
74
+ return output_image
lib/methods/sr.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+ from glob import glob
4
+ from pathlib import Path as PythonPath
5
+
6
+ import cv2
7
+ import torchvision.transforms.functional as TvF
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+ from inspect import isfunction
12
+ from PIL import Image
13
+
14
+ from lib import smplfusion
15
+ from lib.smplfusion import share, router, attentionpatch, transformerpatch
16
+ from lib.utils.iimage import IImage
17
+ from lib.utils import poisson_blend
18
+ from lib.models.sd2_sr import predict_eps_from_z_and_v, predict_start_from_z_and_v
19
+
20
+
21
+ def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
22
+ lr_mask = hr_mask.resize(512)
23
+
24
+ x_min, y_min, rect_w, rect_h = cv2.boundingRect(lr_mask.data[0][:, :, 0])
25
+ x_min = max(x_min - 1, 0)
26
+ y_min = max(y_min - 1, 0)
27
+ x_max = x_min + rect_w + 1
28
+ y_max = y_min + rect_h + 1
29
+
30
+ input_box = np.array([x_min, y_min, x_max, y_max])
31
+
32
+ sam_predictor.set_image(hr_image.resize(512).data[0])
33
+ masks, _, _ = sam_predictor.predict(
34
+ point_coords=None,
35
+ point_labels=None,
36
+ box=input_box[None, :],
37
+ multimask_output=True,
38
+ )
39
+ dilation_kernel = np.ones((13, 13))
40
+ original_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
41
+ original_object_mask = cv2.dilate(original_object_mask, dilation_kernel)
42
+
43
+ sam_predictor.set_image(lr_image.resize(512).data[0])
44
+ masks, _, _ = sam_predictor.predict(
45
+ point_coords=None,
46
+ point_labels=None,
47
+ box=input_box[None, :],
48
+ multimask_output=True,
49
+ )
50
+ dilation_kernel = np.ones((3, 3))
51
+ inpainted_object_mask = (np.sum(masks, axis=0) > 0).astype(np.uint8)
52
+ inpainted_object_mask = cv2.dilate(inpainted_object_mask, dilation_kernel)
53
+
54
+ lr_mask_masking = ((original_object_mask + inpainted_object_mask ) > 0).astype(np.uint8)
55
+ new_mask = lr_mask.data[0] * lr_mask_masking[:, :, np.newaxis]
56
+ new_mask = IImage(new_mask).resize(2048, resample = Image.BICUBIC)
57
+ return new_mask
58
+
59
+
60
+ def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
61
+ blend_output = True, blend_trick = True, no_superres = False,
62
+ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False, dtype=torch.bfloat16):
63
+ torch.manual_seed(seed)
64
+
65
+ router.attention_forward = attentionpatch.default.forward_xformers
66
+ router.basic_transformer_forward = transformerpatch.default.forward
67
+
68
+ if use_sam_mask:
69
+ with torch.no_grad():
70
+ hr_mask = refine_mask(hr_image, hr_mask, lr_image, sam_predictor)
71
+
72
+ orig_h, orig_w = hr_image.torch().shape[2], hr_image.torch().shape[3]
73
+ hr_image = hr_image.padx(256, padding_mode='reflect')
74
+ hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
75
+ hr_mask_orig = hr_mask
76
+ lr_image = lr_image.padx(64, padding_mode='reflect')
77
+ lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).cuda()
78
+ lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
79
+
80
+ if no_superres:
81
+ output_tensor = lr_image.resize((hr_image.torch().shape[2], hr_image.torch().shape[3]), resample = Image.BICUBIC).torch().cuda()
82
+ output_tensor = (255*((output_tensor.clip(-1, 1) + 1) / 2)).to(torch.uint8)
83
+ output_tensor = poisson_blend(
84
+ orig_img=hr_image.data[0][:orig_h, :orig_w, :],
85
+ fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
86
+ mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
87
+ )
88
+ return IImage(output_tensor[:orig_h, :orig_w, :])
89
+
90
+ # encode hr image
91
+ with torch.no_grad():
92
+ hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype)).mean * ddim.config.scale_factor
93
+
94
+ assert hr_z0.shape[2] == lr_image.torch().shape[2]
95
+ assert hr_z0.shape[3] == lr_image.torch().shape[3]
96
+
97
+ unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype)
98
+ zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype)
99
+
100
+ with torch.no_grad():
101
+ context = ddim.encoder.encode([negative_prompt, prompt])
102
+
103
+ noise_level = torch.Tensor(1 * [noise_level]).to('cuda').long()
104
+ unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
105
+
106
+ with torch.autocast('cuda'), torch.no_grad():
107
+ zt = zT
108
+ for index,t in enumerate(range(999, 0, -dt)):
109
+
110
+ _zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
111
+
112
+ eps_uncond, eps = ddim.unet(
113
+ torch.cat([_zt, _zt]).to(dtype),
114
+ timesteps = torch.tensor([t, t]).cuda(),
115
+ context = context,
116
+ y=torch.cat([noise_level]*2)
117
+ ).chunk(2)
118
+
119
+ ts = torch.full((zt.shape[0],), t, device='cuda', dtype=torch.long)
120
+ model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
121
+ eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
122
+ z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
123
+
124
+ if blend_trick:
125
+ z0 = z0 * lr_mask + hr_z0 * (1-lr_mask)
126
+
127
+ zt = ddim.schedule.sqrt_alphas[t - dt] * z0 + ddim.schedule.sqrt_one_minus_alphas[t - dt] * eps
128
+
129
+ with torch.no_grad():
130
+ output_tensor = ddim.vae.decode(z0.to(dtype) / ddim.config.scale_factor)
131
+
132
+ if blend_output:
133
+ output_tensor = (255*((output_tensor + 1) / 2).clip(0, 1)).to(torch.uint8)
134
+ output_tensor = poisson_blend(
135
+ orig_img=hr_image.data[0][:orig_h, :orig_w, :],
136
+ fake_img=output_tensor.cpu().permute(0, 2, 3, 1)[0].numpy()[:orig_h, :orig_w, :],
137
+ mask=hr_mask_orig.alpha().data[0][:orig_h, :orig_w, :]
138
+ )
139
+ return IImage(output_tensor[:orig_h, :orig_w, :])
140
+ else:
141
+ return IImage(output_tensor[:, :, :orig_h, :orig_w])
lib/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import sd2_inp, ds_inp, sd15_inp, sd2_sr, sam
lib/models/common.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import requests
3
+ from pathlib import Path
4
+ from os.path import dirname
5
+
6
+ from omegaconf import OmegaConf
7
+ from tqdm import tqdm
8
+
9
+
10
+ PROJECT_DIR = dirname(dirname(dirname(__file__)))
11
+ CONFIG_FOLDER = f'{PROJECT_DIR}/assets/config'
12
+ MODEL_FOLDER = f'{PROJECT_DIR}/assets/models'
13
+
14
+
15
+ def download_file(url, save_path, chunk_size=1024):
16
+ try:
17
+ save_path = Path(save_path)
18
+ if save_path.exists():
19
+ print(f'{save_path.name} exists')
20
+ return
21
+ save_path.parent.mkdir(exist_ok=True, parents=True)
22
+ resp = requests.get(url, stream=True)
23
+ total = int(resp.headers.get('content-length', 0))
24
+ with open(save_path, 'wb') as file, tqdm(
25
+ desc=save_path.name,
26
+ total=total,
27
+ unit='iB',
28
+ unit_scale=True,
29
+ unit_divisor=1024,
30
+ ) as bar:
31
+ for data in resp.iter_content(chunk_size=chunk_size):
32
+ size = file.write(data)
33
+ bar.update(size)
34
+ print(f'{save_path.name} download finished')
35
+ except Exception as e:
36
+ raise Exception(f"Download failed: {e}")
37
+
38
+
39
+ def get_obj_from_str(string):
40
+ module, cls = string.rsplit(".", 1)
41
+ try:
42
+ return getattr(importlib.import_module(module, package=None), cls)
43
+ except:
44
+ return getattr(importlib.import_module('lib.' + module, package=None), cls)
45
+
46
+
47
+ def load_obj(path):
48
+ objyaml = OmegaConf.load(path)
49
+ return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
lib/models/ds_inp.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ import safetensors
5
+ import safetensors.torch
6
+
7
+ from lib.smplfusion import DDIM, share, scheduler
8
+ from .common import *
9
+
10
+
11
+ MODEL_PATH = f'{MODEL_FOLDER}/dreamshaper/dreamshaper_8Inpainting.safetensors'
12
+ DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
13
+
14
+ # pre-download
15
+ download_file(DOWNLOAD_URL, MODEL_PATH)
16
+
17
+
18
+ def load_model():
19
+ print ("Loading model: Dreamshaper Inpainting V8")
20
+
21
+ download_file(DOWNLOAD_URL, MODEL_PATH)
22
+
23
+ state_dict = safetensors.torch.load_file(MODEL_PATH)
24
+
25
+ config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
26
+ unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
27
+ vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
28
+ encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
29
+
30
+ extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
31
+ unet_state = extract(state_dict, 'model.diffusion_model')
32
+ encoder_state = extract(state_dict, 'cond_stage_model')
33
+ vae_state = extract(state_dict, 'first_stage_model')
34
+
35
+ unet.load_state_dict(unet_state)
36
+ encoder.load_state_dict(encoder_state)
37
+ vae.load_state_dict(vae_state)
38
+
39
+ unet = unet.requires_grad_(False)
40
+ encoder = encoder.requires_grad_(False)
41
+ vae = vae.requires_grad_(False)
42
+
43
+ ddim = DDIM(config, vae, encoder, unet)
44
+ share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
45
+
46
+ return ddim
lib/models/sam.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from segment_anything import sam_model_registry, SamPredictor
2
+ from .common import *
3
+
4
+ MODEL_PATH = f'{MODEL_FOLDER}/sam/sam_vit_h_4b8939.pth'
5
+ DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
6
+
7
+ # pre-download
8
+ download_file(DOWNLOAD_URL, MODEL_PATH)
9
+
10
+
11
+ def load_model():
12
+ print ("Loading model: SAM")
13
+ download_file(DOWNLOAD_URL, MODEL_PATH)
14
+ model_type = "vit_h"
15
+ device = "cuda"
16
+ sam = sam_model_registry[model_type](checkpoint=MODEL_PATH)
17
+ sam.to(device=device)
18
+ sam_predictor = SamPredictor(sam)
19
+ print ("SAM loaded")
20
+ return sam_predictor
lib/models/sd15_inp.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ import torch
3
+
4
+ from lib.smplfusion import DDIM, share, scheduler
5
+ from .common import *
6
+
7
+
8
+ DOWNLOAD_URL = 'https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt?download=true'
9
+ MODEL_PATH = f'{MODEL_FOLDER}/sd-1-5-inpainting/sd-v1-5-inpainting.ckpt'
10
+
11
+ # pre-download
12
+ download_file(DOWNLOAD_URL, MODEL_PATH)
13
+
14
+
15
+ def load_model():
16
+ download_file(DOWNLOAD_URL, MODEL_PATH)
17
+
18
+ state_dict = torch.load(MODEL_PATH)['state_dict']
19
+
20
+ config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
21
+
22
+ print ("Loading model: Stable-Inpainting 1.5")
23
+
24
+ unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v1.yaml').eval().cuda()
25
+ vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
26
+ encoder = load_obj(f'{CONFIG_FOLDER}/encoders/clip.yaml').eval().cuda()
27
+
28
+ extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
29
+ unet_state = extract(state_dict, 'model.diffusion_model')
30
+ encoder_state = extract(state_dict, 'cond_stage_model')
31
+ vae_state = extract(state_dict, 'first_stage_model')
32
+
33
+ unet.load_state_dict(unet_state)
34
+ encoder.load_state_dict(encoder_state)
35
+ vae.load_state_dict(vae_state)
36
+
37
+ unet = unet.requires_grad_(False)
38
+ encoder = encoder.requires_grad_(False)
39
+ vae = vae.requires_grad_(False)
40
+
41
+ ddim = DDIM(config, vae, encoder, unet)
42
+ share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
43
+
44
+ return ddim
lib/models/sd2_inp.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import safetensors
2
+ import safetensors.torch
3
+ import torch
4
+ from omegaconf import OmegaConf
5
+
6
+ from lib.smplfusion import DDIM, share, scheduler
7
+ from .common import *
8
+
9
+ MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-inpainting/512-inpainting-ema.safetensors'
10
+ DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.safetensors?download=true'
11
+
12
+ # pre-download
13
+ download_file(DOWNLOAD_URL, MODEL_PATH)
14
+
15
+
16
+ def load_model():
17
+ print ("Loading model: Stable-Inpainting 2.0")
18
+
19
+ download_file(DOWNLOAD_URL, MODEL_PATH)
20
+
21
+ state_dict = safetensors.torch.load_file(MODEL_PATH)
22
+
23
+ config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v1.yaml')
24
+
25
+ unet = load_obj(f'{CONFIG_FOLDER}/unet/inpainting/v2.yaml').eval().cuda()
26
+ vae = load_obj(f'{CONFIG_FOLDER}/vae.yaml').eval().cuda()
27
+ encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
28
+ ddim = DDIM(config, vae, encoder, unet)
29
+
30
+ extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
31
+ unet_state = extract(state_dict, 'model.diffusion_model')
32
+ encoder_state = extract(state_dict, 'cond_stage_model')
33
+ vae_state = extract(state_dict, 'first_stage_model')
34
+
35
+ unet.load_state_dict(unet_state)
36
+ encoder.load_state_dict(encoder_state)
37
+ vae.load_state_dict(vae_state)
38
+
39
+ unet = unet.requires_grad_(False)
40
+ encoder = encoder.requires_grad_(False)
41
+ vae = vae.requires_grad_(False)
42
+
43
+ ddim = DDIM(config, vae, encoder, unet)
44
+ share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
45
+
46
+ print('Stable-Inpainting 2.0 loaded')
47
+ return ddim
lib/models/sd2_sr.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from functools import partial
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import safetensors
7
+ import safetensors.torch
8
+ import torch
9
+ import torch.nn as nn
10
+ from inspect import isfunction
11
+ from omegaconf import OmegaConf
12
+
13
+ from lib.smplfusion import DDIM, share, scheduler
14
+ from .common import *
15
+
16
+
17
+ DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.safetensors?download=true'
18
+ MODEL_PATH = f'{MODEL_FOLDER}/sd-2-0-upsample/x4-upscaler-ema.safetensors'
19
+
20
+ # pre-download
21
+ download_file(DOWNLOAD_URL, MODEL_PATH)
22
+
23
+
24
+ def exists(x):
25
+ return x is not None
26
+
27
+
28
+ def default(val, d):
29
+ if exists(val):
30
+ return val
31
+ return d() if isfunction(d) else d
32
+
33
+
34
+ def extract_into_tensor(a, t, x_shape):
35
+ b, *_ = t.shape
36
+ out = a.gather(-1, t)
37
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
38
+
39
+
40
+ def predict_eps_from_z_and_v(schedule, x_t, t, v):
41
+ return (
42
+ extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * v +
43
+ extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * x_t
44
+ )
45
+
46
+
47
+ def predict_start_from_z_and_v(schedule, x_t, t, v):
48
+ return (
49
+ extract_into_tensor(schedule.sqrt_alphas.cuda(), t, x_t.shape) * x_t -
50
+ extract_into_tensor(schedule.sqrt_one_minus_alphas.cuda(), t, x_t.shape) * v
51
+ )
52
+
53
+
54
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
55
+ if schedule == "linear":
56
+ betas = (
57
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
58
+ )
59
+
60
+ elif schedule == "cosine":
61
+ timesteps = (
62
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
63
+ )
64
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
65
+ alphas = torch.cos(alphas).pow(2)
66
+ alphas = alphas / alphas[0]
67
+ betas = 1 - alphas[1:] / alphas[:-1]
68
+ betas = np.clip(betas, a_min=0, a_max=0.999)
69
+
70
+ elif schedule == "sqrt_linear":
71
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
72
+ elif schedule == "sqrt":
73
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
74
+ else:
75
+ raise ValueError(f"schedule '{schedule}' unknown.")
76
+ return betas.numpy()
77
+
78
+
79
+ def disabled_train(self, mode=True):
80
+ """Overwrite model.train with this function to make sure train/eval mode
81
+ does not change anymore."""
82
+ return self
83
+
84
+
85
+ class AbstractLowScaleModel(nn.Module):
86
+ # for concatenating a downsampled image to the latent representation
87
+ def __init__(self, noise_schedule_config=None):
88
+ super(AbstractLowScaleModel, self).__init__()
89
+ if noise_schedule_config is not None:
90
+ self.register_schedule(**noise_schedule_config)
91
+
92
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
93
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
94
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
95
+ cosine_s=cosine_s)
96
+ alphas = 1. - betas
97
+ alphas_cumprod = np.cumprod(alphas, axis=0)
98
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
99
+
100
+ timesteps, = betas.shape
101
+ self.num_timesteps = int(timesteps)
102
+ self.linear_start = linear_start
103
+ self.linear_end = linear_end
104
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
105
+
106
+ to_torch = partial(torch.tensor, dtype=torch.float32)
107
+
108
+ self.register_buffer('betas', to_torch(betas))
109
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
110
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
111
+
112
+ # calculations for diffusion q(x_t | x_{t-1}) and others
113
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
114
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
115
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
116
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
117
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
118
+
119
+ def q_sample(self, x_start, t, noise=None):
120
+ noise = default(noise, lambda: torch.randn_like(x_start))
121
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
122
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
123
+
124
+ def forward(self, x):
125
+ return x, None
126
+
127
+ def decode(self, x):
128
+ return x
129
+
130
+
131
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
132
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
133
+ super().__init__(noise_schedule_config=noise_schedule_config)
134
+ self.max_noise_level = max_noise_level
135
+
136
+ def forward(self, x, noise_level=None):
137
+ if noise_level is None:
138
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
139
+ else:
140
+ assert isinstance(noise_level, torch.Tensor)
141
+ z = self.q_sample(x, noise_level)
142
+ return z, noise_level
143
+
144
+
145
+ def get_obj_from_str(string):
146
+ module, cls = string.rsplit(".", 1)
147
+ try:
148
+ return getattr(importlib.import_module(module, package=None), cls)
149
+ except:
150
+ return getattr(importlib.import_module('lib.' + module, package=None), cls)
151
+ def load_obj(path):
152
+ objyaml = OmegaConf.load(path)
153
+ return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
154
+
155
+
156
+ def load_model(dtype=torch.bfloat16):
157
+ print ("Loading model: SD2 superresolution...")
158
+
159
+ download_file(DOWNLOAD_URL, MODEL_PATH)
160
+
161
+ state_dict = safetensors.torch.load_file(MODEL_PATH)
162
+
163
+ config = OmegaConf.load(f'{CONFIG_FOLDER}/ddpm/v2-upsample.yaml')
164
+
165
+ unet = load_obj(f'{CONFIG_FOLDER}/unet/upsample/v2.yaml').eval().cuda()
166
+ vae = load_obj(f'{CONFIG_FOLDER}/vae-upsample.yaml').eval().cuda()
167
+ encoder = load_obj(f'{CONFIG_FOLDER}/encoders/openclip.yaml').eval().cuda()
168
+ ddim = DDIM(config, vae, encoder, unet)
169
+
170
+ extract = lambda state_dict, model: {x[len(model)+1:]:y for x,y in state_dict.items() if model in x}
171
+ unet_state = extract(state_dict, 'model.diffusion_model')
172
+ encoder_state = extract(state_dict, 'cond_stage_model')
173
+ vae_state = extract(state_dict, 'first_stage_model')
174
+
175
+ unet.load_state_dict(unet_state)
176
+ encoder.load_state_dict(encoder_state)
177
+ vae.load_state_dict(vae_state)
178
+
179
+ unet = unet.requires_grad_(False)
180
+ encoder = encoder.requires_grad_(False)
181
+ vae = vae.requires_grad_(False)
182
+
183
+ unet.to(dtype)
184
+ vae.to(dtype)
185
+ encoder.to(dtype)
186
+
187
+ ddim = DDIM(config, vae, encoder, unet)
188
+
189
+ params = {
190
+ 'noise_schedule_config': {
191
+ 'linear_start': 0.0001,
192
+ 'linear_end': 0.02
193
+ },
194
+ 'max_noise_level': 350
195
+ }
196
+
197
+ low_scale_model = ImageConcatWithNoiseAugmentation(**params).eval().to('cuda')
198
+ low_scale_model.train = disabled_train
199
+ for param in low_scale_model.parameters():
200
+ param.requires_grad = False
201
+
202
+ ddim.low_scale_model = low_scale_model
203
+ print('SD2 superresolution loaded')
204
+ return ddim