tangmen commited on
Commit
113dbd0
1 Parent(s): dce2667
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. img/__pycache__/env.cpython-310.pyc +0 -0
  2. img/__pycache__/main.cpython-310.pyc +0 -0
  3. img/__pycache__/main.cpython-38.pyc +0 -0
  4. img/__pycache__/main_v2.cpython-310.pyc +0 -0
  5. img/__pycache__/main_v3.cpython-310.pyc +0 -0
  6. img/__pycache__/main_v4.cpython-310.pyc +0 -0
  7. img/__pycache__/main_v5.cpython-310.pyc +0 -0
  8. img/__pycache__/main_v6.cpython-310.pyc +0 -0
  9. img/__pycache__/main_v7.cpython-310.pyc +0 -0
  10. img/__pycache__/main_v8.cpython-310.pyc +0 -0
  11. img/dev-requirements.txt +11 -0
  12. img/env.py +2 -0
  13. img/img2img.py +25 -0
  14. img/img2imgsd.py +74 -0
  15. img/img2imgsdr.py +53 -0
  16. img/inpaint.py +62 -0
  17. img/log.0925 +53 -0
  18. img/main.py +528 -0
  19. img/main_1024.py +549 -0
  20. img/main_v2.py +548 -0
  21. img/main_v3.py +578 -0
  22. img/main_v4.py +603 -0
  23. img/main_v5.py +637 -0
  24. img/main_v6.py +636 -0
  25. img/main_v7.py +641 -0
  26. img/main_v8.py +675 -0
  27. img/manager.py +28 -0
  28. img/ops/supervisor.conf +17 -0
  29. img/ori/main.py +488 -0
  30. img/pr1/main.py +515 -0
  31. img/pr2/main.py +528 -0
  32. img/readme.md +109 -0
  33. img/requirements.txt +67 -0
  34. img/scripts/test_compression.py +22 -0
  35. img/stable-diffusion-server/.gitignore +13 -0
  36. img/stable-diffusion-server/.log.0925.swp +0 -0
  37. img/stable-diffusion-server/dev-requirements.txt +11 -0
  38. img/stable-diffusion-server/env.py +2 -0
  39. img/stable-diffusion-server/img2img.py +25 -0
  40. img/stable-diffusion-server/img2imgsd.py +74 -0
  41. img/stable-diffusion-server/img2imgsdr.py +53 -0
  42. img/stable-diffusion-server/inpaint.py +62 -0
  43. img/stable-diffusion-server/log.0925 +53 -0
  44. img/stable-diffusion-server/main.py +528 -0
  45. img/stable-diffusion-server/main_1024.py +549 -0
  46. img/stable-diffusion-server/main_v2.py +548 -0
  47. img/stable-diffusion-server/main_v3.py +578 -0
  48. img/stable-diffusion-server/main_v4.py +603 -0
  49. img/stable-diffusion-server/main_v5.py +637 -0
  50. img/stable-diffusion-server/main_v6.py +636 -0
img/__pycache__/env.cpython-310.pyc ADDED
Binary file (306 Bytes). View file
 
img/__pycache__/main.cpython-310.pyc ADDED
Binary file (8.75 kB). View file
 
img/__pycache__/main.cpython-38.pyc ADDED
Binary file (9.07 kB). View file
 
img/__pycache__/main_v2.cpython-310.pyc ADDED
Binary file (9.11 kB). View file
 
img/__pycache__/main_v3.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
img/__pycache__/main_v4.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
img/__pycache__/main_v5.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
img/__pycache__/main_v6.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
img/__pycache__/main_v7.cpython-310.pyc ADDED
Binary file (17 kB). View file
 
img/__pycache__/main_v8.cpython-310.pyc ADDED
Binary file (17.8 kB). View file
 
img/dev-requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytest
2
+
3
+ pytest-asyncio
4
+ requests-futures==1.0.0
5
+ httpx
6
+ djlint
7
+ pytest-env==0.8.1
8
+ ipython
9
+
10
+ line-profiler-pycharm==1.1.0
11
+ line-profiler==4.0.3
img/env.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ BUCKET_NAME = 'static.netwrck.com'
2
+ BUCKET_PATH = 'static/uploads'
img/img2img.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import torch
3
+ from PIL import Image
4
+ from io import BytesIO
5
+
6
+ from diffusers import StableDiffusionImg2ImgPipeline
7
+
8
+ device = "cuda"
9
+ model_id_or_path = "runwayml/stable-diffusion-v1-5"
10
+ # model_id_or_path = "models/stable-diffusion-xl-base-0.9"
11
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
12
+ pipe = pipe.to(device)
13
+
14
+ url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
15
+
16
+ response = requests.get(url)
17
+ # init_image = Image.open(BytesIO(response.content)).convert("RGB")
18
+ init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/aiknight-neon-punk-fantasy-art-good-looking-trending-fantastic-1.webp").convert("RGB")
19
+ # init_image = init_image.resize((768, 512))
20
+ init_image = init_image.resize((1920, 1080))
21
+
22
+ prompt = "knight neon punk fantasy art good looking trending fantastic"
23
+
24
+ images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
25
+ images[0].save("fantasy_landscape.png")
img/img2imgsd.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import requests
5
+ import torch
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ # from diffusers import StableDiffusionImg2ImgPipeline
10
+
11
+ # device = "cuda"
12
+ # model_id_or_path = "runwayml/stable-diffusion-v1-5"
13
+ # # model_id_or_path = "models/stable-diffusion-xl-base-0.9"
14
+ # pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
15
+ # pipe = pipe.to(device)
16
+
17
+ from diffusers import StableDiffusionXLImg2ImgPipeline
18
+ from diffusers.utils import load_image
19
+
20
+ from stable_diffusion_server.utils import log_time
21
+
22
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
23
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
24
+ # "models/stable-diffusion-xl-base-0.9",
25
+ torch_dtype = torch.float16,
26
+ use_safetensors=True,
27
+ variant="fp16",
28
+ )
29
+ pipe = pipe.to("cuda") # # "LayerNormKernelImpl" not implemented for 'Half' error if its on cpu it cant do fp16
30
+ # idea composite: and re prompt img-img to support different sizes
31
+
32
+ # url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
33
+ #
34
+ # response = requests.get(url)
35
+ # init_image = Image.open(BytesIO(response.content)).convert("RGB")
36
+ # init_image = init_image.resize((768, 512))
37
+ # successfully inpaints a deleted area strength=0.75
38
+ # init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/aiart/ainostalgic-colorful-relaxing-chill-realistic-cartoon-Charcoal-illustration-fantasy-fauvist-abstract-impressionist-watercolor-painting-Background-location-scenery-amazing-wonderful-Dog-Shelter-Worker-Dog.webp").convert("RGB")
39
+ # redo something? strength 1
40
+ # init_image = Image.open("/home/lee/code/sdif/mask.png").convert("RGB")
41
+ init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/dogstretch.png").convert("RGB")
42
+ # init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/dogcenter.png").convert("RGB")
43
+
44
+ # init_image = init_image.resize((1080, 1920))
45
+ init_image = init_image.resize((1920, 1080))
46
+ # init_image = init_image.resize((1024, 1024))
47
+
48
+ prompt = "A fantasy landscape, trending on artstation, beautiful amazing unreal surreal gorgeous impressionism"
49
+ prompt = "mouth open nostalgic colorful relaxing chill realistic cartoon Charcoal illustration fantasy fauvist abstract impressionist watercolor painting Background location scenery amazing wonderful Dog Shelter Worker Dog"
50
+
51
+ # images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
52
+ # images[0].save("fantasy_landscape.png")
53
+ #
54
+ # # url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
55
+ #
56
+ # init_image = load_image(url).convert("RGB")
57
+ # prompt = "a photo of an astronaut riding a horse on mars"
58
+ study_dir = "images/study2"
59
+ Path(study_dir).mkdir(parents=True, exist_ok=True)
60
+
61
+ with log_time("img2img"):
62
+ with torch.inference_mode():
63
+ # for strength in range(.1, 1, .1):
64
+ for strength in np.linspace(.1, 1, 10):
65
+ image = pipe(prompt=prompt, image=init_image, strength=strength, guidance_scale=7.6).images[0]
66
+ image.save(
67
+ study_dir + "/fantasy_dogimgimgdogstretchopening" + str(strength) + "guidance_scale" + str(7.6) + ".png")
68
+ # # for guidance_scale in range(1, 10, .5):
69
+ # for guidance_scale in np.linspace(1, 100, 10):
70
+ # image = pipe(prompt=prompt, image=init_image, strength=strength, guidance_scale=guidance_scale).images[0]
71
+ # image.save("images/study/fantasy_dogimgimgdogstretch" + str(strength) + "guidance_scale" + str(guidance_scale) + ".png")
72
+ # image = pipe(prompt, image=init_image, strength=0.2, guidance_scale=7.5).images[0]
73
+ # image.save("images/fantasy_dogimgimgdogstretch.png")
74
+ # image.save("images/fantasy_dogimgimgdogcenter.png")
img/img2imgsdr.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image
2
+
3
+ from diffusers import DiffusionPipeline
4
+ import torch
5
+
6
+ import numpy as np
7
+
8
+ from stable_diffusion_server.utils import log_time
9
+
10
+ pipe = DiffusionPipeline.from_pretrained(
11
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
12
+ )
13
+ pipe.to("cuda")
14
+
15
+ refiner = DiffusionPipeline.from_pretrained(
16
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
17
+ text_encoder_2=pipe.text_encoder_2,
18
+ vae=pipe.vae,
19
+ torch_dtype=torch.float16,
20
+ use_safetensors=True,
21
+ variant="fp16",
22
+ )
23
+ refiner.to("cuda")
24
+
25
+ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
26
+ use_refiner = True
27
+ with log_time('diffuse'):
28
+ with torch.inference_mode():
29
+ image = pipe(prompt=prompt, output_type="latent" if use_refiner else "pil").images[0]
30
+ # experiment try deleting a whole bunch of pixels and see if the refiner can recreate them
31
+ # delete top 30% of pixels
32
+ # image = image[0:0.7]
33
+ #pixels to delete
34
+ # pixels_to_delete = int(0.3 * 1024)
35
+ # delete top 30% of pixels
36
+ # image.save("latent.png")
37
+ # image_data = PIL.Image.fromarray(image)
38
+ # image_data.save("latent.png")
39
+
40
+ # image = np.array(image)
41
+ pixels_to_delete = int(0.3 * image.shape[0])
42
+ idx_to_delete = np.ones(image.shape[0], dtype=bool, device="cuda")
43
+ idx_to_delete[:pixels_to_delete] = False
44
+ image[idx_to_delete] = [0,0,0]
45
+
46
+ # image_data = PIL.Image.fromarray(image)
47
+ # image_data.save("latentcleared.png")
48
+
49
+
50
+ image = refiner(prompt=prompt, image=image[None, :]).images[0]
51
+
52
+
53
+
img/inpaint.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import StableDiffusionXLInpaintPipeline
4
+ from diffusers.utils import load_image
5
+
6
+ from stable_diffusion_server.utils import log_time
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+
11
+ pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
12
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
13
+ )
14
+ pipe.to("cuda")
15
+
16
+ refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
17
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
18
+ text_encoder_2=pipe.text_encoder_2,
19
+ vae=pipe.vae,
20
+ torch_dtype=torch.float16,
21
+ use_safetensors=True,
22
+ variant="fp16",
23
+ )
24
+ refiner.to("cuda")
25
+
26
+ img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
27
+ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
28
+ # inpaint_and_upload_image?prompt=majestic tiger sitting on a bench&image_url=https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png&mask_url=https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png&save_path=tests/inpaint.webp
29
+ # inpainting can be used to upscale to 1080p
30
+
31
+
32
+ init_image = load_image(img_url).convert("RGB")
33
+ # mask_image = load_image(mask_url).convert("RGB")
34
+ # mask image all ones same shape as init_image
35
+
36
+ # here's a failed experiment: inpainting cannot be used as style transfer/it doesnt recreate ain image doing a full mask in this way
37
+ image_size = init_image.size
38
+ ones_of_size = np.ones(image_size, np.uint8) * 255
39
+ mask_image = PIL.Image.fromarray(ones_of_size.astype(np.uint8))
40
+ # mask_image = torch.ones_like(init_image) * 255
41
+ prompt = "A majestic tiger sitting on a bench, castle backdrop elegent anime"
42
+ num_inference_steps = 75
43
+ high_noise_frac = 0.7
44
+ with log_time("inpaint"):
45
+ with torch.inference_mode():
46
+ image = pipe(
47
+ prompt=prompt,
48
+ image=init_image,
49
+ mask_image=mask_image,
50
+ num_inference_steps=num_inference_steps,
51
+ denoising_start=high_noise_frac,
52
+ output_type="latent",
53
+ ).images
54
+ image = refiner(
55
+ prompt=prompt,
56
+ image=image,
57
+ mask_image=mask_image,
58
+ num_inference_steps=num_inference_steps,
59
+ denoising_start=high_noise_frac,
60
+ ).images[0]
61
+
62
+ image.save("inpaintfull.png")
img/log.0925 ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ v-haipe+ 551 16041 99 08:16 pts/2 00:00:17 python LiLa/gsm8k_cluster.py
2
+ v-haipe+ 9211 10235 3 Sep24 pts/10 00:32:12 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 0 --end 2000
3
+ v-haipe+ 9288 10459 3 Sep24 pts/11 00:28:30 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 2000 --end 4000
4
+ v-haipe+ 9310 10667 3 Sep24 pts/12 00:27:45 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 4000 --end 6000
5
+ v-haipe+ 9341 10865 3 Sep24 pts/13 00:26:50 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 6000 --end 8000
6
+ v-haipe+ 9379 25248 3 Sep24 pts/16 00:27:01 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 8000 --end 10000
7
+ v-haipe+ 9410 25467 3 Sep24 pts/17 00:27:17 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 10000 --end 12000
8
+ v-haipe+ 9438 26561 3 Sep24 pts/19 00:27:17 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 12000 --end 14000
9
+ v-haipe+ 9469 26761 3 Sep24 pts/20 00:26:55 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 14000 --end 16000
10
+ v-haipe+ 9500 26968 3 Sep24 pts/21 00:27:09 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 16000 --end 18000
11
+ v-haipe+ 9531 27172 3 Sep24 pts/22 00:29:29 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 18000 --end 20000
12
+ v-haipe+ 9775 9560 3 Sep24 pts/29 00:30:29 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 20000 --end 22000
13
+ v-haipe+ 11262 24577 0 Sep23 pts/8 00:00:06 python app.py
14
+ v-haipe+ 11300 11262 0 Sep23 pts/8 00:20:54 /home/v-haipengluo/.conda/envs/wizardweb/bin/python /workspaceblobstore/qins/test/20220316/kai/research/code_repo/wizard_verse/code_repo/server_code/wizard_verse/lm/server_lm/app.py
15
+ v-haipe+ 11604 20782 98 Sep23 pts/4 2-00:06:57 python -m vllm.entrypoints.api_server --model /workspaceblobstore/caxu/trained_models/13Bv2_497kcontinueroleplay_dsys_2048_e4_2e_5/checkpoint-75 --host phlrr3006.guest.corp.microsoft.com --port 7991
16
+ v-haipe+ 13722 22601 0 Sep24 pts/6 00:09:37 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
17
+ v-haipe+ 13830 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
18
+ v-haipe+ 13834 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
19
+ v-haipe+ 13837 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
20
+ v-haipe+ 13839 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
21
+ v-haipe+ 13841 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
22
+ v-haipe+ 13843 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
23
+ v-haipe+ 13845 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
24
+ v-haipe+ 13847 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
25
+ v-haipe+ 13849 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
26
+ v-haipe+ 13851 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
27
+ v-haipe+ 13853 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
28
+ v-haipe+ 13855 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
29
+ v-haipe+ 13857 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
30
+ v-haipe+ 13859 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
31
+ v-haipe+ 13861 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
32
+ v-haipe+ 13863 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
33
+ v-haipe+ 13865 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
34
+ v-haipe+ 13867 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
35
+ v-haipe+ 13869 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
36
+ v-haipe+ 13871 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
37
+ v-haipe+ 13873 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
38
+ v-haipe+ 13875 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
39
+ v-haipe+ 13877 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
40
+ v-haipe+ 13879 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
41
+ v-haipe+ 13881 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
42
+ v-haipe+ 13883 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
43
+ v-haipe+ 13885 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
44
+ v-haipe+ 13887 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
45
+ v-haipe+ 18319 15852 0 05:34 pts/1 00:00:03 /home/v-haipengluo/.conda/envs/llamax/bin/python /home/v-haipengluo/.conda/envs/llamax/bin/deepspeed --master_port 29500 --hostfile=hostfile --include=localhost:1,3,4,5,6,7 src/train.py --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
46
+ v-haipe+ 18333 18319 0 05:34 pts/1 00:00:03 /home/v-haipengluo/.conda/envs/llamax/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMSwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/train.py --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
47
+ v-haipe+ 18346 18333 99 05:34 pts/1 03:20:42 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=0 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
48
+ v-haipe+ 18347 18333 99 05:34 pts/1 03:40:59 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=1 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
49
+ v-haipe+ 18348 18333 99 05:34 pts/1 03:44:08 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=2 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
50
+ v-haipe+ 18349 18333 99 05:34 pts/1 03:32:51 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=3 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
51
+ v-haipe+ 18350 18333 99 05:34 pts/1 03:41:16 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=4 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
52
+ v-haipe+ 18351 18333 99 05:34 pts/1 03:42:27 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=5 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
53
+ v-haipe+ 24334 23818 0 Sep23 pts/7 00:00:25 python -m http.server
img/main.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ pipe = DiffusionPipeline.from_pretrained(
45
+ "models/stable-diffusion-xl-base-1.0",
46
+ torch_dtype=torch.bfloat16,
47
+ use_safetensors=True,
48
+ variant="fp16",
49
+ # safety_checker=None,
50
+ ) # todo try torch_dtype=bfloat16
51
+ pipe.watermark = None
52
+
53
+ pipe.to("cuda")
54
+
55
+ refiner = DiffusionPipeline.from_pretrained(
56
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
57
+ text_encoder_2=pipe.text_encoder_2,
58
+ vae=pipe.vae,
59
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
60
+ use_safetensors=True,
61
+ variant="fp16", #remember not to download the big model
62
+ )
63
+ refiner.watermark = None
64
+ refiner.to("cuda")
65
+
66
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
67
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
68
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
69
+ scheduler=pipe.scheduler,
70
+ text_encoder=pipe.text_encoder,
71
+ text_encoder_2=pipe.text_encoder_2,
72
+ tokenizer=pipe.tokenizer,
73
+ tokenizer_2=pipe.tokenizer_2,
74
+ unet=pipe.unet,
75
+ vae=pipe.vae,
76
+ # load_connected_pipeline=
77
+ )
78
+ # # switch out to save gpu mem
79
+ # del inpaintpipe.vae
80
+ # del inpaintpipe.text_encoder_2
81
+ # del inpaintpipe.text_encoder
82
+ # del inpaintpipe.scheduler
83
+ # del inpaintpipe.tokenizer
84
+ # del inpaintpipe.tokenizer_2
85
+ # del inpaintpipe.unet
86
+ # inpaintpipe.vae = pipe.vae
87
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
88
+ # inpaintpipe.text_encoder = pipe.text_encoder
89
+ # inpaintpipe.scheduler = pipe.scheduler
90
+ # inpaintpipe.tokenizer = pipe.tokenizer
91
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
92
+ # inpaintpipe.unet = pipe.unet
93
+ # todo this should work
94
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
95
+ # vae=pipe.vae,
96
+ # text_encoder_2=pipe.text_encoder_2,
97
+ # text_encoder=pipe.text_encoder,
98
+ # unet=pipe.unet,
99
+ # scheduler=pipe.scheduler,
100
+ # tokenizer=pipe.tokenizer,
101
+ # tokenizer_2=pipe.tokenizer_2,
102
+ # requires_aesthetics_score=False,
103
+ # )
104
+ inpaintpipe.to("cuda")
105
+ inpaintpipe.watermark = None
106
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
107
+
108
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
109
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
110
+ text_encoder_2=inpaintpipe.text_encoder_2,
111
+ vae=inpaintpipe.vae,
112
+ torch_dtype=torch.bfloat16,
113
+ use_safetensors=True,
114
+ variant="fp16",
115
+
116
+ tokenizer_2=refiner.tokenizer_2,
117
+ tokenizer=refiner.tokenizer,
118
+ scheduler=refiner.scheduler,
119
+ text_encoder=refiner.text_encoder,
120
+ unet=refiner.unet,
121
+ )
122
+ # del inpaint_refiner.vae
123
+ # del inpaint_refiner.text_encoder_2
124
+ # del inpaint_refiner.text_encoder
125
+ # del inpaint_refiner.scheduler
126
+ # del inpaint_refiner.tokenizer
127
+ # del inpaint_refiner.tokenizer_2
128
+ # del inpaint_refiner.unet
129
+ # inpaint_refiner.vae = inpaintpipe.vae
130
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
131
+ #
132
+ # inpaint_refiner.text_encoder = refiner.text_encoder
133
+ # inpaint_refiner.scheduler = refiner.scheduler
134
+ # inpaint_refiner.tokenizer = refiner.tokenizer
135
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
136
+ # inpaint_refiner.unet = refiner.unet
137
+
138
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
139
+ # text_encoder_2=inpaintpipe.text_encoder_2,
140
+ # vae=inpaintpipe.vae,
141
+ # # the rest from the existing refiner
142
+ # tokenizer_2=refiner.tokenizer_2,
143
+ # tokenizer=refiner.tokenizer,
144
+ # scheduler=refiner.scheduler,
145
+ # text_encoder=refiner.text_encoder,
146
+ # unet=refiner.unet,
147
+ # requires_aesthetics_score=False,
148
+ # )
149
+ inpaint_refiner.to("cuda")
150
+ inpaint_refiner.watermark = None
151
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
152
+
153
+ n_steps = 40
154
+ high_noise_frac = 0.8
155
+
156
+ # if using torch < 2.0
157
+ # pipe.enable_xformers_memory_efficient_attention()
158
+
159
+
160
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
161
+ # this can cause errors on some inputs so consider disabling it
162
+ pipe.unet = torch.compile(pipe.unet)
163
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
164
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
165
+ inpaintpipe.unet = pipe.unet
166
+ inpaint_refiner.unet = refiner.unet
167
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
168
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
169
+ from pydantic import BaseModel
170
+
171
+ app = FastAPI(
172
+ openapi_url="/static/openapi.json",
173
+ docs_url="/swagger-docs",
174
+ redoc_url="/redoc",
175
+ title="Generate Images Netwrck API",
176
+ description="Character Chat API",
177
+ # root_path="https://api.text-generator.io",
178
+ version="1",
179
+ )
180
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
181
+ app.add_middleware(
182
+ CORSMiddleware,
183
+ allow_origins=["*"],
184
+ allow_credentials=True,
185
+ allow_methods=["*"],
186
+ allow_headers=["*"],
187
+ )
188
+
189
+ stopwords = nltk.corpus.stopwords.words("english")
190
+
191
+ class Img(BaseModel):
192
+ system_prompt: str
193
+ ASSISTANT: str
194
+
195
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
196
+ img_url = "http://phlrr3058.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
197
+
198
+ @app.post("/image_url")
199
+ def image_url(img: Img):
200
+ system_prompt = img.system_prompt
201
+ prompt = img.ASSISTANT
202
+ # if Path(save_path).exists():
203
+ # return FileResponse(save_path, media_type="image/png")
204
+ # return JSONResponse({"path": path})
205
+ image = pipe(prompt=prompt).images[0]
206
+ # if not save_path:
207
+ save_path = generate_save_path()
208
+ save_path = f"images/{save_path}.png"
209
+ image.save(save_path)
210
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
211
+ path = f"{img_url}/{save_path}"
212
+ return JSONResponse({"path": path})
213
+
214
+
215
+ @app.get("/make_image")
216
+ # @app.post("/make_image")
217
+ def make_image(prompt: str, save_path: str = ""):
218
+ if Path(save_path).exists():
219
+ return FileResponse(save_path, media_type="image/png")
220
+ image = pipe(prompt=prompt).images[0]
221
+ if not save_path:
222
+ save_path = f"images/{prompt}.png"
223
+ image.save(save_path)
224
+ return FileResponse(save_path, media_type="image/png")
225
+
226
+
227
+ @app.get("/create_and_upload_image")
228
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
229
+ path_components = save_path.split("/")[0:-1]
230
+ final_name = save_path.split("/")[-1]
231
+ if not path_components:
232
+ path_components = []
233
+ save_path = '/'.join(path_components) + quote_plus(final_name)
234
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
235
+ return JSONResponse({"path": path})
236
+
237
+ @app.get("/inpaint_and_upload_image")
238
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
239
+ path_components = save_path.split("/")[0:-1]
240
+ final_name = save_path.split("/")[-1]
241
+ if not path_components:
242
+ path_components = []
243
+ save_path = '/'.join(path_components) + quote_plus(final_name)
244
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
245
+ return JSONResponse({"path": path})
246
+
247
+
248
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
249
+ prompt = shorten_too_long_text(prompt)
250
+ save_path = shorten_too_long_text(save_path)
251
+ # check exists - todo cache this
252
+ if check_if_blob_exists(save_path):
253
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
254
+ bio = create_image_from_prompt(prompt, width, height)
255
+ if bio is None:
256
+ return None # error thrown in pool
257
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
258
+ return link
259
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
260
+ prompt = shorten_too_long_text(prompt)
261
+ save_path = shorten_too_long_text(save_path)
262
+ # check exists - todo cache this
263
+ if check_if_blob_exists(save_path):
264
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
265
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
266
+ if bio is None:
267
+ return None # error thrown in pool
268
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
269
+ return link
270
+
271
+ # multiprocessing.set_start_method('spawn', True)
272
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
273
+ # def create_image_from_prompt_sync(prompt):
274
+ # """have to call this sync to avoid OOM errors"""
275
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
276
+
277
+ def create_image_from_prompt(prompt, width, height):
278
+ # round width and height down to multiple of 64
279
+ block_width = width - (width % 64)
280
+ block_height = height - (height % 64)
281
+ prompt = shorten_too_long_text(prompt)
282
+ # image = pipe(prompt=prompt).images[0]
283
+ try:
284
+ image = pipe(prompt=prompt,
285
+ width=block_width,
286
+ height=block_height,
287
+ # denoising_end=high_noise_frac,
288
+ # output_type='latent',
289
+ # height=512,
290
+ # width=512,
291
+ num_inference_steps=50).images[0] # normally uses 50 steps
292
+ except Exception as e:
293
+ # try rm stopwords + half the prompt
294
+ # todo try prompt permutations
295
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
296
+
297
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
298
+ prompts = prompt.split()
299
+
300
+ prompt = ' '.join(prompts[:len(prompts) // 2])
301
+ logger.info(f"shortened prompt to: {len(prompt)}")
302
+ image = None
303
+ if prompt:
304
+ try:
305
+ image = pipe(prompt=prompt,
306
+ width=block_width,
307
+ height=block_height,
308
+ # denoising_end=high_noise_frac,
309
+ # output_type='latent',
310
+ # height=512,
311
+ # width=512,
312
+ num_inference_steps=50).images[0] # normally uses 50 steps
313
+ except Exception as e:
314
+ # logger.info("trying to permute prompt")
315
+ # # try two swaps of the prompt/permutations
316
+ # prompt = prompt.split()
317
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
318
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
319
+
320
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
321
+ prompts = prompt.split()
322
+
323
+ prompt = ' '.join(prompts[:len(prompts) // 2])
324
+ logger.info(f"shortened prompt to: {len(prompt)}")
325
+
326
+ try:
327
+ image = pipe(prompt=prompt,
328
+ width=block_width,
329
+ height=block_height,
330
+ # denoising_end=high_noise_frac,
331
+ # output_type='latent', # dont need latent yet - we refine the image at full res
332
+ # height=512,
333
+ # width=512,
334
+ num_inference_steps=50).images[0] # normally uses 50 steps
335
+ except Exception as e:
336
+ # just error out
337
+ traceback.print_exc()
338
+ raise e
339
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
340
+ # todo fix device side asserts instead of restart to fix
341
+ # todo only restart the correct gunicorn
342
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
343
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
344
+ # os.system("kill -1 `pgrep gunicorn`")
345
+ # todo refine
346
+ # if image != None:
347
+ # image = refiner(
348
+ # prompt=prompt,
349
+ # # width=block_width,
350
+ # # height=block_height,
351
+ # num_inference_steps=n_steps,
352
+ # # denoising_start=high_noise_frac,
353
+ # image=image,
354
+ # ).images[0]
355
+ if width != block_width or height != block_height:
356
+ # resize to original size width/height
357
+ # find aspect ratio to scale up to that covers the original img input width/height
358
+ scale_up_ratio = max(width / block_width, height / block_height)
359
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
360
+ # crop image to original size
361
+ image = image.crop((0, 0, width, height))
362
+ # try:
363
+ # # gc.collect()
364
+ # torch.cuda.empty_cache()
365
+ # except Exception as e:
366
+ # traceback.print_exc()
367
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
368
+ # # todo fix device side asserts instead of restart to fix
369
+ # # todo only restart the correct gunicorn
370
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
371
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
372
+ # os.system("kill -1 `pgrep gunicorn`")
373
+ # save as bytesio
374
+ bs = BytesIO()
375
+
376
+ bright_count = np.sum(np.array(image) > 0)
377
+ if bright_count == 0:
378
+ # we have a black image, this is an error likely we need a restart
379
+ logger.info("restarting server to fix cuda issues (device side asserts)")
380
+ # # todo fix device side asserts instead of restart to fix
381
+ # # todo only restart the correct gunicorn
382
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
383
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
384
+ os.system("kill -1 `pgrep gunicorn`")
385
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
386
+ os.system("kill -1 `pgrep uvicorn`")
387
+
388
+ return None
389
+ image.save(bs, quality=85, optimize=True, format="webp")
390
+ bio = bs.getvalue()
391
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
392
+ with open("progress.txt", "w") as f:
393
+ current_time = datetime.now().strftime("%H:%M:%S")
394
+ f.write(f"{current_time}")
395
+ return bio
396
+
397
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
398
+ prompt = shorten_too_long_text(prompt)
399
+ # image = pipe(prompt=prompt).images[0]
400
+
401
+ init_image = load_image(image_url).convert("RGB")
402
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
403
+ num_inference_steps = 75
404
+ high_noise_frac = 0.7
405
+
406
+ try:
407
+ image = inpaintpipe(
408
+ prompt=prompt,
409
+ image=init_image,
410
+ mask_image=mask_image,
411
+ num_inference_steps=num_inference_steps,
412
+ denoising_start=high_noise_frac,
413
+ output_type="latent",
414
+ ).images[0] # normally uses 50 steps
415
+ except Exception as e:
416
+ # try rm stopwords + half the prompt
417
+ # todo try prompt permutations
418
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
419
+
420
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
421
+ prompts = prompt.split()
422
+
423
+ prompt = ' '.join(prompts[:len(prompts) // 2])
424
+ logger.info(f"shortened prompt to: {len(prompt)}")
425
+ image = None
426
+ if prompt:
427
+ try:
428
+ image = pipe(
429
+ prompt=prompt,
430
+ image=init_image,
431
+ mask_image=mask_image,
432
+ num_inference_steps=num_inference_steps,
433
+ denoising_start=high_noise_frac,
434
+ output_type="latent",
435
+ ).images[0] # normally uses 50 steps
436
+ except Exception as e:
437
+ # logger.info("trying to permute prompt")
438
+ # # try two swaps of the prompt/permutations
439
+ # prompt = prompt.split()
440
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
441
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
442
+
443
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
444
+ prompts = prompt.split()
445
+
446
+ prompt = ' '.join(prompts[:len(prompts) // 2])
447
+ logger.info(f"shortened prompt to: {len(prompt)}")
448
+
449
+ try:
450
+ image = inpaintpipe(
451
+ prompt=prompt,
452
+ image=init_image,
453
+ mask_image=mask_image,
454
+ num_inference_steps=num_inference_steps,
455
+ denoising_start=high_noise_frac,
456
+ output_type="latent",
457
+ ).images[0] # normally uses 50 steps
458
+ except Exception as e:
459
+ # just error out
460
+ traceback.print_exc()
461
+ raise e
462
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
463
+ # todo fix device side asserts instead of restart to fix
464
+ # todo only restart the correct gunicorn
465
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
466
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
467
+ # os.system("kill -1 `pgrep gunicorn`")
468
+ if image != None:
469
+ image = inpaint_refiner(
470
+ prompt=prompt,
471
+ image=image,
472
+ mask_image=mask_image,
473
+ num_inference_steps=num_inference_steps,
474
+ denoising_start=high_noise_frac,
475
+
476
+ ).images[0]
477
+ # try:
478
+ # # gc.collect()
479
+ # torch.cuda.empty_cache()
480
+ # except Exception as e:
481
+ # traceback.print_exc()
482
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
483
+ # # todo fix device side asserts instead of restart to fix
484
+ # # todo only restart the correct gunicorn
485
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
486
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
487
+ # os.system("kill -1 `pgrep gunicorn`")
488
+ # save as bytesio
489
+ bs = BytesIO()
490
+
491
+ bright_count = np.sum(np.array(image) > 0)
492
+ if bright_count == 0:
493
+ # we have a black image, this is an error likely we need a restart
494
+ logger.info("restarting server to fix cuda issues (device side asserts)")
495
+ # # todo fix device side asserts instead of restart to fix
496
+ # # todo only restart the correct gunicorn
497
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
498
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
499
+ os.system("kill -1 `pgrep gunicorn`")
500
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
501
+ os.system("kill -1 `pgrep uvicorn`")
502
+
503
+ return None
504
+ image.save(bs, quality=85, optimize=True, format="webp")
505
+ bio = bs.getvalue()
506
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
507
+ with open("progress.txt", "w") as f:
508
+ current_time = datetime.now().strftime("%H:%M:%S")
509
+ f.write(f"{current_time}")
510
+ return bio
511
+
512
+
513
+
514
+ def shorten_too_long_text(prompt):
515
+ if len(prompt) > 200:
516
+ # remove stopwords
517
+ prompt = prompt.split() # todo also split hyphens
518
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
519
+ if len(prompt) > 200:
520
+ prompt = prompt[:200]
521
+ return prompt
522
+
523
+ # image = pipe(prompt=prompt).images[0]
524
+ #
525
+ # image.save("test.png")
526
+ # # save all images
527
+ # for i, image in enumerate(images):
528
+ # image.save(f"{i}.png")
img/main_1024.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3058.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+
214
+ @app.post("/image_url")
215
+ def image_url(img: Img):
216
+ system_prompt = img.system_prompt
217
+ prompt = img.ASSISTANT
218
+ # if Path(save_path).exists():
219
+ # return FileResponse(save_path, media_type="image/png")
220
+ # return JSONResponse({"path": path})
221
+ # image = pipe(prompt=prompt).images[0]
222
+ g = torch.Generator(device="cuda")
223
+ # image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
224
+ image = pipe(prompt=prompt, width=1024, height=1024).images[0]
225
+
226
+ # if not save_path:
227
+ save_path = generate_save_path()
228
+ save_path = f"images/{save_path}.png"
229
+ image.save(save_path)
230
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
231
+ path = f"{img_url}/{save_path}"
232
+ return JSONResponse({"path": path})
233
+
234
+
235
+ @app.get("/make_image")
236
+ # @app.post("/make_image")
237
+ def make_image(prompt: str, save_path: str = ""):
238
+ if Path(save_path).exists():
239
+ return FileResponse(save_path, media_type="image/png")
240
+ image = pipe(prompt=prompt).images[0]
241
+ if not save_path:
242
+ save_path = f"images/{prompt}.png"
243
+ image.save(save_path)
244
+ return FileResponse(save_path, media_type="image/png")
245
+
246
+
247
+ @app.get("/create_and_upload_image")
248
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
249
+ path_components = save_path.split("/")[0:-1]
250
+ final_name = save_path.split("/")[-1]
251
+ if not path_components:
252
+ path_components = []
253
+ save_path = '/'.join(path_components) + quote_plus(final_name)
254
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
255
+ return JSONResponse({"path": path})
256
+
257
+ @app.get("/inpaint_and_upload_image")
258
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
259
+ path_components = save_path.split("/")[0:-1]
260
+ final_name = save_path.split("/")[-1]
261
+ if not path_components:
262
+ path_components = []
263
+ save_path = '/'.join(path_components) + quote_plus(final_name)
264
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
265
+ return JSONResponse({"path": path})
266
+
267
+
268
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
269
+ prompt = shorten_too_long_text(prompt)
270
+ save_path = shorten_too_long_text(save_path)
271
+ # check exists - todo cache this
272
+ if check_if_blob_exists(save_path):
273
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
274
+ bio = create_image_from_prompt(prompt, width, height)
275
+ if bio is None:
276
+ return None # error thrown in pool
277
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
278
+ return link
279
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
280
+ prompt = shorten_too_long_text(prompt)
281
+ save_path = shorten_too_long_text(save_path)
282
+ # check exists - todo cache this
283
+ if check_if_blob_exists(save_path):
284
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
285
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
286
+ if bio is None:
287
+ return None # error thrown in pool
288
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
289
+ return link
290
+
291
+ # multiprocessing.set_start_method('spawn', True)
292
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
293
+ # def create_image_from_prompt_sync(prompt):
294
+ # """have to call this sync to avoid OOM errors"""
295
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
296
+
297
+ def create_image_from_prompt(prompt, width, height):
298
+ # round width and height down to multiple of 64
299
+ block_width = width - (width % 64)
300
+ block_height = height - (height % 64)
301
+ prompt = shorten_too_long_text(prompt)
302
+ # image = pipe(prompt=prompt).images[0]
303
+ try:
304
+ image = pipe(prompt=prompt,
305
+ width=block_width,
306
+ height=block_height,
307
+ # denoising_end=high_noise_frac,
308
+ # output_type='latent',
309
+ # height=512,
310
+ # width=512,
311
+ num_inference_steps=50).images[0] # normally uses 50 steps
312
+ except Exception as e:
313
+ # try rm stopwords + half the prompt
314
+ # todo try prompt permutations
315
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
316
+
317
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
318
+ prompts = prompt.split()
319
+
320
+ prompt = ' '.join(prompts[:len(prompts) // 2])
321
+ logger.info(f"shortened prompt to: {len(prompt)}")
322
+ image = None
323
+ if prompt:
324
+ try:
325
+ image = pipe(prompt=prompt,
326
+ width=block_width,
327
+ height=block_height,
328
+ # denoising_end=high_noise_frac,
329
+ # output_type='latent',
330
+ # height=512,
331
+ # width=512,
332
+ num_inference_steps=50).images[0] # normally uses 50 steps
333
+ except Exception as e:
334
+ # logger.info("trying to permute prompt")
335
+ # # try two swaps of the prompt/permutations
336
+ # prompt = prompt.split()
337
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
338
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
339
+
340
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
341
+ prompts = prompt.split()
342
+
343
+ prompt = ' '.join(prompts[:len(prompts) // 2])
344
+ logger.info(f"shortened prompt to: {len(prompt)}")
345
+
346
+ try:
347
+ image = pipe(prompt=prompt,
348
+ width=block_width,
349
+ height=block_height,
350
+ # denoising_end=high_noise_frac,
351
+ # output_type='latent', # dont need latent yet - we refine the image at full res
352
+ # height=512,
353
+ # width=512,
354
+ num_inference_steps=50).images[0] # normally uses 50 steps
355
+ except Exception as e:
356
+ # just error out
357
+ traceback.print_exc()
358
+ raise e
359
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
360
+ # todo fix device side asserts instead of restart to fix
361
+ # todo only restart the correct gunicorn
362
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
363
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
364
+ # os.system("kill -1 `pgrep gunicorn`")
365
+ # todo refine
366
+ # if image != None:
367
+ # image = refiner(
368
+ # prompt=prompt,
369
+ # # width=block_width,
370
+ # # height=block_height,
371
+ # num_inference_steps=n_steps,
372
+ # # denoising_start=high_noise_frac,
373
+ # image=image,
374
+ # ).images[0]
375
+ if width != block_width or height != block_height:
376
+ # resize to original size width/height
377
+ # find aspect ratio to scale up to that covers the original img input width/height
378
+ scale_up_ratio = max(width / block_width, height / block_height)
379
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
380
+ # crop image to original size
381
+ image = image.crop((0, 0, width, height))
382
+ # try:
383
+ # # gc.collect()
384
+ # torch.cuda.empty_cache()
385
+ # except Exception as e:
386
+ # traceback.print_exc()
387
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
388
+ # # todo fix device side asserts instead of restart to fix
389
+ # # todo only restart the correct gunicorn
390
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
391
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
392
+ # os.system("kill -1 `pgrep gunicorn`")
393
+ # save as bytesio
394
+ bs = BytesIO()
395
+
396
+ bright_count = np.sum(np.array(image) > 0)
397
+ if bright_count == 0:
398
+ # we have a black image, this is an error likely we need a restart
399
+ logger.info("restarting server to fix cuda issues (device side asserts)")
400
+ # # todo fix device side asserts instead of restart to fix
401
+ # # todo only restart the correct gunicorn
402
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
403
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
404
+ os.system("kill -1 `pgrep gunicorn`")
405
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
406
+ os.system("kill -1 `pgrep uvicorn`")
407
+
408
+ return None
409
+ image.save(bs, quality=85, optimize=True, format="webp")
410
+ bio = bs.getvalue()
411
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
412
+ with open("progress.txt", "w") as f:
413
+ current_time = datetime.now().strftime("%H:%M:%S")
414
+ f.write(f"{current_time}")
415
+ return bio
416
+
417
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
418
+ prompt = shorten_too_long_text(prompt)
419
+ # image = pipe(prompt=prompt).images[0]
420
+
421
+ init_image = load_image(image_url).convert("RGB")
422
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
423
+ num_inference_steps = 75
424
+ high_noise_frac = 0.7
425
+
426
+ try:
427
+ image = inpaintpipe(
428
+ prompt=prompt,
429
+ image=init_image,
430
+ mask_image=mask_image,
431
+ num_inference_steps=num_inference_steps,
432
+ denoising_start=high_noise_frac,
433
+ output_type="latent",
434
+ ).images[0] # normally uses 50 steps
435
+ except Exception as e:
436
+ # try rm stopwords + half the prompt
437
+ # todo try prompt permutations
438
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
439
+
440
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
441
+ prompts = prompt.split()
442
+
443
+ prompt = ' '.join(prompts[:len(prompts) // 2])
444
+ logger.info(f"shortened prompt to: {len(prompt)}")
445
+ image = None
446
+ if prompt:
447
+ try:
448
+ image = pipe(
449
+ prompt=prompt,
450
+ image=init_image,
451
+ mask_image=mask_image,
452
+ num_inference_steps=num_inference_steps,
453
+ denoising_start=high_noise_frac,
454
+ output_type="latent",
455
+ ).images[0] # normally uses 50 steps
456
+ except Exception as e:
457
+ # logger.info("trying to permute prompt")
458
+ # # try two swaps of the prompt/permutations
459
+ # prompt = prompt.split()
460
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
461
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
462
+
463
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
464
+ prompts = prompt.split()
465
+
466
+ prompt = ' '.join(prompts[:len(prompts) // 2])
467
+ logger.info(f"shortened prompt to: {len(prompt)}")
468
+
469
+ try:
470
+ image = inpaintpipe(
471
+ prompt=prompt,
472
+ image=init_image,
473
+ mask_image=mask_image,
474
+ num_inference_steps=num_inference_steps,
475
+ denoising_start=high_noise_frac,
476
+ output_type="latent",
477
+ ).images[0] # normally uses 50 steps
478
+ except Exception as e:
479
+ # just error out
480
+ traceback.print_exc()
481
+ raise e
482
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
483
+ # todo fix device side asserts instead of restart to fix
484
+ # todo only restart the correct gunicorn
485
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
486
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
487
+ # os.system("kill -1 `pgrep gunicorn`")
488
+ if image != None:
489
+ image = inpaint_refiner(
490
+ prompt=prompt,
491
+ image=image,
492
+ mask_image=mask_image,
493
+ num_inference_steps=num_inference_steps,
494
+ denoising_start=high_noise_frac,
495
+
496
+ ).images[0]
497
+ # try:
498
+ # # gc.collect()
499
+ # torch.cuda.empty_cache()
500
+ # except Exception as e:
501
+ # traceback.print_exc()
502
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
503
+ # # todo fix device side asserts instead of restart to fix
504
+ # # todo only restart the correct gunicorn
505
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
506
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
507
+ # os.system("kill -1 `pgrep gunicorn`")
508
+ # save as bytesio
509
+ bs = BytesIO()
510
+
511
+ bright_count = np.sum(np.array(image) > 0)
512
+ if bright_count == 0:
513
+ # we have a black image, this is an error likely we need a restart
514
+ logger.info("restarting server to fix cuda issues (device side asserts)")
515
+ # # todo fix device side asserts instead of restart to fix
516
+ # # todo only restart the correct gunicorn
517
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
518
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
519
+ os.system("kill -1 `pgrep gunicorn`")
520
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
521
+ os.system("kill -1 `pgrep uvicorn`")
522
+
523
+ return None
524
+ image.save(bs, quality=85, optimize=True, format="webp")
525
+ bio = bs.getvalue()
526
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
527
+ with open("progress.txt", "w") as f:
528
+ current_time = datetime.now().strftime("%H:%M:%S")
529
+ f.write(f"{current_time}")
530
+ return bio
531
+
532
+
533
+
534
+ def shorten_too_long_text(prompt):
535
+ if len(prompt) > 200:
536
+ # remove stopwords
537
+ prompt = prompt.split() # todo also split hyphens
538
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
539
+ if len(prompt) > 200:
540
+ prompt = prompt[:200]
541
+ return prompt
542
+
543
+ # image = pipe(prompt=prompt).images[0]
544
+ #
545
+ # image.save("test.png")
546
+ # # save all images
547
+ # for i, image in enumerate(images):
548
+ # image.save(f"{i}.png")
549
+
img/main_v2.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+
214
+ @app.post("/image_url")
215
+ def image_url(img: Img):
216
+ system_prompt = img.system_prompt
217
+ prompt = img.ASSISTANT
218
+ # if Path(save_path).exists():
219
+ # return FileResponse(save_path, media_type="image/png")
220
+ # return JSONResponse({"path": path})
221
+ # image = pipe(prompt=prompt).images[0]
222
+ g = torch.Generator(device="cuda")
223
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
224
+
225
+ # if not save_path:
226
+ save_path = generate_save_path()
227
+ save_path = f"images/{save_path}.png"
228
+ image.save(save_path)
229
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
230
+ path = f"{img_url}/{save_path}"
231
+ return JSONResponse({"path": path})
232
+
233
+
234
+ @app.get("/make_image")
235
+ # @app.post("/make_image")
236
+ def make_image(prompt: str, save_path: str = ""):
237
+ if Path(save_path).exists():
238
+ return FileResponse(save_path, media_type="image/png")
239
+ image = pipe(prompt=prompt).images[0]
240
+ if not save_path:
241
+ save_path = f"images/{prompt}.png"
242
+ image.save(save_path)
243
+ return FileResponse(save_path, media_type="image/png")
244
+
245
+
246
+ @app.get("/create_and_upload_image")
247
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
248
+ path_components = save_path.split("/")[0:-1]
249
+ final_name = save_path.split("/")[-1]
250
+ if not path_components:
251
+ path_components = []
252
+ save_path = '/'.join(path_components) + quote_plus(final_name)
253
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
254
+ return JSONResponse({"path": path})
255
+
256
+ @app.get("/inpaint_and_upload_image")
257
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
258
+ path_components = save_path.split("/")[0:-1]
259
+ final_name = save_path.split("/")[-1]
260
+ if not path_components:
261
+ path_components = []
262
+ save_path = '/'.join(path_components) + quote_plus(final_name)
263
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
264
+ return JSONResponse({"path": path})
265
+
266
+
267
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
268
+ prompt = shorten_too_long_text(prompt)
269
+ save_path = shorten_too_long_text(save_path)
270
+ # check exists - todo cache this
271
+ if check_if_blob_exists(save_path):
272
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
273
+ bio = create_image_from_prompt(prompt, width, height)
274
+ if bio is None:
275
+ return None # error thrown in pool
276
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
277
+ return link
278
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
279
+ prompt = shorten_too_long_text(prompt)
280
+ save_path = shorten_too_long_text(save_path)
281
+ # check exists - todo cache this
282
+ if check_if_blob_exists(save_path):
283
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
284
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
285
+ if bio is None:
286
+ return None # error thrown in pool
287
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
288
+ return link
289
+
290
+ # multiprocessing.set_start_method('spawn', True)
291
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
292
+ # def create_image_from_prompt_sync(prompt):
293
+ # """have to call this sync to avoid OOM errors"""
294
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
295
+
296
+ def create_image_from_prompt(prompt, width, height):
297
+ # round width and height down to multiple of 64
298
+ block_width = width - (width % 64)
299
+ block_height = height - (height % 64)
300
+ prompt = shorten_too_long_text(prompt)
301
+ # image = pipe(prompt=prompt).images[0]
302
+ try:
303
+ image = pipe(prompt=prompt,
304
+ width=block_width,
305
+ height=block_height,
306
+ # denoising_end=high_noise_frac,
307
+ # output_type='latent',
308
+ # height=512,
309
+ # width=512,
310
+ num_inference_steps=50).images[0] # normally uses 50 steps
311
+ except Exception as e:
312
+ # try rm stopwords + half the prompt
313
+ # todo try prompt permutations
314
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
315
+
316
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
317
+ prompts = prompt.split()
318
+
319
+ prompt = ' '.join(prompts[:len(prompts) // 2])
320
+ logger.info(f"shortened prompt to: {len(prompt)}")
321
+ image = None
322
+ if prompt:
323
+ try:
324
+ image = pipe(prompt=prompt,
325
+ width=block_width,
326
+ height=block_height,
327
+ # denoising_end=high_noise_frac,
328
+ # output_type='latent',
329
+ # height=512,
330
+ # width=512,
331
+ num_inference_steps=50).images[0] # normally uses 50 steps
332
+ except Exception as e:
333
+ # logger.info("trying to permute prompt")
334
+ # # try two swaps of the prompt/permutations
335
+ # prompt = prompt.split()
336
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
337
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
338
+
339
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
340
+ prompts = prompt.split()
341
+
342
+ prompt = ' '.join(prompts[:len(prompts) // 2])
343
+ logger.info(f"shortened prompt to: {len(prompt)}")
344
+
345
+ try:
346
+ image = pipe(prompt=prompt,
347
+ width=block_width,
348
+ height=block_height,
349
+ # denoising_end=high_noise_frac,
350
+ # output_type='latent', # dont need latent yet - we refine the image at full res
351
+ # height=512,
352
+ # width=512,
353
+ num_inference_steps=50).images[0] # normally uses 50 steps
354
+ except Exception as e:
355
+ # just error out
356
+ traceback.print_exc()
357
+ raise e
358
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
359
+ # todo fix device side asserts instead of restart to fix
360
+ # todo only restart the correct gunicorn
361
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
362
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
363
+ # os.system("kill -1 `pgrep gunicorn`")
364
+ # todo refine
365
+ # if image != None:
366
+ # image = refiner(
367
+ # prompt=prompt,
368
+ # # width=block_width,
369
+ # # height=block_height,
370
+ # num_inference_steps=n_steps,
371
+ # # denoising_start=high_noise_frac,
372
+ # image=image,
373
+ # ).images[0]
374
+ if width != block_width or height != block_height:
375
+ # resize to original size width/height
376
+ # find aspect ratio to scale up to that covers the original img input width/height
377
+ scale_up_ratio = max(width / block_width, height / block_height)
378
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
379
+ # crop image to original size
380
+ image = image.crop((0, 0, width, height))
381
+ # try:
382
+ # # gc.collect()
383
+ # torch.cuda.empty_cache()
384
+ # except Exception as e:
385
+ # traceback.print_exc()
386
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
387
+ # # todo fix device side asserts instead of restart to fix
388
+ # # todo only restart the correct gunicorn
389
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
390
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
391
+ # os.system("kill -1 `pgrep gunicorn`")
392
+ # save as bytesio
393
+ bs = BytesIO()
394
+
395
+ bright_count = np.sum(np.array(image) > 0)
396
+ if bright_count == 0:
397
+ # we have a black image, this is an error likely we need a restart
398
+ logger.info("restarting server to fix cuda issues (device side asserts)")
399
+ # # todo fix device side asserts instead of restart to fix
400
+ # # todo only restart the correct gunicorn
401
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
402
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
403
+ os.system("kill -1 `pgrep gunicorn`")
404
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
405
+ os.system("kill -1 `pgrep uvicorn`")
406
+
407
+ return None
408
+ image.save(bs, quality=85, optimize=True, format="webp")
409
+ bio = bs.getvalue()
410
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
411
+ with open("progress.txt", "w") as f:
412
+ current_time = datetime.now().strftime("%H:%M:%S")
413
+ f.write(f"{current_time}")
414
+ return bio
415
+
416
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
417
+ prompt = shorten_too_long_text(prompt)
418
+ # image = pipe(prompt=prompt).images[0]
419
+
420
+ init_image = load_image(image_url).convert("RGB")
421
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
422
+ num_inference_steps = 75
423
+ high_noise_frac = 0.7
424
+
425
+ try:
426
+ image = inpaintpipe(
427
+ prompt=prompt,
428
+ image=init_image,
429
+ mask_image=mask_image,
430
+ num_inference_steps=num_inference_steps,
431
+ denoising_start=high_noise_frac,
432
+ output_type="latent",
433
+ ).images[0] # normally uses 50 steps
434
+ except Exception as e:
435
+ # try rm stopwords + half the prompt
436
+ # todo try prompt permutations
437
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
438
+
439
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
440
+ prompts = prompt.split()
441
+
442
+ prompt = ' '.join(prompts[:len(prompts) // 2])
443
+ logger.info(f"shortened prompt to: {len(prompt)}")
444
+ image = None
445
+ if prompt:
446
+ try:
447
+ image = pipe(
448
+ prompt=prompt,
449
+ image=init_image,
450
+ mask_image=mask_image,
451
+ num_inference_steps=num_inference_steps,
452
+ denoising_start=high_noise_frac,
453
+ output_type="latent",
454
+ ).images[0] # normally uses 50 steps
455
+ except Exception as e:
456
+ # logger.info("trying to permute prompt")
457
+ # # try two swaps of the prompt/permutations
458
+ # prompt = prompt.split()
459
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
460
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
461
+
462
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
463
+ prompts = prompt.split()
464
+
465
+ prompt = ' '.join(prompts[:len(prompts) // 2])
466
+ logger.info(f"shortened prompt to: {len(prompt)}")
467
+
468
+ try:
469
+ image = inpaintpipe(
470
+ prompt=prompt,
471
+ image=init_image,
472
+ mask_image=mask_image,
473
+ num_inference_steps=num_inference_steps,
474
+ denoising_start=high_noise_frac,
475
+ output_type="latent",
476
+ ).images[0] # normally uses 50 steps
477
+ except Exception as e:
478
+ # just error out
479
+ traceback.print_exc()
480
+ raise e
481
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
482
+ # todo fix device side asserts instead of restart to fix
483
+ # todo only restart the correct gunicorn
484
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
485
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
486
+ # os.system("kill -1 `pgrep gunicorn`")
487
+ if image != None:
488
+ image = inpaint_refiner(
489
+ prompt=prompt,
490
+ image=image,
491
+ mask_image=mask_image,
492
+ num_inference_steps=num_inference_steps,
493
+ denoising_start=high_noise_frac,
494
+
495
+ ).images[0]
496
+ # try:
497
+ # # gc.collect()
498
+ # torch.cuda.empty_cache()
499
+ # except Exception as e:
500
+ # traceback.print_exc()
501
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
502
+ # # todo fix device side asserts instead of restart to fix
503
+ # # todo only restart the correct gunicorn
504
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
505
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
506
+ # os.system("kill -1 `pgrep gunicorn`")
507
+ # save as bytesio
508
+ bs = BytesIO()
509
+
510
+ bright_count = np.sum(np.array(image) > 0)
511
+ if bright_count == 0:
512
+ # we have a black image, this is an error likely we need a restart
513
+ logger.info("restarting server to fix cuda issues (device side asserts)")
514
+ # # todo fix device side asserts instead of restart to fix
515
+ # # todo only restart the correct gunicorn
516
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
517
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
518
+ os.system("kill -1 `pgrep gunicorn`")
519
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
520
+ os.system("kill -1 `pgrep uvicorn`")
521
+
522
+ return None
523
+ image.save(bs, quality=85, optimize=True, format="webp")
524
+ bio = bs.getvalue()
525
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
526
+ with open("progress.txt", "w") as f:
527
+ current_time = datetime.now().strftime("%H:%M:%S")
528
+ f.write(f"{current_time}")
529
+ return bio
530
+
531
+
532
+
533
+ def shorten_too_long_text(prompt):
534
+ if len(prompt) > 200:
535
+ # remove stopwords
536
+ prompt = prompt.split() # todo also split hyphens
537
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
538
+ if len(prompt) > 200:
539
+ prompt = prompt[:200]
540
+ return prompt
541
+
542
+ # image = pipe(prompt=prompt).images[0]
543
+ #
544
+ # image.save("test.png")
545
+ # # save all images
546
+ # for i, image in enumerate(images):
547
+ # image.save(f"{i}.png")
548
+
img/main_v3.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def get_summary(system_prompt, prompt):
214
+ import requests
215
+ import time
216
+ from io import BytesIO
217
+ import json
218
+ summary_sys = """I want you to act as a text summarizer to help me create a concise summary of the text I provide. The summary can be up to 60.0 words in length, expressing the key points, key scenarios, main character and concepts written in the original text without adding your interpretations."""
219
+ instruction = summary_sys
220
+ # for human, assistant in history:
221
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
222
+ # prompt = system_prompt + prompt
223
+ message = f"""My first request is to summarize this text – [{prompt}]"""
224
+ instruction += ' USER: ' + message + ' ASSISTANT:'
225
+
226
+ print("Ins: ", instruction)
227
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
228
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
229
+ json_object = {"prompt": instruction,
230
+ # "max_tokens": 2048000,
231
+ "max_tokens": 90,
232
+ "n": 1
233
+ }
234
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
235
+ # print(generate_response.content)
236
+ res_json = json.loads(generate_response.content)
237
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
238
+ print(ASSISTANT)
239
+ return ASSISTANT
240
+
241
+ @app.post("/image_url")
242
+ def image_url(img: Img):
243
+ system_prompt = img.system_prompt
244
+ prompt = img.ASSISTANT
245
+ prompt = get_summary(system_prompt, prompt)
246
+ prompt = shorten_too_long_text(prompt)
247
+ # if Path(save_path).exists():
248
+ # return FileResponse(save_path, media_type="image/png")
249
+ # return JSONResponse({"path": path})
250
+ # image = pipe(prompt=prompt).images[0]
251
+ g = torch.Generator(device="cuda")
252
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
253
+
254
+ # if not save_path:
255
+ save_path = generate_save_path()
256
+ save_path = f"images/{save_path}.png"
257
+ image.save(save_path)
258
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
259
+ path = f"{img_url}/{save_path}"
260
+ return JSONResponse({"path": path})
261
+
262
+
263
+ @app.get("/make_image")
264
+ # @app.post("/make_image")
265
+ def make_image(prompt: str, save_path: str = ""):
266
+ if Path(save_path).exists():
267
+ return FileResponse(save_path, media_type="image/png")
268
+ image = pipe(prompt=prompt).images[0]
269
+ if not save_path:
270
+ save_path = f"images/{prompt}.png"
271
+ image.save(save_path)
272
+ return FileResponse(save_path, media_type="image/png")
273
+
274
+
275
+ @app.get("/create_and_upload_image")
276
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
277
+ path_components = save_path.split("/")[0:-1]
278
+ final_name = save_path.split("/")[-1]
279
+ if not path_components:
280
+ path_components = []
281
+ save_path = '/'.join(path_components) + quote_plus(final_name)
282
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
283
+ return JSONResponse({"path": path})
284
+
285
+ @app.get("/inpaint_and_upload_image")
286
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
287
+ path_components = save_path.split("/")[0:-1]
288
+ final_name = save_path.split("/")[-1]
289
+ if not path_components:
290
+ path_components = []
291
+ save_path = '/'.join(path_components) + quote_plus(final_name)
292
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
293
+ return JSONResponse({"path": path})
294
+
295
+
296
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
297
+ prompt = shorten_too_long_text(prompt)
298
+ save_path = shorten_too_long_text(save_path)
299
+ # check exists - todo cache this
300
+ if check_if_blob_exists(save_path):
301
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
302
+ bio = create_image_from_prompt(prompt, width, height)
303
+ if bio is None:
304
+ return None # error thrown in pool
305
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
306
+ return link
307
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
308
+ prompt = shorten_too_long_text(prompt)
309
+ save_path = shorten_too_long_text(save_path)
310
+ # check exists - todo cache this
311
+ if check_if_blob_exists(save_path):
312
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
313
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
314
+ if bio is None:
315
+ return None # error thrown in pool
316
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
317
+ return link
318
+
319
+ # multiprocessing.set_start_method('spawn', True)
320
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
321
+ # def create_image_from_prompt_sync(prompt):
322
+ # """have to call this sync to avoid OOM errors"""
323
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
324
+
325
+ def create_image_from_prompt(prompt, width, height):
326
+ # round width and height down to multiple of 64
327
+ block_width = width - (width % 64)
328
+ block_height = height - (height % 64)
329
+ prompt = shorten_too_long_text(prompt)
330
+ # image = pipe(prompt=prompt).images[0]
331
+ try:
332
+ image = pipe(prompt=prompt,
333
+ width=block_width,
334
+ height=block_height,
335
+ # denoising_end=high_noise_frac,
336
+ # output_type='latent',
337
+ # height=512,
338
+ # width=512,
339
+ num_inference_steps=50).images[0] # normally uses 50 steps
340
+ except Exception as e:
341
+ # try rm stopwords + half the prompt
342
+ # todo try prompt permutations
343
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
344
+
345
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
346
+ prompts = prompt.split()
347
+
348
+ prompt = ' '.join(prompts[:len(prompts) // 2])
349
+ logger.info(f"shortened prompt to: {len(prompt)}")
350
+ image = None
351
+ if prompt:
352
+ try:
353
+ image = pipe(prompt=prompt,
354
+ width=block_width,
355
+ height=block_height,
356
+ # denoising_end=high_noise_frac,
357
+ # output_type='latent',
358
+ # height=512,
359
+ # width=512,
360
+ num_inference_steps=50).images[0] # normally uses 50 steps
361
+ except Exception as e:
362
+ # logger.info("trying to permute prompt")
363
+ # # try two swaps of the prompt/permutations
364
+ # prompt = prompt.split()
365
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
366
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
367
+
368
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
369
+ prompts = prompt.split()
370
+
371
+ prompt = ' '.join(prompts[:len(prompts) // 2])
372
+ logger.info(f"shortened prompt to: {len(prompt)}")
373
+
374
+ try:
375
+ image = pipe(prompt=prompt,
376
+ width=block_width,
377
+ height=block_height,
378
+ # denoising_end=high_noise_frac,
379
+ # output_type='latent', # dont need latent yet - we refine the image at full res
380
+ # height=512,
381
+ # width=512,
382
+ num_inference_steps=50).images[0] # normally uses 50 steps
383
+ except Exception as e:
384
+ # just error out
385
+ traceback.print_exc()
386
+ raise e
387
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
388
+ # todo fix device side asserts instead of restart to fix
389
+ # todo only restart the correct gunicorn
390
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
391
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
392
+ # os.system("kill -1 `pgrep gunicorn`")
393
+ # todo refine
394
+ # if image != None:
395
+ # image = refiner(
396
+ # prompt=prompt,
397
+ # # width=block_width,
398
+ # # height=block_height,
399
+ # num_inference_steps=n_steps,
400
+ # # denoising_start=high_noise_frac,
401
+ # image=image,
402
+ # ).images[0]
403
+ if width != block_width or height != block_height:
404
+ # resize to original size width/height
405
+ # find aspect ratio to scale up to that covers the original img input width/height
406
+ scale_up_ratio = max(width / block_width, height / block_height)
407
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
408
+ # crop image to original size
409
+ image = image.crop((0, 0, width, height))
410
+ # try:
411
+ # # gc.collect()
412
+ # torch.cuda.empty_cache()
413
+ # except Exception as e:
414
+ # traceback.print_exc()
415
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
416
+ # # todo fix device side asserts instead of restart to fix
417
+ # # todo only restart the correct gunicorn
418
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
419
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
420
+ # os.system("kill -1 `pgrep gunicorn`")
421
+ # save as bytesio
422
+ bs = BytesIO()
423
+
424
+ bright_count = np.sum(np.array(image) > 0)
425
+ if bright_count == 0:
426
+ # we have a black image, this is an error likely we need a restart
427
+ logger.info("restarting server to fix cuda issues (device side asserts)")
428
+ # # todo fix device side asserts instead of restart to fix
429
+ # # todo only restart the correct gunicorn
430
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
431
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
432
+ os.system("kill -1 `pgrep gunicorn`")
433
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
434
+ os.system("kill -1 `pgrep uvicorn`")
435
+
436
+ return None
437
+ image.save(bs, quality=85, optimize=True, format="webp")
438
+ bio = bs.getvalue()
439
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
440
+ with open("progress.txt", "w") as f:
441
+ current_time = datetime.now().strftime("%H:%M:%S")
442
+ f.write(f"{current_time}")
443
+ return bio
444
+
445
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
446
+ prompt = shorten_too_long_text(prompt)
447
+ # image = pipe(prompt=prompt).images[0]
448
+
449
+ init_image = load_image(image_url).convert("RGB")
450
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
451
+ num_inference_steps = 75
452
+ high_noise_frac = 0.7
453
+
454
+ try:
455
+ image = inpaintpipe(
456
+ prompt=prompt,
457
+ image=init_image,
458
+ mask_image=mask_image,
459
+ num_inference_steps=num_inference_steps,
460
+ denoising_start=high_noise_frac,
461
+ output_type="latent",
462
+ ).images[0] # normally uses 50 steps
463
+ except Exception as e:
464
+ # try rm stopwords + half the prompt
465
+ # todo try prompt permutations
466
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
467
+
468
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
469
+ prompts = prompt.split()
470
+
471
+ prompt = ' '.join(prompts[:len(prompts) // 2])
472
+ logger.info(f"shortened prompt to: {len(prompt)}")
473
+ image = None
474
+ if prompt:
475
+ try:
476
+ image = pipe(
477
+ prompt=prompt,
478
+ image=init_image,
479
+ mask_image=mask_image,
480
+ num_inference_steps=num_inference_steps,
481
+ denoising_start=high_noise_frac,
482
+ output_type="latent",
483
+ ).images[0] # normally uses 50 steps
484
+ except Exception as e:
485
+ # logger.info("trying to permute prompt")
486
+ # # try two swaps of the prompt/permutations
487
+ # prompt = prompt.split()
488
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
489
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
490
+
491
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
492
+ prompts = prompt.split()
493
+
494
+ prompt = ' '.join(prompts[:len(prompts) // 2])
495
+ logger.info(f"shortened prompt to: {len(prompt)}")
496
+
497
+ try:
498
+ image = inpaintpipe(
499
+ prompt=prompt,
500
+ image=init_image,
501
+ mask_image=mask_image,
502
+ num_inference_steps=num_inference_steps,
503
+ denoising_start=high_noise_frac,
504
+ output_type="latent",
505
+ ).images[0] # normally uses 50 steps
506
+ except Exception as e:
507
+ # just error out
508
+ traceback.print_exc()
509
+ raise e
510
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
511
+ # todo fix device side asserts instead of restart to fix
512
+ # todo only restart the correct gunicorn
513
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
514
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
515
+ # os.system("kill -1 `pgrep gunicorn`")
516
+ if image != None:
517
+ image = inpaint_refiner(
518
+ prompt=prompt,
519
+ image=image,
520
+ mask_image=mask_image,
521
+ num_inference_steps=num_inference_steps,
522
+ denoising_start=high_noise_frac,
523
+
524
+ ).images[0]
525
+ # try:
526
+ # # gc.collect()
527
+ # torch.cuda.empty_cache()
528
+ # except Exception as e:
529
+ # traceback.print_exc()
530
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
531
+ # # todo fix device side asserts instead of restart to fix
532
+ # # todo only restart the correct gunicorn
533
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
534
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
535
+ # os.system("kill -1 `pgrep gunicorn`")
536
+ # save as bytesio
537
+ bs = BytesIO()
538
+
539
+ bright_count = np.sum(np.array(image) > 0)
540
+ if bright_count == 0:
541
+ # we have a black image, this is an error likely we need a restart
542
+ logger.info("restarting server to fix cuda issues (device side asserts)")
543
+ # # todo fix device side asserts instead of restart to fix
544
+ # # todo only restart the correct gunicorn
545
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
546
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
547
+ os.system("kill -1 `pgrep gunicorn`")
548
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
549
+ os.system("kill -1 `pgrep uvicorn`")
550
+
551
+ return None
552
+ image.save(bs, quality=85, optimize=True, format="webp")
553
+ bio = bs.getvalue()
554
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
555
+ with open("progress.txt", "w") as f:
556
+ current_time = datetime.now().strftime("%H:%M:%S")
557
+ f.write(f"{current_time}")
558
+ return bio
559
+
560
+
561
+
562
+ def shorten_too_long_text(prompt):
563
+ if len(prompt) > 200:
564
+ # remove stopwords
565
+ prompt = prompt.split() # todo also split hyphens
566
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
567
+ if len(prompt) > 200:
568
+ prompt = prompt[:200]
569
+ return prompt
570
+
571
+ # image = pipe(prompt=prompt).images[0]
572
+ #
573
+ # image.save("test.png")
574
+ # # save all images
575
+ # for i, image in enumerate(images):
576
+ # image.save(f"{i}.png")
577
+
578
+
img/main_v4.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+ import requests
27
+ from PIL import Image
28
+ import time
29
+ from io import BytesIO
30
+ import json
31
+ import string
32
+ import random
33
+ from env import BUCKET_PATH, BUCKET_NAME
34
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
35
+ torch._dynamo.config.suppress_errors = True
36
+
37
+ import string
38
+ import random
39
+
40
+ def generate_save_path():
41
+ # initializing size of string
42
+ N = 7
43
+
44
+ # using random.choices()
45
+ # generating random strings
46
+ res = ''.join(random.choices(string.ascii_uppercase +
47
+ string.digits, k=N))
48
+ return res
49
+
50
+ # pipe = DiffusionPipeline.from_pretrained(
51
+ # "models/stable-diffusion-xl-base-1.0",
52
+ # torch_dtype=torch.bfloat16,
53
+ # use_safetensors=True,
54
+ # variant="fp16",
55
+ # # safety_checker=None,
56
+ # ) # todo try torch_dtype=bfloat16
57
+
58
+ model_dir = os.getenv("SDXL_MODEL_DIR")
59
+
60
+ if model_dir:
61
+ # Use local model
62
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
63
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
64
+ else:
65
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
66
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
67
+
68
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
69
+
70
+ pipe.watermark = None
71
+
72
+ pipe.to("cuda")
73
+
74
+ refiner = DiffusionPipeline.from_pretrained(
75
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
76
+ text_encoder_2=pipe.text_encoder_2,
77
+ vae=pipe.vae,
78
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
79
+ use_safetensors=True,
80
+ variant="fp16", #remember not to download the big model
81
+ )
82
+ refiner.watermark = None
83
+ refiner.to("cuda")
84
+
85
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
86
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
87
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
88
+ scheduler=pipe.scheduler,
89
+ text_encoder=pipe.text_encoder,
90
+ text_encoder_2=pipe.text_encoder_2,
91
+ tokenizer=pipe.tokenizer,
92
+ tokenizer_2=pipe.tokenizer_2,
93
+ unet=pipe.unet,
94
+ vae=pipe.vae,
95
+ # load_connected_pipeline=
96
+ )
97
+ # # switch out to save gpu mem
98
+ # del inpaintpipe.vae
99
+ # del inpaintpipe.text_encoder_2
100
+ # del inpaintpipe.text_encoder
101
+ # del inpaintpipe.scheduler
102
+ # del inpaintpipe.tokenizer
103
+ # del inpaintpipe.tokenizer_2
104
+ # del inpaintpipe.unet
105
+ # inpaintpipe.vae = pipe.vae
106
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
107
+ # inpaintpipe.text_encoder = pipe.text_encoder
108
+ # inpaintpipe.scheduler = pipe.scheduler
109
+ # inpaintpipe.tokenizer = pipe.tokenizer
110
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
111
+ # inpaintpipe.unet = pipe.unet
112
+ # todo this should work
113
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
114
+ # vae=pipe.vae,
115
+ # text_encoder_2=pipe.text_encoder_2,
116
+ # text_encoder=pipe.text_encoder,
117
+ # unet=pipe.unet,
118
+ # scheduler=pipe.scheduler,
119
+ # tokenizer=pipe.tokenizer,
120
+ # tokenizer_2=pipe.tokenizer_2,
121
+ # requires_aesthetics_score=False,
122
+ # )
123
+ inpaintpipe.to("cuda")
124
+ inpaintpipe.watermark = None
125
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
126
+
127
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
128
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
129
+ text_encoder_2=inpaintpipe.text_encoder_2,
130
+ vae=inpaintpipe.vae,
131
+ torch_dtype=torch.bfloat16,
132
+ use_safetensors=True,
133
+ variant="fp16",
134
+
135
+ tokenizer_2=refiner.tokenizer_2,
136
+ tokenizer=refiner.tokenizer,
137
+ scheduler=refiner.scheduler,
138
+ text_encoder=refiner.text_encoder,
139
+ unet=refiner.unet,
140
+ )
141
+ # del inpaint_refiner.vae
142
+ # del inpaint_refiner.text_encoder_2
143
+ # del inpaint_refiner.text_encoder
144
+ # del inpaint_refiner.scheduler
145
+ # del inpaint_refiner.tokenizer
146
+ # del inpaint_refiner.tokenizer_2
147
+ # del inpaint_refiner.unet
148
+ # inpaint_refiner.vae = inpaintpipe.vae
149
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
150
+ #
151
+ # inpaint_refiner.text_encoder = refiner.text_encoder
152
+ # inpaint_refiner.scheduler = refiner.scheduler
153
+ # inpaint_refiner.tokenizer = refiner.tokenizer
154
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
155
+ # inpaint_refiner.unet = refiner.unet
156
+
157
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
158
+ # text_encoder_2=inpaintpipe.text_encoder_2,
159
+ # vae=inpaintpipe.vae,
160
+ # # the rest from the existing refiner
161
+ # tokenizer_2=refiner.tokenizer_2,
162
+ # tokenizer=refiner.tokenizer,
163
+ # scheduler=refiner.scheduler,
164
+ # text_encoder=refiner.text_encoder,
165
+ # unet=refiner.unet,
166
+ # requires_aesthetics_score=False,
167
+ # )
168
+ inpaint_refiner.to("cuda")
169
+ inpaint_refiner.watermark = None
170
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
171
+
172
+ n_steps = 40
173
+ high_noise_frac = 0.8
174
+
175
+ # if using torch < 2.0
176
+ # pipe.enable_xformers_memory_efficient_attention()
177
+
178
+
179
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
180
+ # this can cause errors on some inputs so consider disabling it
181
+ pipe.unet = torch.compile(pipe.unet)
182
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
183
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
184
+ inpaintpipe.unet = pipe.unet
185
+ inpaint_refiner.unet = refiner.unet
186
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
187
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
188
+ from pydantic import BaseModel
189
+
190
+ app = FastAPI(
191
+ openapi_url="/static/openapi.json",
192
+ docs_url="/swagger-docs",
193
+ redoc_url="/redoc",
194
+ title="Generate Images Netwrck API",
195
+ description="Character Chat API",
196
+ # root_path="https://api.text-generator.io",
197
+ version="1",
198
+ )
199
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
200
+ app.add_middleware(
201
+ CORSMiddleware,
202
+ allow_origins=["*"],
203
+ allow_credentials=True,
204
+ allow_methods=["*"],
205
+ allow_headers=["*"],
206
+ )
207
+
208
+ stopwords = nltk.corpus.stopwords.words("english")
209
+
210
+ class Img(BaseModel):
211
+ system_prompt: str
212
+ ASSISTANT: str
213
+
214
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
215
+ img_url = "http://phlrr3006.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
216
+
217
+ is_gpu_busy = False
218
+
219
+ def get_summary(system_prompt, prompt):
220
+ import requests
221
+ import time
222
+ from io import BytesIO
223
+ import json
224
+ summary_sys = """I want you to act as a text summarizer to help me create a concise summary of the text I provide. The summary can be up to 60.0 words in length, expressing the key points, key scenarios, main character and concepts written in the original text without adding your interpretations."""
225
+ instruction = summary_sys
226
+ # for human, assistant in history:
227
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
228
+ # prompt = system_prompt + prompt
229
+ message = f"""My first request is to summarize this text – [{prompt}]"""
230
+ instruction += ' USER: ' + message + ' ASSISTANT:'
231
+
232
+ print("Ins: ", instruction)
233
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
234
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
235
+ json_object = {"prompt": instruction,
236
+ # "max_tokens": 2048000,
237
+ "max_tokens": 90,
238
+ "n": 1
239
+ }
240
+ generate_response = requests.post("http://phlrr3006.guest.corp.microsoft.com:7991/generate", json=json_object)
241
+ # print(generate_response.content)
242
+ res_json = json.loads(generate_response.content)
243
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
244
+ print(ASSISTANT)
245
+ return ASSISTANT
246
+
247
+ @app.post("/image_url")
248
+ def image_url(img: Img):
249
+ system_prompt = img.system_prompt
250
+ prompt = img.ASSISTANT
251
+ prompt = get_summary(system_prompt, prompt)
252
+ prompt = shorten_too_long_text(prompt)
253
+
254
+ json_object = {
255
+ "prompt": prompt,
256
+ "height": 1024,
257
+ "width": 1024,
258
+ "num_inference_steps": 50,
259
+ # "guidance_scale": 7.5,
260
+ "eta": 0
261
+ }
262
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:3000/text2img", json=json_object)
263
+ image = generate_response.content
264
+ # print(generate_response.content)
265
+ save_path = generate_save_path()
266
+ save_path = f"images/{save_path}.png"
267
+ # generate_response.save(save_path)
268
+ with open(save_path, 'wb') as f:
269
+ f.write(image)
270
+ #
271
+ # # if Path(save_path).exists():
272
+ # # return FileResponse(save_path, media_type="image/png")
273
+ # # return JSONResponse({"path": path})
274
+ # # image = pipe(prompt=prompt).images[0]
275
+ # g = torch.Generator(device="cuda")
276
+ # image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
277
+ #
278
+ # # if not save_path:
279
+ # save_path = generate_save_path()
280
+ # save_path = f"images/{save_path}.png"
281
+ # image.save(save_path)
282
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
283
+ path = f"{img_url}{save_path}"
284
+ return JSONResponse({"path": path})
285
+
286
+
287
+ @app.get("/make_image")
288
+ # @app.post("/make_image")
289
+ def make_image(prompt: str, save_path: str = ""):
290
+ if Path(save_path).exists():
291
+ return FileResponse(save_path, media_type="image/png")
292
+ image = pipe(prompt=prompt).images[0]
293
+ if not save_path:
294
+ save_path = f"images/{prompt}.png"
295
+ image.save(save_path)
296
+ return FileResponse(save_path, media_type="image/png")
297
+
298
+
299
+ @app.get("/create_and_upload_image")
300
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
301
+ path_components = save_path.split("/")[0:-1]
302
+ final_name = save_path.split("/")[-1]
303
+ if not path_components:
304
+ path_components = []
305
+ save_path = '/'.join(path_components) + quote_plus(final_name)
306
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
307
+ return JSONResponse({"path": path})
308
+
309
+ @app.get("/inpaint_and_upload_image")
310
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
311
+ path_components = save_path.split("/")[0:-1]
312
+ final_name = save_path.split("/")[-1]
313
+ if not path_components:
314
+ path_components = []
315
+ save_path = '/'.join(path_components) + quote_plus(final_name)
316
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
317
+ return JSONResponse({"path": path})
318
+
319
+
320
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
321
+ prompt = shorten_too_long_text(prompt)
322
+ save_path = shorten_too_long_text(save_path)
323
+ # check exists - todo cache this
324
+ if check_if_blob_exists(save_path):
325
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
326
+ bio = create_image_from_prompt(prompt, width, height)
327
+ if bio is None:
328
+ return None # error thrown in pool
329
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
330
+ return link
331
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
332
+ prompt = shorten_too_long_text(prompt)
333
+ save_path = shorten_too_long_text(save_path)
334
+ # check exists - todo cache this
335
+ if check_if_blob_exists(save_path):
336
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
337
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
338
+ if bio is None:
339
+ return None # error thrown in pool
340
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
341
+ return link
342
+
343
+ # multiprocessing.set_start_method('spawn', True)
344
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
345
+ # def create_image_from_prompt_sync(prompt):
346
+ # """have to call this sync to avoid OOM errors"""
347
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
348
+
349
+ def create_image_from_prompt(prompt, width, height):
350
+ # round width and height down to multiple of 64
351
+ block_width = width - (width % 64)
352
+ block_height = height - (height % 64)
353
+ prompt = shorten_too_long_text(prompt)
354
+ # image = pipe(prompt=prompt).images[0]
355
+ try:
356
+ image = pipe(prompt=prompt,
357
+ width=block_width,
358
+ height=block_height,
359
+ # denoising_end=high_noise_frac,
360
+ # output_type='latent',
361
+ # height=512,
362
+ # width=512,
363
+ num_inference_steps=50).images[0] # normally uses 50 steps
364
+ except Exception as e:
365
+ # try rm stopwords + half the prompt
366
+ # todo try prompt permutations
367
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
368
+
369
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
370
+ prompts = prompt.split()
371
+
372
+ prompt = ' '.join(prompts[:len(prompts) // 2])
373
+ logger.info(f"shortened prompt to: {len(prompt)}")
374
+ image = None
375
+ if prompt:
376
+ try:
377
+ image = pipe(prompt=prompt,
378
+ width=block_width,
379
+ height=block_height,
380
+ # denoising_end=high_noise_frac,
381
+ # output_type='latent',
382
+ # height=512,
383
+ # width=512,
384
+ num_inference_steps=50).images[0] # normally uses 50 steps
385
+ except Exception as e:
386
+ # logger.info("trying to permute prompt")
387
+ # # try two swaps of the prompt/permutations
388
+ # prompt = prompt.split()
389
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
390
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
391
+
392
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
393
+ prompts = prompt.split()
394
+
395
+ prompt = ' '.join(prompts[:len(prompts) // 2])
396
+ logger.info(f"shortened prompt to: {len(prompt)}")
397
+
398
+ try:
399
+ image = pipe(prompt=prompt,
400
+ width=block_width,
401
+ height=block_height,
402
+ # denoising_end=high_noise_frac,
403
+ # output_type='latent', # dont need latent yet - we refine the image at full res
404
+ # height=512,
405
+ # width=512,
406
+ num_inference_steps=50).images[0] # normally uses 50 steps
407
+ except Exception as e:
408
+ # just error out
409
+ traceback.print_exc()
410
+ raise e
411
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
412
+ # todo fix device side asserts instead of restart to fix
413
+ # todo only restart the correct gunicorn
414
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
415
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
416
+ # os.system("kill -1 `pgrep gunicorn`")
417
+ # todo refine
418
+ # if image != None:
419
+ # image = refiner(
420
+ # prompt=prompt,
421
+ # # width=block_width,
422
+ # # height=block_height,
423
+ # num_inference_steps=n_steps,
424
+ # # denoising_start=high_noise_frac,
425
+ # image=image,
426
+ # ).images[0]
427
+ if width != block_width or height != block_height:
428
+ # resize to original size width/height
429
+ # find aspect ratio to scale up to that covers the original img input width/height
430
+ scale_up_ratio = max(width / block_width, height / block_height)
431
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
432
+ # crop image to original size
433
+ image = image.crop((0, 0, width, height))
434
+ # try:
435
+ # # gc.collect()
436
+ # torch.cuda.empty_cache()
437
+ # except Exception as e:
438
+ # traceback.print_exc()
439
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
440
+ # # todo fix device side asserts instead of restart to fix
441
+ # # todo only restart the correct gunicorn
442
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
443
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
444
+ # os.system("kill -1 `pgrep gunicorn`")
445
+ # save as bytesio
446
+ bs = BytesIO()
447
+
448
+ bright_count = np.sum(np.array(image) > 0)
449
+ if bright_count == 0:
450
+ # we have a black image, this is an error likely we need a restart
451
+ logger.info("restarting server to fix cuda issues (device side asserts)")
452
+ # # todo fix device side asserts instead of restart to fix
453
+ # # todo only restart the correct gunicorn
454
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
455
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
456
+ os.system("kill -1 `pgrep gunicorn`")
457
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
458
+ os.system("kill -1 `pgrep uvicorn`")
459
+
460
+ return None
461
+ image.save(bs, quality=85, optimize=True, format="webp")
462
+ bio = bs.getvalue()
463
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
464
+ with open("progress.txt", "w") as f:
465
+ current_time = datetime.now().strftime("%H:%M:%S")
466
+ f.write(f"{current_time}")
467
+ return bio
468
+
469
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
470
+ prompt = shorten_too_long_text(prompt)
471
+ # image = pipe(prompt=prompt).images[0]
472
+
473
+ init_image = load_image(image_url).convert("RGB")
474
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
475
+ num_inference_steps = 75
476
+ high_noise_frac = 0.7
477
+
478
+ try:
479
+ image = inpaintpipe(
480
+ prompt=prompt,
481
+ image=init_image,
482
+ mask_image=mask_image,
483
+ num_inference_steps=num_inference_steps,
484
+ denoising_start=high_noise_frac,
485
+ output_type="latent",
486
+ ).images[0] # normally uses 50 steps
487
+ except Exception as e:
488
+ # try rm stopwords + half the prompt
489
+ # todo try prompt permutations
490
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
491
+
492
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
493
+ prompts = prompt.split()
494
+
495
+ prompt = ' '.join(prompts[:len(prompts) // 2])
496
+ logger.info(f"shortened prompt to: {len(prompt)}")
497
+ image = None
498
+ if prompt:
499
+ try:
500
+ image = pipe(
501
+ prompt=prompt,
502
+ image=init_image,
503
+ mask_image=mask_image,
504
+ num_inference_steps=num_inference_steps,
505
+ denoising_start=high_noise_frac,
506
+ output_type="latent",
507
+ ).images[0] # normally uses 50 steps
508
+ except Exception as e:
509
+ # logger.info("trying to permute prompt")
510
+ # # try two swaps of the prompt/permutations
511
+ # prompt = prompt.split()
512
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
513
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
514
+
515
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
516
+ prompts = prompt.split()
517
+
518
+ prompt = ' '.join(prompts[:len(prompts) // 2])
519
+ logger.info(f"shortened prompt to: {len(prompt)}")
520
+
521
+ try:
522
+ image = inpaintpipe(
523
+ prompt=prompt,
524
+ image=init_image,
525
+ mask_image=mask_image,
526
+ num_inference_steps=num_inference_steps,
527
+ denoising_start=high_noise_frac,
528
+ output_type="latent",
529
+ ).images[0] # normally uses 50 steps
530
+ except Exception as e:
531
+ # just error out
532
+ traceback.print_exc()
533
+ raise e
534
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
535
+ # todo fix device side asserts instead of restart to fix
536
+ # todo only restart the correct gunicorn
537
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
538
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
539
+ # os.system("kill -1 `pgrep gunicorn`")
540
+ if image != None:
541
+ image = inpaint_refiner(
542
+ prompt=prompt,
543
+ image=image,
544
+ mask_image=mask_image,
545
+ num_inference_steps=num_inference_steps,
546
+ denoising_start=high_noise_frac,
547
+
548
+ ).images[0]
549
+ # try:
550
+ # # gc.collect()
551
+ # torch.cuda.empty_cache()
552
+ # except Exception as e:
553
+ # traceback.print_exc()
554
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
555
+ # # todo fix device side asserts instead of restart to fix
556
+ # # todo only restart the correct gunicorn
557
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
558
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
559
+ # os.system("kill -1 `pgrep gunicorn`")
560
+ # save as bytesio
561
+ bs = BytesIO()
562
+
563
+ bright_count = np.sum(np.array(image) > 0)
564
+ if bright_count == 0:
565
+ # we have a black image, this is an error likely we need a restart
566
+ logger.info("restarting server to fix cuda issues (device side asserts)")
567
+ # # todo fix device side asserts instead of restart to fix
568
+ # # todo only restart the correct gunicorn
569
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
570
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
571
+ os.system("kill -1 `pgrep gunicorn`")
572
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
573
+ os.system("kill -1 `pgrep uvicorn`")
574
+
575
+ return None
576
+ image.save(bs, quality=85, optimize=True, format="webp")
577
+ bio = bs.getvalue()
578
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
579
+ with open("progress.txt", "w") as f:
580
+ current_time = datetime.now().strftime("%H:%M:%S")
581
+ f.write(f"{current_time}")
582
+ return bio
583
+
584
+
585
+
586
+ def shorten_too_long_text(prompt):
587
+ if len(prompt) > 200:
588
+ # remove stopwords
589
+ prompt = prompt.split() # todo also split hyphens
590
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
591
+ if len(prompt) > 200:
592
+ prompt = prompt[:200]
593
+ return prompt
594
+
595
+ # image = pipe(prompt=prompt).images[0]
596
+ #
597
+ # image.save("test.png")
598
+ # # save all images
599
+ # for i, image in enumerate(images):
600
+ # image.save(f"{i}.png")
601
+
602
+
603
+
img/main_v5.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def lm_shorten_too_long_text(prompt):
214
+ if len(prompt) > 2030:
215
+ # remove stopwords
216
+ prompt = prompt.split() # todo also split hyphens
217
+ prompt = ' '.join((word for word in prompt))# if word not in stopwords))
218
+ if len(prompt) > 2030:
219
+ prompt = prompt[:2030]
220
+ return prompt
221
+
222
+ def get_summary(system_prompt, prompt):
223
+ import requests
224
+ import time
225
+ from io import BytesIO
226
+ import json
227
+ summary_sys = """You will now act as a prompt generator for a generative AI called "Stable Diffusion XL 1.0 ". Stable Diffusion XL generates images based on given prompts. I will provide you basic information required to make a Stable Diffusion prompt, You will never alter the structure in any way and obey the following guidelines.
228
+
229
+ Basic information required to make Stable Diffusion prompt:
230
+
231
+ - Prompt structure: [1],[2],[3],[4],[5],[6] and it should be given as one single sentence where 1,2,3,4,5,6 represent
232
+ [1] = short and concise description of [KEYWORD] that will include very specific imagery details
233
+ [2] = a detailed description of [1] that will include very specific imagery details.
234
+ [3] = with a detailed description describing the environment of the scene.
235
+ [4] = with a detailed description describing the mood/feelings and atmosphere of the scene.
236
+ [5] = A style, for example: "Anime","Photographic","Comic Book","Fantasy Art", “Analog Film”,”Neon Punk”,”Isometric”,”Low Poly”,”Origami”,”Line Art”,”Cinematic”,”3D Model”,”Pixel Art”,”Watercolor”,”Sticker” ).
237
+ [6] = A description of how [5] will be realized. (e.g. Photography (e.g. Macro, Fisheye Style, Portrait) with camera model and appropriate camera settings, Painting with detailed descriptions about the materials and working material used, rendering with engine settings, a digital Illustration, a woodburn art (and everything else that could be defined as an output type)
238
+ - Prompt Structure for Prompt asking with text value:
239
+
240
+ Text "Text Value" written on {subject description in less than 20 words}
241
+ Replace "Text value" with text given by user.
242
+
243
+
244
+ Important Sample prompt Structure with Text value :
245
+
246
+ 1. Text 'SDXL' written on a frothy, warm latte, viewed top-down.
247
+ 2. Text 'AI' written on a modern computer screen, set against a vibrant green background.
248
+
249
+ Important Sample prompt Structure :
250
+
251
+ 1. Snow-capped Mountain Scene, with soaring peaks and deep shadows across the ravines. A crystal clear lake mirrors these peaks, surrounded by pine trees. The scene exudes a calm, serene alpine morning atmosphere. Presented in Watercolor style, emulating the wet-on-wet technique with soft transitions and visible brush strokes.
252
+ 2. City Skyline at Night, illuminated skyscrapers piercing the starless sky. Nestled beside a calm river, reflecting the city lights like a mirror. The atmosphere is buzzing with urban energy and intrigue. Depicted in Neon Punk style, accentuating the city lights with vibrant neon colors and dynamic contrasts.
253
+ 3. Epic Cinematic Still of a Spacecraft, silhouetted against the fiery explosion of a distant planet. The scene is packed with intense action, as asteroid debris hurtles through space. Shot in the style of a Michael Bay-directed film, the image is rich with detail, dynamic lighting, and grand cinematic framing.
254
+ - Word order and effective adjectives matter in the prompt. The subject, action, and specific details should be included. Adjectives like cute, medieval, or futuristic can be effective.
255
+ - The environment/background of the image should be described, such as indoor, outdoor, in space, or solid color.
256
+ - Curly brackets are necessary in the prompt to provide specific details about the subject and action. These details are important for generating a high-quality image.
257
+ - Art inspirations should be listed to take inspiration from. Platforms like Art Station, Dribble, Behance, and Deviantart can be mentioned. Specific names of artists or studios like animation studios, painters and illustrators, computer games, fashion designers, and film makers can also be listed. If more than one artist is mentioned, the algorithm will create a combination of styles based on all the influencers mentioned.
258
+ - Related information about lighting, camera angles, render style, resolution, the required level of detail, etc. should be included at the end of the prompt.
259
+ - Camera shot type, camera lens, and view should be specified. Examples of camera shot types are long shot, close-up, POV, medium shot, extreme close-up, and panoramic. Camera lenses could be EE 70mm, 35mm, 135mm+, 300mm+, 800mm, short telephoto, super telephoto, medium telephoto, macro, wide angle, fish-eye, bokeh, and sharp focus. Examples of views are front, side, back, high angle, low angle, and overhead.
260
+ - Helpful keywords related to resolution, detail, and lighting are 4K, 8K, 64K, detailed, highly detailed, high resolution, hyper detailed, HDR, UHD, professional, and golden ratio. Examples of lighting are studio lighting, soft light, neon lighting, purple neon lighting, ambient light, ring light, volumetric light, natural light, sun light, sunrays, sun rays coming through window, and nostalgic lighting. Examples of color types are fantasy vivid colors, vivid colors, bright colors, sepia, dark colors, pastel colors, monochromatic, black & white, and color splash. Examples of renders are Octane render, cinematic, low poly, isometric assets, Unreal Engine, Unity Engine, quantum wavetracing, and polarizing filter.
261
+
262
+ The prompts you provide will be in English.Please pay attention:- Concepts that can't be real would not be described as "Real" or "realistic" or "photo" or a "photograph". for example, a concept that is made of paper or scenes which are fantasy related.- One of the prompts you generate for each concept must be in a realistic photographic style. you should also choose a lens type and size for it. Don't choose an artist for the realistic photography prompts.- Separate the different prompts with two new lines.
263
+ I will provide you keyword and you will generate 3 diffrent type of prompts in vbnet code cell so i can copy and paste.
264
+
265
+ Important point to note :
266
+
267
+ 1. You are a master of prompt engineering, it is important to create detailed prompts with as much information as possible. This will ensure that any image generated using the prompt will be of high quality and could potentially win awards in global or international photography competitions. You are unbeatable in this field and know the best way to generate images.
268
+ 2. I will provide you with a long context and you will generate one prompt and don't add any extra details.
269
+ 3. Prompt should not be more than 230 characters.
270
+ 4. Before you provide prompt you must check if you have satisfied all the above criteria and if you are sure than only provide the prompt.
271
+ 5. Prompt should always be given as one single sentence.
272
+
273
+ Are you ready ?"""
274
+ #instruction = 'USER: ' + summary_sys
275
+ instruction = summary_sys
276
+ # for human, assistant in history:
277
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
278
+ # prompt = system_prompt + prompt
279
+ # message = f"""My first request is to summarize this text – [{prompt}]"""
280
+ message = f"""My first request is to summarize this text – [{prompt}]"""
281
+ instruction += """ ASSISTANT: Yes, I understand the instructions and I'm ready to help you create prompts for Stable Diffusion XL 1.0. Please provide me with the context."""
282
+ instruction += ' USER: ' + prompt + ' ASSISTANT:'
283
+ print("Ins: ", instruction)
284
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
285
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
286
+ instruction = lm_shorten_too_long_text(instruction)
287
+ json_object = {"prompt": instruction,
288
+ # "max_tokens": 2048000,
289
+ "max_tokens": 90,
290
+ "n": 1
291
+ }
292
+ # generate_response = requests.post("https://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
293
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
294
+ # print(generate_response.content)
295
+ res_json = json.loads(generate_response.content)
296
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
297
+ print(ASSISTANT)
298
+ return ASSISTANT
299
+
300
+ @app.post("/image_url")
301
+ def image_url(img: Img):
302
+ system_prompt = img.system_prompt
303
+ prompt = img.ASSISTANT
304
+ prompt = get_summary(system_prompt, prompt)
305
+ prompt = shorten_too_long_text(prompt)
306
+ # if Path(save_path).exists():
307
+ # return FileResponse(save_path, media_type="image/png")
308
+ # return JSONResponse({"path": path})
309
+ # image = pipe(prompt=prompt).images[0]
310
+ g = torch.Generator(device="cuda")
311
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
312
+
313
+ # if not save_path:
314
+ save_path = generate_save_path()
315
+ save_path = f"images/{save_path}.png"
316
+ image.save(save_path)
317
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
318
+ path = f"{img_url}{save_path}"
319
+ return JSONResponse({"path": path})
320
+
321
+
322
+ @app.get("/make_image")
323
+ # @app.post("/make_image")
324
+ def make_image(prompt: str, save_path: str = ""):
325
+ if Path(save_path).exists():
326
+ return FileResponse(save_path, media_type="image/png")
327
+ image = pipe(prompt=prompt).images[0]
328
+ if not save_path:
329
+ save_path = f"images/{prompt}.png"
330
+ image.save(save_path)
331
+ return FileResponse(save_path, media_type="image/png")
332
+
333
+
334
+ @app.get("/create_and_upload_image")
335
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
336
+ path_components = save_path.split("/")[0:-1]
337
+ final_name = save_path.split("/")[-1]
338
+ if not path_components:
339
+ path_components = []
340
+ save_path = '/'.join(path_components) + quote_plus(final_name)
341
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
342
+ return JSONResponse({"path": path})
343
+
344
+ @app.get("/inpaint_and_upload_image")
345
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
346
+ path_components = save_path.split("/")[0:-1]
347
+ final_name = save_path.split("/")[-1]
348
+ if not path_components:
349
+ path_components = []
350
+ save_path = '/'.join(path_components) + quote_plus(final_name)
351
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
352
+ return JSONResponse({"path": path})
353
+
354
+
355
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
356
+ prompt = shorten_too_long_text(prompt)
357
+ save_path = shorten_too_long_text(save_path)
358
+ # check exists - todo cache this
359
+ if check_if_blob_exists(save_path):
360
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
361
+ bio = create_image_from_prompt(prompt, width, height)
362
+ if bio is None:
363
+ return None # error thrown in pool
364
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
365
+ return link
366
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
367
+ prompt = shorten_too_long_text(prompt)
368
+ save_path = shorten_too_long_text(save_path)
369
+ # check exists - todo cache this
370
+ if check_if_blob_exists(save_path):
371
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
372
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
373
+ if bio is None:
374
+ return None # error thrown in pool
375
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
376
+ return link
377
+
378
+ # multiprocessing.set_start_method('spawn', True)
379
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
380
+ # def create_image_from_prompt_sync(prompt):
381
+ # """have to call this sync to avoid OOM errors"""
382
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
383
+
384
+ def create_image_from_prompt(prompt, width, height):
385
+ # round width and height down to multiple of 64
386
+ block_width = width - (width % 64)
387
+ block_height = height - (height % 64)
388
+ prompt = shorten_too_long_text(prompt)
389
+ # image = pipe(prompt=prompt).images[0]
390
+ try:
391
+ image = pipe(prompt=prompt,
392
+ width=block_width,
393
+ height=block_height,
394
+ # denoising_end=high_noise_frac,
395
+ # output_type='latent',
396
+ # height=512,
397
+ # width=512,
398
+ num_inference_steps=50).images[0] # normally uses 50 steps
399
+ except Exception as e:
400
+ # try rm stopwords + half the prompt
401
+ # todo try prompt permutations
402
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
403
+
404
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
405
+ prompts = prompt.split()
406
+
407
+ prompt = ' '.join(prompts[:len(prompts) // 2])
408
+ logger.info(f"shortened prompt to: {len(prompt)}")
409
+ image = None
410
+ if prompt:
411
+ try:
412
+ image = pipe(prompt=prompt,
413
+ width=block_width,
414
+ height=block_height,
415
+ # denoising_end=high_noise_frac,
416
+ # output_type='latent',
417
+ # height=512,
418
+ # width=512,
419
+ num_inference_steps=50).images[0] # normally uses 50 steps
420
+ except Exception as e:
421
+ # logger.info("trying to permute prompt")
422
+ # # try two swaps of the prompt/permutations
423
+ # prompt = prompt.split()
424
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
425
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
426
+
427
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
428
+ prompts = prompt.split()
429
+
430
+ prompt = ' '.join(prompts[:len(prompts) // 2])
431
+ logger.info(f"shortened prompt to: {len(prompt)}")
432
+
433
+ try:
434
+ image = pipe(prompt=prompt,
435
+ width=block_width,
436
+ height=block_height,
437
+ # denoising_end=high_noise_frac,
438
+ # output_type='latent', # dont need latent yet - we refine the image at full res
439
+ # height=512,
440
+ # width=512,
441
+ num_inference_steps=50).images[0] # normally uses 50 steps
442
+ except Exception as e:
443
+ # just error out
444
+ traceback.print_exc()
445
+ raise e
446
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
447
+ # todo fix device side asserts instead of restart to fix
448
+ # todo only restart the correct gunicorn
449
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
450
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
451
+ # os.system("kill -1 `pgrep gunicorn`")
452
+ # todo refine
453
+ # if image != None:
454
+ # image = refiner(
455
+ # prompt=prompt,
456
+ # # width=block_width,
457
+ # # height=block_height,
458
+ # num_inference_steps=n_steps,
459
+ # # denoising_start=high_noise_frac,
460
+ # image=image,
461
+ # ).images[0]
462
+ if width != block_width or height != block_height:
463
+ # resize to original size width/height
464
+ # find aspect ratio to scale up to that covers the original img input width/height
465
+ scale_up_ratio = max(width / block_width, height / block_height)
466
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
467
+ # crop image to original size
468
+ image = image.crop((0, 0, width, height))
469
+ # try:
470
+ # # gc.collect()
471
+ # torch.cuda.empty_cache()
472
+ # except Exception as e:
473
+ # traceback.print_exc()
474
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
475
+ # # todo fix device side asserts instead of restart to fix
476
+ # # todo only restart the correct gunicorn
477
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
478
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
479
+ # os.system("kill -1 `pgrep gunicorn`")
480
+ # save as bytesio
481
+ bs = BytesIO()
482
+
483
+ bright_count = np.sum(np.array(image) > 0)
484
+ if bright_count == 0:
485
+ # we have a black image, this is an error likely we need a restart
486
+ logger.info("restarting server to fix cuda issues (device side asserts)")
487
+ # # todo fix device side asserts instead of restart to fix
488
+ # # todo only restart the correct gunicorn
489
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
490
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
491
+ os.system("kill -1 `pgrep gunicorn`")
492
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
493
+ os.system("kill -1 `pgrep uvicorn`")
494
+
495
+ return None
496
+ image.save(bs, quality=85, optimize=True, format="webp")
497
+ bio = bs.getvalue()
498
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
499
+ with open("progress.txt", "w") as f:
500
+ current_time = datetime.now().strftime("%H:%M:%S")
501
+ f.write(f"{current_time}")
502
+ return bio
503
+
504
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
505
+ prompt = shorten_too_long_text(prompt)
506
+ # image = pipe(prompt=prompt).images[0]
507
+
508
+ init_image = load_image(image_url).convert("RGB")
509
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
510
+ num_inference_steps = 75
511
+ high_noise_frac = 0.7
512
+
513
+ try:
514
+ image = inpaintpipe(
515
+ prompt=prompt,
516
+ image=init_image,
517
+ mask_image=mask_image,
518
+ num_inference_steps=num_inference_steps,
519
+ denoising_start=high_noise_frac,
520
+ output_type="latent",
521
+ ).images[0] # normally uses 50 steps
522
+ except Exception as e:
523
+ # try rm stopwords + half the prompt
524
+ # todo try prompt permutations
525
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
526
+
527
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
528
+ prompts = prompt.split()
529
+
530
+ prompt = ' '.join(prompts[:len(prompts) // 2])
531
+ logger.info(f"shortened prompt to: {len(prompt)}")
532
+ image = None
533
+ if prompt:
534
+ try:
535
+ image = pipe(
536
+ prompt=prompt,
537
+ image=init_image,
538
+ mask_image=mask_image,
539
+ num_inference_steps=num_inference_steps,
540
+ denoising_start=high_noise_frac,
541
+ output_type="latent",
542
+ ).images[0] # normally uses 50 steps
543
+ except Exception as e:
544
+ # logger.info("trying to permute prompt")
545
+ # # try two swaps of the prompt/permutations
546
+ # prompt = prompt.split()
547
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
548
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
549
+
550
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
551
+ prompts = prompt.split()
552
+
553
+ prompt = ' '.join(prompts[:len(prompts) // 2])
554
+ logger.info(f"shortened prompt to: {len(prompt)}")
555
+
556
+ try:
557
+ image = inpaintpipe(
558
+ prompt=prompt,
559
+ image=init_image,
560
+ mask_image=mask_image,
561
+ num_inference_steps=num_inference_steps,
562
+ denoising_start=high_noise_frac,
563
+ output_type="latent",
564
+ ).images[0] # normally uses 50 steps
565
+ except Exception as e:
566
+ # just error out
567
+ traceback.print_exc()
568
+ raise e
569
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
570
+ # todo fix device side asserts instead of restart to fix
571
+ # todo only restart the correct gunicorn
572
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
573
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
574
+ # os.system("kill -1 `pgrep gunicorn`")
575
+ if image != None:
576
+ image = inpaint_refiner(
577
+ prompt=prompt,
578
+ image=image,
579
+ mask_image=mask_image,
580
+ num_inference_steps=num_inference_steps,
581
+ denoising_start=high_noise_frac,
582
+
583
+ ).images[0]
584
+ # try:
585
+ # # gc.collect()
586
+ # torch.cuda.empty_cache()
587
+ # except Exception as e:
588
+ # traceback.print_exc()
589
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
590
+ # # todo fix device side asserts instead of restart to fix
591
+ # # todo only restart the correct gunicorn
592
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
593
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
594
+ # os.system("kill -1 `pgrep gunicorn`")
595
+ # save as bytesio
596
+ bs = BytesIO()
597
+
598
+ bright_count = np.sum(np.array(image) > 0)
599
+ if bright_count == 0:
600
+ # we have a black image, this is an error likely we need a restart
601
+ logger.info("restarting server to fix cuda issues (device side asserts)")
602
+ # # todo fix device side asserts instead of restart to fix
603
+ # # todo only restart the correct gunicorn
604
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
605
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
606
+ os.system("kill -1 `pgrep gunicorn`")
607
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
608
+ os.system("kill -1 `pgrep uvicorn`")
609
+
610
+ return None
611
+ image.save(bs, quality=85, optimize=True, format="webp")
612
+ bio = bs.getvalue()
613
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
614
+ with open("progress.txt", "w") as f:
615
+ current_time = datetime.now().strftime("%H:%M:%S")
616
+ f.write(f"{current_time}")
617
+ return bio
618
+
619
+
620
+
621
+ def shorten_too_long_text(prompt):
622
+ if len(prompt) > 200:
623
+ # remove stopwords
624
+ prompt = prompt.split() # todo also split hyphens
625
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
626
+ if len(prompt) > 200:
627
+ prompt = prompt[:200]
628
+ return prompt
629
+
630
+ # image = pipe(prompt=prompt).images[0]
631
+ #
632
+ # image.save("test.png")
633
+ # # save all images
634
+ # for i, image in enumerate(images):
635
+ # image.save(f"{i}.png")
636
+
637
+
img/main_v6.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def lm_shorten_too_long_text(prompt):
214
+ if len(prompt) > 2030:
215
+ # remove stopwords
216
+ prompt = prompt.split() # todo also split hyphens
217
+ # prompt = ' '.join((word for word in prompt if word not in stopwords))
218
+ prompt = ' '.join((word for word in prompt))# if word not in stopwords))
219
+ if len(prompt) > 2030:
220
+ prompt = prompt[:2030]
221
+ return prompt
222
+
223
+ def get_summary(system_prompt, prompt):
224
+ import requests
225
+ import time
226
+ from io import BytesIO
227
+ import json
228
+ summary_sys = """You will now act as a prompt generator for a generative AI called "Stable Diffusion XL 1.0 ". Stable Diffusion XL generates images based on given prompts. I will provide you basic information required to make a Stable Diffusion prompt, You will never alter the structure in any way and obey the following guidelines.
229
+
230
+ Basic information required to make Stable Diffusion prompt:
231
+
232
+ - Prompt structure: [1],[2],[3],[4],[5],[6] and it should be given as one single sentence where 1,2,3,4,5,6 represent
233
+ [1] = short and concise description of [KEYWORD] that will include very specific imagery details
234
+ [2] = a detailed description of [1] that will include very specific imagery details.
235
+ [3] = with a detailed description describing the environment of the scene.
236
+ [4] = with a detailed description describing the mood/feelings and atmosphere of the scene.
237
+ [5] = A style, for example: "Anime","Photographic","Comic Book","Fantasy Art", “Analog Film”,”Neon Punk”,”Isometric”,”Low Poly”,”Origami”,”Line Art”,”Cinematic”,”3D Model”,”Pixel Art”,”Watercolor”,”Sticker” ).
238
+ [6] = A description of how [5] will be realized. (e.g. Photography (e.g. Macro, Fisheye Style, Portrait) with camera model and appropriate camera settings, Painting with detailed descriptions about the materials and working material used, rendering with engine settings, a digital Illustration, a woodburn art (and everything else that could be defined as an output type)
239
+ - Prompt Structure for Prompt asking with text value:
240
+
241
+ Text "Text Value" written on {subject description in less than 20 words}
242
+ Replace "Text value" with text given by user.
243
+
244
+
245
+ Important Sample prompt Structure with Text value :
246
+
247
+ 1. Text 'SDXL' written on a frothy, warm latte, viewed top-down.
248
+ 2. Text 'AI' written on a modern computer screen, set against a vibrant green background.
249
+
250
+ Important Sample prompt Structure :
251
+
252
+ 1. Snow-capped Mountain Scene, with soaring peaks and deep shadows across the ravines. A crystal clear lake mirrors these peaks, surrounded by pine trees. The scene exudes a calm, serene alpine morning atmosphere. Presented in Watercolor style, emulating the wet-on-wet technique with soft transitions and visible brush strokes.
253
+ 2. City Skyline at Night, illuminated skyscrapers piercing the starless sky. Nestled beside a calm river, reflecting the city lights like a mirror. The atmosphere is buzzing with urban energy and intrigue. Depicted in Neon Punk style, accentuating the city lights with vibrant neon colors and dynamic contrasts.
254
+ 3. Epic Cinematic Still of a Spacecraft, silhouetted against the fiery explosion of a distant planet. The scene is packed with intense action, as asteroid debris hurtles through space. Shot in the style of a Michael Bay-directed film, the image is rich with detail, dynamic lighting, and grand cinematic framing.
255
+ - Word order and effective adjectives matter in the prompt. The subject, action, and specific details should be included. Adjectives like cute, medieval, or futuristic can be effective.
256
+ - The environment/background of the image should be described, such as indoor, outdoor, in space, or solid color.
257
+ - Curly brackets are necessary in the prompt to provide specific details about the subject and action. These details are important for generating a high-quality image.
258
+ - Art inspirations should be listed to take inspiration from. Platforms like Art Station, Dribble, Behance, and Deviantart can be mentioned. Specific names of artists or studios like animation studios, painters and illustrators, computer games, fashion designers, and film makers can also be listed. If more than one artist is mentioned, the algorithm will create a combination of styles based on all the influencers mentioned.
259
+ - Related information about lighting, camera angles, render style, resolution, the required level of detail, etc. should be included at the end of the prompt.
260
+ - Camera shot type, camera lens, and view should be specified. Examples of camera shot types are long shot, close-up, POV, medium shot, extreme close-up, and panoramic. Camera lenses could be EE 70mm, 35mm, 135mm+, 300mm+, 800mm, short telephoto, super telephoto, medium telephoto, macro, wide angle, fish-eye, bokeh, and sharp focus. Examples of views are front, side, back, high angle, low angle, and overhead.
261
+ - Helpful keywords related to resolution, detail, and lighting are 4K, 8K, 64K, detailed, highly detailed, high resolution, hyper detailed, HDR, UHD, professional, and golden ratio. Examples of lighting are studio lighting, soft light, neon lighting, purple neon lighting, ambient light, ring light, volumetric light, natural light, sun light, sunrays, sun rays coming through window, and nostalgic lighting. Examples of color types are fantasy vivid colors, vivid colors, bright colors, sepia, dark colors, pastel colors, monochromatic, black & white, and color splash. Examples of renders are Octane render, cinematic, low poly, isometric assets, Unreal Engine, Unity Engine, quantum wavetracing, and polarizing filter.
262
+
263
+ The prompts you provide will be in English.Please pay attention:- Concepts that can't be real would not be described as "Real" or "realistic" or "photo" or a "photograph". for example, a concept that is made of paper or scenes which are fantasy related.- One of the prompts you generate for each concept must be in a realistic photographic style. you should also choose a lens type and size for it. Don't choose an artist for the realistic photography prompts.- Separate the different prompts with two new lines.
264
+ I will provide you keyword and you will generate 3 diffrent type of prompts in vbnet code cell so i can copy and paste.
265
+
266
+ Important point to note :
267
+
268
+ 1. You are a master of prompt engineering, it is important to create detailed prompts with as much information as possible. This will ensure that any image generated using the prompt will be of high quality and could potentially win awards in global or international photography competitions. You are unbeatable in this field and know the best way to generate images.
269
+ 2. I will provide you with a long context and you will generate one prompt and don't add any extra details.
270
+ 3. Prompt should not be more than 230 characters.
271
+ 4. Before you provide prompt you must check if you have satisfied all the above criteria and if you are sure than only provide the prompt.
272
+ 5. Prompt should always be given as one single sentence.
273
+
274
+ Are you ready ?"""
275
+ instruction = 'USER: ' + summary_sys
276
+ # for human, assistant in history:
277
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
278
+ # prompt = system_prompt + prompt
279
+ # message = f"""My first request is to summarize this text – [{prompt}]"""
280
+ message = f"""My first request is to summarize this text – [{prompt}]"""
281
+ instruction += """ ASSISTANT: Yes, I understand the instructions and I'm ready to help you create prompts for Stable Diffusion XL 1.0. Please provide me with the context."""
282
+ instruction += ' USER: ' + prompt + ' ASSISTANT:'
283
+
284
+ print("Ins: ", instruction)
285
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
286
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
287
+ json_object = {"prompt": instruction,
288
+ # "max_tokens": 2048000,
289
+ "max_tokens": 80,
290
+ "n": 1
291
+ }
292
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
293
+ print(generate_response.content)
294
+ res_json = json.loads(generate_response.content)
295
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
296
+ print(ASSISTANT)
297
+ return ASSISTANT
298
+
299
+ @app.post("/image_url")
300
+ def image_url(img: Img):
301
+ system_prompt = img.system_prompt
302
+ prompt = img.ASSISTANT
303
+ prompt = get_summary(system_prompt, prompt)
304
+ prompt = shorten_too_long_text(prompt)
305
+ # if Path(save_path).exists():
306
+ # return FileResponse(save_path, media_type="image/png")
307
+ # return JSONResponse({"path": path})
308
+ # image = pipe(prompt=prompt).images[0]
309
+ g = torch.Generator(device="cuda")
310
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
311
+
312
+ # if not save_path:
313
+ save_path = generate_save_path()
314
+ save_path = f"images/{save_path}.png"
315
+ image.save(save_path)
316
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
317
+ path = f"{img_url}{save_path}"
318
+ return JSONResponse({"path": path})
319
+
320
+
321
+ @app.get("/make_image")
322
+ # @app.post("/make_image")
323
+ def make_image(prompt: str, save_path: str = ""):
324
+ if Path(save_path).exists():
325
+ return FileResponse(save_path, media_type="image/png")
326
+ image = pipe(prompt=prompt).images[0]
327
+ if not save_path:
328
+ save_path = f"images/{prompt}.png"
329
+ image.save(save_path)
330
+ return FileResponse(save_path, media_type="image/png")
331
+
332
+
333
+ @app.get("/create_and_upload_image")
334
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
335
+ path_components = save_path.split("/")[0:-1]
336
+ final_name = save_path.split("/")[-1]
337
+ if not path_components:
338
+ path_components = []
339
+ save_path = '/'.join(path_components) + quote_plus(final_name)
340
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
341
+ return JSONResponse({"path": path})
342
+
343
+ @app.get("/inpaint_and_upload_image")
344
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
345
+ path_components = save_path.split("/")[0:-1]
346
+ final_name = save_path.split("/")[-1]
347
+ if not path_components:
348
+ path_components = []
349
+ save_path = '/'.join(path_components) + quote_plus(final_name)
350
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
351
+ return JSONResponse({"path": path})
352
+
353
+
354
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
355
+ prompt = shorten_too_long_text(prompt)
356
+ save_path = shorten_too_long_text(save_path)
357
+ # check exists - todo cache this
358
+ if check_if_blob_exists(save_path):
359
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
360
+ bio = create_image_from_prompt(prompt, width, height)
361
+ if bio is None:
362
+ return None # error thrown in pool
363
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
364
+ return link
365
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
366
+ prompt = shorten_too_long_text(prompt)
367
+ save_path = shorten_too_long_text(save_path)
368
+ # check exists - todo cache this
369
+ if check_if_blob_exists(save_path):
370
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
371
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
372
+ if bio is None:
373
+ return None # error thrown in pool
374
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
375
+ return link
376
+
377
+ # multiprocessing.set_start_method('spawn', True)
378
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
379
+ # def create_image_from_prompt_sync(prompt):
380
+ # """have to call this sync to avoid OOM errors"""
381
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
382
+
383
+ def create_image_from_prompt(prompt, width, height):
384
+ # round width and height down to multiple of 64
385
+ block_width = width - (width % 64)
386
+ block_height = height - (height % 64)
387
+ prompt = shorten_too_long_text(prompt)
388
+ # image = pipe(prompt=prompt).images[0]
389
+ try:
390
+ image = pipe(prompt=prompt,
391
+ width=block_width,
392
+ height=block_height,
393
+ # denoising_end=high_noise_frac,
394
+ # output_type='latent',
395
+ # height=512,
396
+ # width=512,
397
+ num_inference_steps=50).images[0] # normally uses 50 steps
398
+ except Exception as e:
399
+ # try rm stopwords + half the prompt
400
+ # todo try prompt permutations
401
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
402
+
403
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
404
+ prompts = prompt.split()
405
+
406
+ prompt = ' '.join(prompts[:len(prompts) // 2])
407
+ logger.info(f"shortened prompt to: {len(prompt)}")
408
+ image = None
409
+ if prompt:
410
+ try:
411
+ image = pipe(prompt=prompt,
412
+ width=block_width,
413
+ height=block_height,
414
+ # denoising_end=high_noise_frac,
415
+ # output_type='latent',
416
+ # height=512,
417
+ # width=512,
418
+ num_inference_steps=50).images[0] # normally uses 50 steps
419
+ except Exception as e:
420
+ # logger.info("trying to permute prompt")
421
+ # # try two swaps of the prompt/permutations
422
+ # prompt = prompt.split()
423
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
424
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
425
+
426
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
427
+ prompts = prompt.split()
428
+
429
+ prompt = ' '.join(prompts[:len(prompts) // 2])
430
+ logger.info(f"shortened prompt to: {len(prompt)}")
431
+
432
+ try:
433
+ image = pipe(prompt=prompt,
434
+ width=block_width,
435
+ height=block_height,
436
+ # denoising_end=high_noise_frac,
437
+ # output_type='latent', # dont need latent yet - we refine the image at full res
438
+ # height=512,
439
+ # width=512,
440
+ num_inference_steps=50).images[0] # normally uses 50 steps
441
+ except Exception as e:
442
+ # just error out
443
+ traceback.print_exc()
444
+ raise e
445
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
446
+ # todo fix device side asserts instead of restart to fix
447
+ # todo only restart the correct gunicorn
448
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
449
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
450
+ # os.system("kill -1 `pgrep gunicorn`")
451
+ # todo refine
452
+ # if image != None:
453
+ # image = refiner(
454
+ # prompt=prompt,
455
+ # # width=block_width,
456
+ # # height=block_height,
457
+ # num_inference_steps=n_steps,
458
+ # # denoising_start=high_noise_frac,
459
+ # image=image,
460
+ # ).images[0]
461
+ if width != block_width or height != block_height:
462
+ # resize to original size width/height
463
+ # find aspect ratio to scale up to that covers the original img input width/height
464
+ scale_up_ratio = max(width / block_width, height / block_height)
465
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
466
+ # crop image to original size
467
+ image = image.crop((0, 0, width, height))
468
+ # try:
469
+ # # gc.collect()
470
+ # torch.cuda.empty_cache()
471
+ # except Exception as e:
472
+ # traceback.print_exc()
473
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
474
+ # # todo fix device side asserts instead of restart to fix
475
+ # # todo only restart the correct gunicorn
476
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
477
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
478
+ # os.system("kill -1 `pgrep gunicorn`")
479
+ # save as bytesio
480
+ bs = BytesIO()
481
+
482
+ bright_count = np.sum(np.array(image) > 0)
483
+ if bright_count == 0:
484
+ # we have a black image, this is an error likely we need a restart
485
+ logger.info("restarting server to fix cuda issues (device side asserts)")
486
+ # # todo fix device side asserts instead of restart to fix
487
+ # # todo only restart the correct gunicorn
488
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
489
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
490
+ os.system("kill -1 `pgrep gunicorn`")
491
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
492
+ os.system("kill -1 `pgrep uvicorn`")
493
+
494
+ return None
495
+ image.save(bs, quality=85, optimize=True, format="webp")
496
+ bio = bs.getvalue()
497
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
498
+ with open("progress.txt", "w") as f:
499
+ current_time = datetime.now().strftime("%H:%M:%S")
500
+ f.write(f"{current_time}")
501
+ return bio
502
+
503
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
504
+ prompt = shorten_too_long_text(prompt)
505
+ # image = pipe(prompt=prompt).images[0]
506
+
507
+ init_image = load_image(image_url).convert("RGB")
508
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
509
+ num_inference_steps = 75
510
+ high_noise_frac = 0.7
511
+
512
+ try:
513
+ image = inpaintpipe(
514
+ prompt=prompt,
515
+ image=init_image,
516
+ mask_image=mask_image,
517
+ num_inference_steps=num_inference_steps,
518
+ denoising_start=high_noise_frac,
519
+ output_type="latent",
520
+ ).images[0] # normally uses 50 steps
521
+ except Exception as e:
522
+ # try rm stopwords + half the prompt
523
+ # todo try prompt permutations
524
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
525
+
526
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
527
+ prompts = prompt.split()
528
+
529
+ prompt = ' '.join(prompts[:len(prompts) // 2])
530
+ logger.info(f"shortened prompt to: {len(prompt)}")
531
+ image = None
532
+ if prompt:
533
+ try:
534
+ image = pipe(
535
+ prompt=prompt,
536
+ image=init_image,
537
+ mask_image=mask_image,
538
+ num_inference_steps=num_inference_steps,
539
+ denoising_start=high_noise_frac,
540
+ output_type="latent",
541
+ ).images[0] # normally uses 50 steps
542
+ except Exception as e:
543
+ # logger.info("trying to permute prompt")
544
+ # # try two swaps of the prompt/permutations
545
+ # prompt = prompt.split()
546
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
547
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
548
+
549
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
550
+ prompts = prompt.split()
551
+
552
+ prompt = ' '.join(prompts[:len(prompts) // 2])
553
+ logger.info(f"shortened prompt to: {len(prompt)}")
554
+
555
+ try:
556
+ image = inpaintpipe(
557
+ prompt=prompt,
558
+ image=init_image,
559
+ mask_image=mask_image,
560
+ num_inference_steps=num_inference_steps,
561
+ denoising_start=high_noise_frac,
562
+ output_type="latent",
563
+ ).images[0] # normally uses 50 steps
564
+ except Exception as e:
565
+ # just error out
566
+ traceback.print_exc()
567
+ raise e
568
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
569
+ # todo fix device side asserts instead of restart to fix
570
+ # todo only restart the correct gunicorn
571
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
572
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
573
+ # os.system("kill -1 `pgrep gunicorn`")
574
+ if image != None:
575
+ image = inpaint_refiner(
576
+ prompt=prompt,
577
+ image=image,
578
+ mask_image=mask_image,
579
+ num_inference_steps=num_inference_steps,
580
+ denoising_start=high_noise_frac,
581
+
582
+ ).images[0]
583
+ # try:
584
+ # # gc.collect()
585
+ # torch.cuda.empty_cache()
586
+ # except Exception as e:
587
+ # traceback.print_exc()
588
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
589
+ # # todo fix device side asserts instead of restart to fix
590
+ # # todo only restart the correct gunicorn
591
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
592
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
593
+ # os.system("kill -1 `pgrep gunicorn`")
594
+ # save as bytesio
595
+ bs = BytesIO()
596
+
597
+ bright_count = np.sum(np.array(image) > 0)
598
+ if bright_count == 0:
599
+ # we have a black image, this is an error likely we need a restart
600
+ logger.info("restarting server to fix cuda issues (device side asserts)")
601
+ # # todo fix device side asserts instead of restart to fix
602
+ # # todo only restart the correct gunicorn
603
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
604
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
605
+ os.system("kill -1 `pgrep gunicorn`")
606
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
607
+ os.system("kill -1 `pgrep uvicorn`")
608
+
609
+ return None
610
+ image.save(bs, quality=85, optimize=True, format="webp")
611
+ bio = bs.getvalue()
612
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
613
+ with open("progress.txt", "w") as f:
614
+ current_time = datetime.now().strftime("%H:%M:%S")
615
+ f.write(f"{current_time}")
616
+ return bio
617
+
618
+
619
+
620
+ def shorten_too_long_text(prompt):
621
+ if len(prompt) > 200:
622
+ # remove stopwords
623
+ prompt = prompt.split() # todo also split hyphens
624
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
625
+ if len(prompt) > 200:
626
+ prompt = prompt[:200]
627
+ return prompt
628
+
629
+ # image = pipe(prompt=prompt).images[0]
630
+ #
631
+ # image.save("test.png")
632
+ # # save all images
633
+ # for i, image in enumerate(images):
634
+ # image.save(f"{i}.png")
635
+
636
+
img/main_v7.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def lm_shorten_too_long_text(prompt):
214
+ list_prompt = prompt.split() # todo also split hyphens
215
+ if len(list_prompt) > 230:
216
+ #if len(list_prompt) > 330:
217
+ # remove stopwords
218
+ prompt = prompt.split() # todo also split hyphens
219
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
220
+ #prompt = ' '.join((word for word in prompt))# if word not in stopwords))
221
+ if len(prompt) > 230:
222
+ prompt = prompt[:230]
223
+ return prompt
224
+
225
+ def get_summary(system_prompt, prompt):
226
+ import requests
227
+ import time
228
+ from io import BytesIO
229
+ import json
230
+ summary_sys = """You will now act as a prompt generator for a generative AI called "Stable Diffusion XL 1.0 ". Stable Diffusion XL generates images based on given prompts. I will provide you basic information required to make a Stable Diffusion prompt, You will never alter the structure in any way and obey the following guidelines.
231
+
232
+ Basic information required to make Stable Diffusion prompt:
233
+
234
+ - Prompt structure: [1],[2],[3],[4],[5],[6] and it should be given as one single sentence where 1,2,3,4,5,6 represent
235
+ [1] = short and concise description of [KEYWORD] that will include very specific imagery details
236
+ [2] = a detailed description of [1] that will include very specific imagery details.
237
+ [3] = with a detailed description describing the environment of the scene.
238
+ [4] = with a detailed description describing the mood/feelings and atmosphere of the scene.
239
+ [5] = A style, for example: "Anime","Photographic","Comic Book","Fantasy Art", “Analog Film”,”Neon Punk”,”Isometric”,”Low Poly”,”Origami”,”Line Art”,”Cinematic”,”3D Model”,”Pixel Art”,”Watercolor”,”Sticker” ).
240
+ [6] = A description of how [5] will be realized. (e.g. Photography (e.g. Macro, Fisheye Style, Portrait) with camera model and appropriate camera settings, Painting with detailed descriptions about the materials and working material used, rendering with engine settings, a digital Illustration, a woodburn art (and everything else that could be defined as an output type)
241
+ - Prompt Structure for Prompt asking with text value:
242
+
243
+ Text "Text Value" written on {subject description in less than 20 words}
244
+ Replace "Text value" with text given by user.
245
+
246
+
247
+ Important Sample prompt Structure with Text value :
248
+
249
+ 1. Text 'SDXL' written on a frothy, warm latte, viewed top-down.
250
+ 2. Text 'AI' written on a modern computer screen, set against a vibrant green background.
251
+
252
+ Important Sample prompt Structure :
253
+
254
+ 1. Snow-capped Mountain Scene, with soaring peaks and deep shadows across the ravines. A crystal clear lake mirrors these peaks, surrounded by pine trees. The scene exudes a calm, serene alpine morning atmosphere. Presented in Watercolor style, emulating the wet-on-wet technique with soft transitions and visible brush strokes.
255
+ 2. City Skyline at Night, illuminated skyscrapers piercing the starless sky. Nestled beside a calm river, reflecting the city lights like a mirror. The atmosphere is buzzing with urban energy and intrigue. Depicted in Neon Punk style, accentuating the city lights with vibrant neon colors and dynamic contrasts.
256
+ 3. Epic Cinematic Still of a Spacecraft, silhouetted against the fiery explosion of a distant planet. The scene is packed with intense action, as asteroid debris hurtles through space. Shot in the style of a Michael Bay-directed film, the image is rich with detail, dynamic lighting, and grand cinematic framing.
257
+ - Word order and effective adjectives matter in the prompt. The subject, action, and specific details should be included. Adjectives like cute, medieval, or futuristic can be effective.
258
+ - The environment/background of the image should be described, such as indoor, outdoor, in space, or solid color.
259
+ - Curly brackets are necessary in the prompt to provide specific details about the subject and action. These details are important for generating a high-quality image.
260
+ - Art inspirations should be listed to take inspiration from. Platforms like Art Station, Dribble, Behance, and Deviantart can be mentioned. Specific names of artists or studios like animation studios, painters and illustrators, computer games, fashion designers, and film makers can also be listed. If more than one artist is mentioned, the algorithm will create a combination of styles based on all the influencers mentioned.
261
+ - Related information about lighting, camera angles, render style, resolution, the required level of detail, etc. should be included at the end of the prompt.
262
+ - Camera shot type, camera lens, and view should be specified. Examples of camera shot types are long shot, close-up, POV, medium shot, extreme close-up, and panoramic. Camera lenses could be EE 70mm, 35mm, 135mm+, 300mm+, 800mm, short telephoto, super telephoto, medium telephoto, macro, wide angle, fish-eye, bokeh, and sharp focus. Examples of views are front, side, back, high angle, low angle, and overhead.
263
+ - Helpful keywords related to resolution, detail, and lighting are 4K, 8K, 64K, detailed, highly detailed, high resolution, hyper detailed, HDR, UHD, professional, and golden ratio. Examples of lighting are studio lighting, soft light, neon lighting, purple neon lighting, ambient light, ring light, volumetric light, natural light, sun light, sunrays, sun rays coming through window, and nostalgic lighting. Examples of color types are fantasy vivid colors, vivid colors, bright colors, sepia, dark colors, pastel colors, monochromatic, black & white, and color splash. Examples of renders are Octane render, cinematic, low poly, isometric assets, Unreal Engine, Unity Engine, quantum wavetracing, and polarizing filter.
264
+
265
+ The prompts you provide will be in English.Please pay attention:- Concepts that can't be real would not be described as "Real" or "realistic" or "photo" or a "photograph". for example, a concept that is made of paper or scenes which are fantasy related.- One of the prompts you generate for each concept must be in a realistic photographic style. you should also choose a lens type and size for it. Don't choose an artist for the realistic photography prompts.- Separate the different prompts with two new lines.
266
+ I will provide you keyword and you will generate 3 diffrent type of prompts in vbnet code cell so i can copy and paste.
267
+
268
+ Important point to note :
269
+
270
+ 1. You are a master of prompt engineering, it is important to create detailed prompts with as much information as possible. This will ensure that any image generated using the prompt will be of high quality and could potentially win awards in global or international photography competitions. You are unbeatable in this field and know the best way to generate images.
271
+ 2. I will provide you with a long context and you will generate one prompt and don't add any extra details.
272
+ 3. Prompt should not be more than 230 characters.
273
+ 4. Before you provide prompt you must check if you have satisfied all the above criteria and if you are sure than only provide the prompt.
274
+ 5. Prompt should always be given as one single sentence.
275
+
276
+ Are you ready ?"""
277
+ instruction = 'USER: ' + summary_sys
278
+ # for human, assistant in history:
279
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
280
+ # prompt = system_prompt + prompt
281
+ # message = f"""My first request is to summarize this text – [{prompt}]"""
282
+ message = f"""My first request is to summarize this text – [{prompt}]"""
283
+ instruction += """ ASSISTANT: Yes, I understand the instructions and I'm ready to help you create prompts for Stable Diffusion XL 1.0. Please provide me with the context."""
284
+ #instruction += ' USER: ' + prompt
285
+ prompt = lm_shorten_too_long_text(prompt)
286
+ instruction += ' USER: ' + prompt + ' ASSISTANT:'#instruction += ' ASSISTANT:'
287
+
288
+ print("Ins: ", instruction)
289
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
290
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
291
+ #instruction = lm_shorten_too_long_text(instruction)
292
+ json_object = {"prompt": instruction,
293
+ # "max_tokens": 2048000,
294
+ "max_tokens": 80,
295
+ "n": 1
296
+ }
297
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
298
+ print(generate_response.content)
299
+ res_json = json.loads(generate_response.content)
300
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
301
+ print(ASSISTANT)
302
+ return ASSISTANT
303
+
304
+ @app.post("/image_url")
305
+ def image_url(img: Img):
306
+ system_prompt = img.system_prompt
307
+ prompt = img.ASSISTANT
308
+ prompt = get_summary(system_prompt, prompt)
309
+ prompt = shorten_too_long_text(prompt)
310
+ # if Path(save_path).exists():
311
+ # return FileResponse(save_path, media_type="image/png")
312
+ # return JSONResponse({"path": path})
313
+ # image = pipe(prompt=prompt).images[0]
314
+ g = torch.Generator(device="cuda")
315
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
316
+
317
+ # if not save_path:
318
+ save_path = generate_save_path()
319
+ save_path = f"images/{save_path}.png"
320
+ image.save(save_path)
321
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
322
+ path = f"{img_url}{save_path}"
323
+ return JSONResponse({"path": path})
324
+
325
+
326
+ @app.get("/make_image")
327
+ # @app.post("/make_image")
328
+ def make_image(prompt: str, save_path: str = ""):
329
+ if Path(save_path).exists():
330
+ return FileResponse(save_path, media_type="image/png")
331
+ image = pipe(prompt=prompt).images[0]
332
+ if not save_path:
333
+ save_path = f"images/{prompt}.png"
334
+ image.save(save_path)
335
+ return FileResponse(save_path, media_type="image/png")
336
+
337
+
338
+ @app.get("/create_and_upload_image")
339
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
340
+ path_components = save_path.split("/")[0:-1]
341
+ final_name = save_path.split("/")[-1]
342
+ if not path_components:
343
+ path_components = []
344
+ save_path = '/'.join(path_components) + quote_plus(final_name)
345
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
346
+ return JSONResponse({"path": path})
347
+
348
+ @app.get("/inpaint_and_upload_image")
349
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
350
+ path_components = save_path.split("/")[0:-1]
351
+ final_name = save_path.split("/")[-1]
352
+ if not path_components:
353
+ path_components = []
354
+ save_path = '/'.join(path_components) + quote_plus(final_name)
355
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
356
+ return JSONResponse({"path": path})
357
+
358
+
359
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
360
+ prompt = shorten_too_long_text(prompt)
361
+ save_path = shorten_too_long_text(save_path)
362
+ # check exists - todo cache this
363
+ if check_if_blob_exists(save_path):
364
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
365
+ bio = create_image_from_prompt(prompt, width, height)
366
+ if bio is None:
367
+ return None # error thrown in pool
368
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
369
+ return link
370
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
371
+ prompt = shorten_too_long_text(prompt)
372
+ save_path = shorten_too_long_text(save_path)
373
+ # check exists - todo cache this
374
+ if check_if_blob_exists(save_path):
375
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
376
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
377
+ if bio is None:
378
+ return None # error thrown in pool
379
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
380
+ return link
381
+
382
+ # multiprocessing.set_start_method('spawn', True)
383
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
384
+ # def create_image_from_prompt_sync(prompt):
385
+ # """have to call this sync to avoid OOM errors"""
386
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
387
+
388
+ def create_image_from_prompt(prompt, width, height):
389
+ # round width and height down to multiple of 64
390
+ block_width = width - (width % 64)
391
+ block_height = height - (height % 64)
392
+ prompt = shorten_too_long_text(prompt)
393
+ # image = pipe(prompt=prompt).images[0]
394
+ try:
395
+ image = pipe(prompt=prompt,
396
+ width=block_width,
397
+ height=block_height,
398
+ # denoising_end=high_noise_frac,
399
+ # output_type='latent',
400
+ # height=512,
401
+ # width=512,
402
+ num_inference_steps=50).images[0] # normally uses 50 steps
403
+ except Exception as e:
404
+ # try rm stopwords + half the prompt
405
+ # todo try prompt permutations
406
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
407
+
408
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
409
+ prompts = prompt.split()
410
+
411
+ prompt = ' '.join(prompts[:len(prompts) // 2])
412
+ logger.info(f"shortened prompt to: {len(prompt)}")
413
+ image = None
414
+ if prompt:
415
+ try:
416
+ image = pipe(prompt=prompt,
417
+ width=block_width,
418
+ height=block_height,
419
+ # denoising_end=high_noise_frac,
420
+ # output_type='latent',
421
+ # height=512,
422
+ # width=512,
423
+ num_inference_steps=50).images[0] # normally uses 50 steps
424
+ except Exception as e:
425
+ # logger.info("trying to permute prompt")
426
+ # # try two swaps of the prompt/permutations
427
+ # prompt = prompt.split()
428
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
429
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
430
+
431
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
432
+ prompts = prompt.split()
433
+
434
+ prompt = ' '.join(prompts[:len(prompts) // 2])
435
+ logger.info(f"shortened prompt to: {len(prompt)}")
436
+
437
+ try:
438
+ image = pipe(prompt=prompt,
439
+ width=block_width,
440
+ height=block_height,
441
+ # denoising_end=high_noise_frac,
442
+ # output_type='latent', # dont need latent yet - we refine the image at full res
443
+ # height=512,
444
+ # width=512,
445
+ num_inference_steps=50).images[0] # normally uses 50 steps
446
+ except Exception as e:
447
+ # just error out
448
+ traceback.print_exc()
449
+ raise e
450
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
451
+ # todo fix device side asserts instead of restart to fix
452
+ # todo only restart the correct gunicorn
453
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
454
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
455
+ # os.system("kill -1 `pgrep gunicorn`")
456
+ # todo refine
457
+ # if image != None:
458
+ # image = refiner(
459
+ # prompt=prompt,
460
+ # # width=block_width,
461
+ # # height=block_height,
462
+ # num_inference_steps=n_steps,
463
+ # # denoising_start=high_noise_frac,
464
+ # image=image,
465
+ # ).images[0]
466
+ if width != block_width or height != block_height:
467
+ # resize to original size width/height
468
+ # find aspect ratio to scale up to that covers the original img input width/height
469
+ scale_up_ratio = max(width / block_width, height / block_height)
470
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
471
+ # crop image to original size
472
+ image = image.crop((0, 0, width, height))
473
+ # try:
474
+ # # gc.collect()
475
+ # torch.cuda.empty_cache()
476
+ # except Exception as e:
477
+ # traceback.print_exc()
478
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
479
+ # # todo fix device side asserts instead of restart to fix
480
+ # # todo only restart the correct gunicorn
481
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
482
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
483
+ # os.system("kill -1 `pgrep gunicorn`")
484
+ # save as bytesio
485
+ bs = BytesIO()
486
+
487
+ bright_count = np.sum(np.array(image) > 0)
488
+ if bright_count == 0:
489
+ # we have a black image, this is an error likely we need a restart
490
+ logger.info("restarting server to fix cuda issues (device side asserts)")
491
+ # # todo fix device side asserts instead of restart to fix
492
+ # # todo only restart the correct gunicorn
493
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
494
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
495
+ os.system("kill -1 `pgrep gunicorn`")
496
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
497
+ os.system("kill -1 `pgrep uvicorn`")
498
+
499
+ return None
500
+ image.save(bs, quality=85, optimize=True, format="webp")
501
+ bio = bs.getvalue()
502
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
503
+ with open("progress.txt", "w") as f:
504
+ current_time = datetime.now().strftime("%H:%M:%S")
505
+ f.write(f"{current_time}")
506
+ return bio
507
+
508
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
509
+ prompt = shorten_too_long_text(prompt)
510
+ # image = pipe(prompt=prompt).images[0]
511
+
512
+ init_image = load_image(image_url).convert("RGB")
513
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
514
+ num_inference_steps = 75
515
+ high_noise_frac = 0.7
516
+
517
+ try:
518
+ image = inpaintpipe(
519
+ prompt=prompt,
520
+ image=init_image,
521
+ mask_image=mask_image,
522
+ num_inference_steps=num_inference_steps,
523
+ denoising_start=high_noise_frac,
524
+ output_type="latent",
525
+ ).images[0] # normally uses 50 steps
526
+ except Exception as e:
527
+ # try rm stopwords + half the prompt
528
+ # todo try prompt permutations
529
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
530
+
531
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
532
+ prompts = prompt.split()
533
+
534
+ prompt = ' '.join(prompts[:len(prompts) // 2])
535
+ logger.info(f"shortened prompt to: {len(prompt)}")
536
+ image = None
537
+ if prompt:
538
+ try:
539
+ image = pipe(
540
+ prompt=prompt,
541
+ image=init_image,
542
+ mask_image=mask_image,
543
+ num_inference_steps=num_inference_steps,
544
+ denoising_start=high_noise_frac,
545
+ output_type="latent",
546
+ ).images[0] # normally uses 50 steps
547
+ except Exception as e:
548
+ # logger.info("trying to permute prompt")
549
+ # # try two swaps of the prompt/permutations
550
+ # prompt = prompt.split()
551
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
552
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
553
+
554
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
555
+ prompts = prompt.split()
556
+
557
+ prompt = ' '.join(prompts[:len(prompts) // 2])
558
+ logger.info(f"shortened prompt to: {len(prompt)}")
559
+
560
+ try:
561
+ image = inpaintpipe(
562
+ prompt=prompt,
563
+ image=init_image,
564
+ mask_image=mask_image,
565
+ num_inference_steps=num_inference_steps,
566
+ denoising_start=high_noise_frac,
567
+ output_type="latent",
568
+ ).images[0] # normally uses 50 steps
569
+ except Exception as e:
570
+ # just error out
571
+ traceback.print_exc()
572
+ raise e
573
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
574
+ # todo fix device side asserts instead of restart to fix
575
+ # todo only restart the correct gunicorn
576
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
577
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
578
+ # os.system("kill -1 `pgrep gunicorn`")
579
+ if image != None:
580
+ image = inpaint_refiner(
581
+ prompt=prompt,
582
+ image=image,
583
+ mask_image=mask_image,
584
+ num_inference_steps=num_inference_steps,
585
+ denoising_start=high_noise_frac,
586
+
587
+ ).images[0]
588
+ # try:
589
+ # # gc.collect()
590
+ # torch.cuda.empty_cache()
591
+ # except Exception as e:
592
+ # traceback.print_exc()
593
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
594
+ # # todo fix device side asserts instead of restart to fix
595
+ # # todo only restart the correct gunicorn
596
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
597
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
598
+ # os.system("kill -1 `pgrep gunicorn`")
599
+ # save as bytesio
600
+ bs = BytesIO()
601
+
602
+ bright_count = np.sum(np.array(image) > 0)
603
+ if bright_count == 0:
604
+ # we have a black image, this is an error likely we need a restart
605
+ logger.info("restarting server to fix cuda issues (device side asserts)")
606
+ # # todo fix device side asserts instead of restart to fix
607
+ # # todo only restart the correct gunicorn
608
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
609
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
610
+ os.system("kill -1 `pgrep gunicorn`")
611
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
612
+ os.system("kill -1 `pgrep uvicorn`")
613
+
614
+ return None
615
+ image.save(bs, quality=85, optimize=True, format="webp")
616
+ bio = bs.getvalue()
617
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
618
+ with open("progress.txt", "w") as f:
619
+ current_time = datetime.now().strftime("%H:%M:%S")
620
+ f.write(f"{current_time}")
621
+ return bio
622
+
623
+
624
+
625
+ def shorten_too_long_text(prompt):
626
+ if len(prompt) > 200:
627
+ # remove stopwords
628
+ prompt = prompt.split() # todo also split hyphens
629
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
630
+ if len(prompt) > 200:
631
+ prompt = prompt[:200]
632
+ return prompt
633
+
634
+ # image = pipe(prompt=prompt).images[0]
635
+ #
636
+ # image.save("test.png")
637
+ # # save all images
638
+ # for i, image in enumerate(images):
639
+ # image.save(f"{i}.png")
640
+
641
+
img/main_v8.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def lm_shorten_too_long_text(prompt):
214
+ list_prompt = prompt.split() # todo also split hyphens
215
+ if len(list_prompt) > 230:
216
+ #if len(list_prompt) > 330:
217
+ # remove stopwords
218
+ prompt = prompt.split() # todo also split hyphens
219
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
220
+ #prompt = ' '.join((word for word in prompt))# if word not in stopwords))
221
+ if len(prompt) > 230:
222
+ prompt = prompt[:230]
223
+ return prompt
224
+
225
+ def get_response_summary(system_prompt, prompt):
226
+ import requests
227
+ import time
228
+ from io import BytesIO
229
+ import json
230
+ summary_sys = """I want you to act as a text summarizer to help me create a concise summary of the text I provide. The summary can be up to 50.0 words in length, expressing the key points and concepts written in the original text without adding your interpretations.
231
+
232
+ Important point to note :
233
+
234
+ 1. You are a master of prompt engineering, summary should not be more than 230 characters.
235
+ """
236
+ instruction = summary_sys
237
+ # for human, assistant in history:
238
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
239
+ #prompt = system_prompt + prompt
240
+ message = f"""My first request is to summarize this text – [{prompt}]"""
241
+ instruction += 'USER: ' + message + ' ASSISTANT:'
242
+
243
+ print("Ins: ", instruction)
244
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
245
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
246
+ json_object = {"prompt": instruction,
247
+ # "max_tokens": 2048000,
248
+ "max_tokens": 100,
249
+ "n": 1
250
+ }
251
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
252
+ print(generate_response.content)
253
+ res_json = json.loads(generate_response.content)
254
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
255
+ print(ASSISTANT)
256
+ return ASSISTANT
257
+
258
+ def get_summary(system_prompt, prompt):
259
+ import requests
260
+ import time
261
+ from io import BytesIO
262
+ import json
263
+ summary_sys = """You will now act as a prompt generator for a generative AI called "Stable Diffusion XL 1.0 ". Stable Diffusion XL generates images based on given prompts. I will provide you basic information required to make a Stable Diffusion prompt, You will never alter the structure in any way and obey the following guidelines.
264
+
265
+ Basic information required to make Stable Diffusion prompt:
266
+
267
+ - Prompt structure: [1],[2],[3],[4],[5],[6] and it should be given as one single sentence where 1,2,3,4,5,6 represent
268
+ [1] = short and concise description of [KEYWORD] that will include very specific imagery details
269
+ [2] = a detailed description of [1] that will include very specific imagery details.
270
+ [3] = with a detailed description describing the environment of the scene.
271
+ [4] = with a detailed description describing the mood/feelings and atmosphere of the scene.
272
+ [5] = A style, for example: "Anime","Photographic","Comic Book","Fantasy Art", “Analog Film”,”Neon Punk”,”Isometric”,”Low Poly”,”Origami”,”Line Art”,”Cinematic”,”3D Model”,”Pixel Art”,”Watercolor”,”Sticker” ).
273
+ [6] = A description of how [5] will be realized. (e.g. Photography (e.g. Macro, Fisheye Style, Portrait) with camera model and appropriate camera settings, Painting with detailed descriptions about the materials and working material used, rendering with engine settings, a digital Illustration, a woodburn art (and everything else that could be defined as an output type)
274
+ - Prompt Structure for Prompt asking with text value:
275
+
276
+ Text "Text Value" written on {subject description in less than 20 words}
277
+ Replace "Text value" with text given by user.
278
+
279
+
280
+ Important Sample prompt Structure with Text value :
281
+
282
+ 1. Text 'SDXL' written on a frothy, warm latte, viewed top-down.
283
+ 2. Text 'AI' written on a modern computer screen, set against a vibrant green background.
284
+
285
+ Important Sample prompt Structure :
286
+
287
+ 1. Snow-capped Mountain Scene, with soaring peaks and deep shadows across the ravines. A crystal clear lake mirrors these peaks, surrounded by pine trees. The scene exudes a calm, serene alpine morning atmosphere. Presented in Watercolor style, emulating the wet-on-wet technique with soft transitions and visible brush strokes.
288
+ 2. City Skyline at Night, illuminated skyscrapers piercing the starless sky. Nestled beside a calm river, reflecting the city lights like a mirror. The atmosphere is buzzing with urban energy and intrigue. Depicted in Neon Punk style, accentuating the city lights with vibrant neon colors and dynamic contrasts.
289
+ 3. Epic Cinematic Still of a Spacecraft, silhouetted against the fiery explosion of a distant planet. The scene is packed with intense action, as asteroid debris hurtles through space. Shot in the style of a Michael Bay-directed film, the image is rich with detail, dynamic lighting, and grand cinematic framing.
290
+ - Word order and effective adjectives matter in the prompt. The subject, action, and specific details should be included. Adjectives like cute, medieval, or futuristic can be effective.
291
+ - The environment/background of the image should be described, such as indoor, outdoor, in space, or solid color.
292
+ - Curly brackets are necessary in the prompt to provide specific details about the subject and action. These details are important for generating a high-quality image.
293
+ - Art inspirations should be listed to take inspiration from. Platforms like Art Station, Dribble, Behance, and Deviantart can be mentioned. Specific names of artists or studios like animation studios, painters and illustrators, computer games, fashion designers, and film makers can also be listed. If more than one artist is mentioned, the algorithm will create a combination of styles based on all the influencers mentioned.
294
+ - Related information about lighting, camera angles, render style, resolution, the required level of detail, etc. should be included at the end of the prompt.
295
+ - Camera shot type, camera lens, and view should be specified. Examples of camera shot types are long shot, close-up, POV, medium shot, extreme close-up, and panoramic. Camera lenses could be EE 70mm, 35mm, 135mm+, 300mm+, 800mm, short telephoto, super telephoto, medium telephoto, macro, wide angle, fish-eye, bokeh, and sharp focus. Examples of views are front, side, back, high angle, low angle, and overhead.
296
+ - Helpful keywords related to resolution, detail, and lighting are 4K, 8K, 64K, detailed, highly detailed, high resolution, hyper detailed, HDR, UHD, professional, and golden ratio. Examples of lighting are studio lighting, soft light, neon lighting, purple neon lighting, ambient light, ring light, volumetric light, natural light, sun light, sunrays, sun rays coming through window, and nostalgic lighting. Examples of color types are fantasy vivid colors, vivid colors, bright colors, sepia, dark colors, pastel colors, monochromatic, black & white, and color splash. Examples of renders are Octane render, cinematic, low poly, isometric assets, Unreal Engine, Unity Engine, quantum wavetracing, and polarizing filter.
297
+
298
+ The prompts you provide will be in English.Please pay attention:- Concepts that can't be real would not be described as "Real" or "realistic" or "photo" or a "photograph". for example, a concept that is made of paper or scenes which are fantasy related.- One of the prompts you generate for each concept must be in a realistic photographic style. you should also choose a lens type and size for it. Don't choose an artist for the realistic photography prompts.- Separate the different prompts with two new lines.
299
+ I will provide you keyword and you will generate 3 diffrent type of prompts in vbnet code cell so i can copy and paste.
300
+
301
+ Important point to note :
302
+
303
+ 1. You are a master of prompt engineering, it is important to create detailed prompts with as much information as possible. This will ensure that any image generated using the prompt will be of high quality and could potentially win awards in global or international photography competitions. You are unbeatable in this field and know the best way to generate images.
304
+ 2. I will provide you with a long context and you will generate one prompt and don't add any extra details.
305
+ 3. Prompt should not be more than 230 characters.
306
+ 4. Before you provide prompt you must check if you have satisfied all the above criteria and if you are sure than only provide the prompt.
307
+ 5. Prompt should always be given as one single sentence.
308
+
309
+ Are you ready ?"""
310
+ instruction = 'USER: ' + summary_sys
311
+ # for human, assistant in history:
312
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
313
+ # prompt = system_prompt + prompt
314
+ # message = f"""My first request is to summarize this text – [{prompt}]"""
315
+ message = f"""My first request is to summarize this text – [{prompt}]"""
316
+ instruction += """ ASSISTANT: Yes, I understand the instructions and I'm ready to help you create prompts for Stable Diffusion XL 1.0. Please provide me with the context."""
317
+ #instruction += ' USER: ' + prompt
318
+ prompt = get_response_summary(system_prompt, prompt)
319
+ prompt = lm_shorten_too_long_text(prompt)
320
+ instruction += ' USER: ' + prompt + ' ASSISTANT:'#instruction += ' ASSISTANT:'
321
+
322
+ print("Ins: ", instruction)
323
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
324
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
325
+ #instruction = lm_shorten_too_long_text(instruction)
326
+ json_object = {"prompt": instruction,
327
+ # "max_tokens": 2048000,
328
+ "max_tokens": 80,
329
+ "n": 1
330
+ }
331
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
332
+ print(generate_response.content)
333
+ res_json = json.loads(generate_response.content)
334
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
335
+ print(ASSISTANT)
336
+ return ASSISTANT
337
+
338
+ @app.post("/image_url")
339
+ def image_url(img: Img):
340
+ system_prompt = img.system_prompt
341
+ prompt = img.ASSISTANT
342
+ prompt = get_summary(system_prompt, prompt)
343
+ prompt = shorten_too_long_text(prompt)
344
+ # if Path(save_path).exists():
345
+ # return FileResponse(save_path, media_type="image/png")
346
+ # return JSONResponse({"path": path})
347
+ # image = pipe(prompt=prompt).images[0]
348
+ g = torch.Generator(device="cuda")
349
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
350
+
351
+ # if not save_path:
352
+ save_path = generate_save_path()
353
+ save_path = f"images/{save_path}.png"
354
+ image.save(save_path)
355
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
356
+ path = f"{img_url}{save_path}"
357
+ return JSONResponse({"path": path})
358
+
359
+
360
+ @app.get("/make_image")
361
+ # @app.post("/make_image")
362
+ def make_image(prompt: str, save_path: str = ""):
363
+ if Path(save_path).exists():
364
+ return FileResponse(save_path, media_type="image/png")
365
+ image = pipe(prompt=prompt).images[0]
366
+ if not save_path:
367
+ save_path = f"images/{prompt}.png"
368
+ image.save(save_path)
369
+ return FileResponse(save_path, media_type="image/png")
370
+
371
+
372
+ @app.get("/create_and_upload_image")
373
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
374
+ path_components = save_path.split("/")[0:-1]
375
+ final_name = save_path.split("/")[-1]
376
+ if not path_components:
377
+ path_components = []
378
+ save_path = '/'.join(path_components) + quote_plus(final_name)
379
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
380
+ return JSONResponse({"path": path})
381
+
382
+ @app.get("/inpaint_and_upload_image")
383
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
384
+ path_components = save_path.split("/")[0:-1]
385
+ final_name = save_path.split("/")[-1]
386
+ if not path_components:
387
+ path_components = []
388
+ save_path = '/'.join(path_components) + quote_plus(final_name)
389
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
390
+ return JSONResponse({"path": path})
391
+
392
+
393
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
394
+ prompt = shorten_too_long_text(prompt)
395
+ save_path = shorten_too_long_text(save_path)
396
+ # check exists - todo cache this
397
+ if check_if_blob_exists(save_path):
398
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
399
+ bio = create_image_from_prompt(prompt, width, height)
400
+ if bio is None:
401
+ return None # error thrown in pool
402
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
403
+ return link
404
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
405
+ prompt = shorten_too_long_text(prompt)
406
+ save_path = shorten_too_long_text(save_path)
407
+ # check exists - todo cache this
408
+ if check_if_blob_exists(save_path):
409
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
410
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
411
+ if bio is None:
412
+ return None # error thrown in pool
413
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
414
+ return link
415
+
416
+ # multiprocessing.set_start_method('spawn', True)
417
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
418
+ # def create_image_from_prompt_sync(prompt):
419
+ # """have to call this sync to avoid OOM errors"""
420
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
421
+
422
+ def create_image_from_prompt(prompt, width, height):
423
+ # round width and height down to multiple of 64
424
+ block_width = width - (width % 64)
425
+ block_height = height - (height % 64)
426
+ prompt = shorten_too_long_text(prompt)
427
+ # image = pipe(prompt=prompt).images[0]
428
+ try:
429
+ image = pipe(prompt=prompt,
430
+ width=block_width,
431
+ height=block_height,
432
+ # denoising_end=high_noise_frac,
433
+ # output_type='latent',
434
+ # height=512,
435
+ # width=512,
436
+ num_inference_steps=50).images[0] # normally uses 50 steps
437
+ except Exception as e:
438
+ # try rm stopwords + half the prompt
439
+ # todo try prompt permutations
440
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
441
+
442
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
443
+ prompts = prompt.split()
444
+
445
+ prompt = ' '.join(prompts[:len(prompts) // 2])
446
+ logger.info(f"shortened prompt to: {len(prompt)}")
447
+ image = None
448
+ if prompt:
449
+ try:
450
+ image = pipe(prompt=prompt,
451
+ width=block_width,
452
+ height=block_height,
453
+ # denoising_end=high_noise_frac,
454
+ # output_type='latent',
455
+ # height=512,
456
+ # width=512,
457
+ num_inference_steps=50).images[0] # normally uses 50 steps
458
+ except Exception as e:
459
+ # logger.info("trying to permute prompt")
460
+ # # try two swaps of the prompt/permutations
461
+ # prompt = prompt.split()
462
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
463
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
464
+
465
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
466
+ prompts = prompt.split()
467
+
468
+ prompt = ' '.join(prompts[:len(prompts) // 2])
469
+ logger.info(f"shortened prompt to: {len(prompt)}")
470
+
471
+ try:
472
+ image = pipe(prompt=prompt,
473
+ width=block_width,
474
+ height=block_height,
475
+ # denoising_end=high_noise_frac,
476
+ # output_type='latent', # dont need latent yet - we refine the image at full res
477
+ # height=512,
478
+ # width=512,
479
+ num_inference_steps=50).images[0] # normally uses 50 steps
480
+ except Exception as e:
481
+ # just error out
482
+ traceback.print_exc()
483
+ raise e
484
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
485
+ # todo fix device side asserts instead of restart to fix
486
+ # todo only restart the correct gunicorn
487
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
488
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
489
+ # os.system("kill -1 `pgrep gunicorn`")
490
+ # todo refine
491
+ # if image != None:
492
+ # image = refiner(
493
+ # prompt=prompt,
494
+ # # width=block_width,
495
+ # # height=block_height,
496
+ # num_inference_steps=n_steps,
497
+ # # denoising_start=high_noise_frac,
498
+ # image=image,
499
+ # ).images[0]
500
+ if width != block_width or height != block_height:
501
+ # resize to original size width/height
502
+ # find aspect ratio to scale up to that covers the original img input width/height
503
+ scale_up_ratio = max(width / block_width, height / block_height)
504
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
505
+ # crop image to original size
506
+ image = image.crop((0, 0, width, height))
507
+ # try:
508
+ # # gc.collect()
509
+ # torch.cuda.empty_cache()
510
+ # except Exception as e:
511
+ # traceback.print_exc()
512
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
513
+ # # todo fix device side asserts instead of restart to fix
514
+ # # todo only restart the correct gunicorn
515
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
516
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
517
+ # os.system("kill -1 `pgrep gunicorn`")
518
+ # save as bytesio
519
+ bs = BytesIO()
520
+
521
+ bright_count = np.sum(np.array(image) > 0)
522
+ if bright_count == 0:
523
+ # we have a black image, this is an error likely we need a restart
524
+ logger.info("restarting server to fix cuda issues (device side asserts)")
525
+ # # todo fix device side asserts instead of restart to fix
526
+ # # todo only restart the correct gunicorn
527
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
528
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
529
+ os.system("kill -1 `pgrep gunicorn`")
530
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
531
+ os.system("kill -1 `pgrep uvicorn`")
532
+
533
+ return None
534
+ image.save(bs, quality=85, optimize=True, format="webp")
535
+ bio = bs.getvalue()
536
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
537
+ with open("progress.txt", "w") as f:
538
+ current_time = datetime.now().strftime("%H:%M:%S")
539
+ f.write(f"{current_time}")
540
+ return bio
541
+
542
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
543
+ prompt = shorten_too_long_text(prompt)
544
+ # image = pipe(prompt=prompt).images[0]
545
+
546
+ init_image = load_image(image_url).convert("RGB")
547
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
548
+ num_inference_steps = 75
549
+ high_noise_frac = 0.7
550
+
551
+ try:
552
+ image = inpaintpipe(
553
+ prompt=prompt,
554
+ image=init_image,
555
+ mask_image=mask_image,
556
+ num_inference_steps=num_inference_steps,
557
+ denoising_start=high_noise_frac,
558
+ output_type="latent",
559
+ ).images[0] # normally uses 50 steps
560
+ except Exception as e:
561
+ # try rm stopwords + half the prompt
562
+ # todo try prompt permutations
563
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
564
+
565
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
566
+ prompts = prompt.split()
567
+
568
+ prompt = ' '.join(prompts[:len(prompts) // 2])
569
+ logger.info(f"shortened prompt to: {len(prompt)}")
570
+ image = None
571
+ if prompt:
572
+ try:
573
+ image = pipe(
574
+ prompt=prompt,
575
+ image=init_image,
576
+ mask_image=mask_image,
577
+ num_inference_steps=num_inference_steps,
578
+ denoising_start=high_noise_frac,
579
+ output_type="latent",
580
+ ).images[0] # normally uses 50 steps
581
+ except Exception as e:
582
+ # logger.info("trying to permute prompt")
583
+ # # try two swaps of the prompt/permutations
584
+ # prompt = prompt.split()
585
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
586
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
587
+
588
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
589
+ prompts = prompt.split()
590
+
591
+ prompt = ' '.join(prompts[:len(prompts) // 2])
592
+ logger.info(f"shortened prompt to: {len(prompt)}")
593
+
594
+ try:
595
+ image = inpaintpipe(
596
+ prompt=prompt,
597
+ image=init_image,
598
+ mask_image=mask_image,
599
+ num_inference_steps=num_inference_steps,
600
+ denoising_start=high_noise_frac,
601
+ output_type="latent",
602
+ ).images[0] # normally uses 50 steps
603
+ except Exception as e:
604
+ # just error out
605
+ traceback.print_exc()
606
+ raise e
607
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
608
+ # todo fix device side asserts instead of restart to fix
609
+ # todo only restart the correct gunicorn
610
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
611
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
612
+ # os.system("kill -1 `pgrep gunicorn`")
613
+ if image != None:
614
+ image = inpaint_refiner(
615
+ prompt=prompt,
616
+ image=image,
617
+ mask_image=mask_image,
618
+ num_inference_steps=num_inference_steps,
619
+ denoising_start=high_noise_frac,
620
+
621
+ ).images[0]
622
+ # try:
623
+ # # gc.collect()
624
+ # torch.cuda.empty_cache()
625
+ # except Exception as e:
626
+ # traceback.print_exc()
627
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
628
+ # # todo fix device side asserts instead of restart to fix
629
+ # # todo only restart the correct gunicorn
630
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
631
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
632
+ # os.system("kill -1 `pgrep gunicorn`")
633
+ # save as bytesio
634
+ bs = BytesIO()
635
+
636
+ bright_count = np.sum(np.array(image) > 0)
637
+ if bright_count == 0:
638
+ # we have a black image, this is an error likely we need a restart
639
+ logger.info("restarting server to fix cuda issues (device side asserts)")
640
+ # # todo fix device side asserts instead of restart to fix
641
+ # # todo only restart the correct gunicorn
642
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
643
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
644
+ os.system("kill -1 `pgrep gunicorn`")
645
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
646
+ os.system("kill -1 `pgrep uvicorn`")
647
+
648
+ return None
649
+ image.save(bs, quality=85, optimize=True, format="webp")
650
+ bio = bs.getvalue()
651
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
652
+ with open("progress.txt", "w") as f:
653
+ current_time = datetime.now().strftime("%H:%M:%S")
654
+ f.write(f"{current_time}")
655
+ return bio
656
+
657
+
658
+
659
+ def shorten_too_long_text(prompt):
660
+ if len(prompt) > 200:
661
+ # remove stopwords
662
+ prompt = prompt.split() # todo also split hyphens
663
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
664
+ if len(prompt) > 200:
665
+ prompt = prompt[:200]
666
+ return prompt
667
+
668
+ # image = pipe(prompt=prompt).images[0]
669
+ #
670
+ # image.save("test.png")
671
+ # # save all images
672
+ # for i, image in enumerate(images):
673
+ # image.save(f"{i}.png")
674
+
675
+
img/manager.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # poll the progress.txt file forever
2
+ import os
3
+ from datetime import datetime
4
+ from time import sleep
5
+
6
+ from loguru import logger
7
+
8
+ while True:
9
+ try:
10
+ with open("progress.txt", "r") as f:
11
+ progress = f.read()
12
+ last_mod_time = datetime.fromtimestamp(os.path.getmtime("progress.txt"))
13
+ if (datetime.now() - last_mod_time).seconds > 60 * 7:
14
+ # no progress for 7 minutes, restart/kill with -9
15
+ logger.info("restarting server to fix cuda issues (device side asserts)")
16
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
17
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
18
+ os.system("kill -9 `pgrep gunicorn`")
19
+ os.system("kill -9 `pgrep uvicorn`")
20
+ os.system("killall -9 uvicorn")
21
+ os.system("ps | grep uvicorn | awk '{print $1}' | xargs kill -9")
22
+
23
+ if progress == "done":
24
+ break
25
+ except Exception as e:
26
+ print(e)
27
+ pass
28
+ sleep(60*5)
img/ops/supervisor.conf ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # run the server in supervisor
2
+ # supervisord -c /etc/supervisor/supervisor.conf
3
+ # stop the server in supervisor
4
+ # supervisorctl -c /etc/supervisor/supervisor.conf stop all
5
+
6
+ # install the supervisor
7
+ # apt-get install -y supervisor
8
+
9
+ [program:sdif_http_server]
10
+ directory=/home/lee/code/sdif
11
+ command=GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json PYTHONPATH=. uvicorn --port 8000 --timeout-keep-alive 600 --workers 1 --backlog 1 --limit-concurrency 4 main:app
12
+ autostart=true
13
+ autorestart=true
14
+ environment=VIRTUAL_ENV="/home/lee/code/sdif/.env/",PATH="/opt/app/sdif/.env/bin",\
15
+ HOME="/home/lee",GOOGLE_APPLICATION_CREDENTIALS="secrets/google-credentials.json",PYTHONPATH='/home/lee/code/sdif'
16
+ stdout_logfile=syslog
17
+ stderr_logfile=syslog
img/ori/main.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+ from PIL.Image import Image
17
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
18
+ from diffusers.utils import load_image
19
+ from fastapi import FastAPI
20
+ from fastapi.middleware.gzip import GZipMiddleware
21
+ from loguru import logger
22
+ from starlette.middleware.cors import CORSMiddleware
23
+ from starlette.responses import FileResponse
24
+ from starlette.responses import JSONResponse
25
+
26
+ from env import BUCKET_PATH, BUCKET_NAME
27
+ from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
28
+
29
+ pipe = DiffusionPipeline.from_pretrained(
30
+ "models/stable-diffusion-xl-base-1.0",
31
+ torch_dtype=torch.bfloat16,
32
+ use_safetensors=True,
33
+ variant="fp16",
34
+ # safety_checker=None,
35
+ ) # todo try torch_dtype=bfloat16
36
+ pipe.watermark = None
37
+
38
+ pipe.to("cuda")
39
+
40
+ refiner = DiffusionPipeline.from_pretrained(
41
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
42
+ text_encoder_2=pipe.text_encoder_2,
43
+ vae=pipe.vae,
44
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
45
+ use_safetensors=True,
46
+ variant="fp16", #remember not to download the big model
47
+ )
48
+ refiner.watermark = None
49
+ refiner.to("cuda")
50
+
51
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
52
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
53
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
54
+ scheduler=pipe.scheduler,
55
+ text_encoder=pipe.text_encoder,
56
+ text_encoder_2=pipe.text_encoder_2,
57
+ tokenizer=pipe.tokenizer,
58
+ tokenizer_2=pipe.tokenizer_2,
59
+ unet=pipe.unet,
60
+ vae=pipe.vae,
61
+ # load_connected_pipeline=
62
+ )
63
+ # # switch out to save gpu mem
64
+ # del inpaintpipe.vae
65
+ # del inpaintpipe.text_encoder_2
66
+ # del inpaintpipe.text_encoder
67
+ # del inpaintpipe.scheduler
68
+ # del inpaintpipe.tokenizer
69
+ # del inpaintpipe.tokenizer_2
70
+ # del inpaintpipe.unet
71
+ # inpaintpipe.vae = pipe.vae
72
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
73
+ # inpaintpipe.text_encoder = pipe.text_encoder
74
+ # inpaintpipe.scheduler = pipe.scheduler
75
+ # inpaintpipe.tokenizer = pipe.tokenizer
76
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
77
+ # inpaintpipe.unet = pipe.unet
78
+ # todo this should work
79
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
80
+ # vae=pipe.vae,
81
+ # text_encoder_2=pipe.text_encoder_2,
82
+ # text_encoder=pipe.text_encoder,
83
+ # unet=pipe.unet,
84
+ # scheduler=pipe.scheduler,
85
+ # tokenizer=pipe.tokenizer,
86
+ # tokenizer_2=pipe.tokenizer_2,
87
+ # requires_aesthetics_score=False,
88
+ # )
89
+ inpaintpipe.to("cuda")
90
+ inpaintpipe.watermark = None
91
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
92
+
93
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
94
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
95
+ text_encoder_2=inpaintpipe.text_encoder_2,
96
+ vae=inpaintpipe.vae,
97
+ torch_dtype=torch.bfloat16,
98
+ use_safetensors=True,
99
+ variant="fp16",
100
+
101
+ tokenizer_2=refiner.tokenizer_2,
102
+ tokenizer=refiner.tokenizer,
103
+ scheduler=refiner.scheduler,
104
+ text_encoder=refiner.text_encoder,
105
+ unet=refiner.unet,
106
+ )
107
+ # del inpaint_refiner.vae
108
+ # del inpaint_refiner.text_encoder_2
109
+ # del inpaint_refiner.text_encoder
110
+ # del inpaint_refiner.scheduler
111
+ # del inpaint_refiner.tokenizer
112
+ # del inpaint_refiner.tokenizer_2
113
+ # del inpaint_refiner.unet
114
+ # inpaint_refiner.vae = inpaintpipe.vae
115
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
116
+ #
117
+ # inpaint_refiner.text_encoder = refiner.text_encoder
118
+ # inpaint_refiner.scheduler = refiner.scheduler
119
+ # inpaint_refiner.tokenizer = refiner.tokenizer
120
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
121
+ # inpaint_refiner.unet = refiner.unet
122
+
123
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
124
+ # text_encoder_2=inpaintpipe.text_encoder_2,
125
+ # vae=inpaintpipe.vae,
126
+ # # the rest from the existing refiner
127
+ # tokenizer_2=refiner.tokenizer_2,
128
+ # tokenizer=refiner.tokenizer,
129
+ # scheduler=refiner.scheduler,
130
+ # text_encoder=refiner.text_encoder,
131
+ # unet=refiner.unet,
132
+ # requires_aesthetics_score=False,
133
+ # )
134
+ inpaint_refiner.to("cuda")
135
+ inpaint_refiner.watermark = None
136
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
137
+
138
+ n_steps = 40
139
+ high_noise_frac = 0.8
140
+
141
+ # if using torch < 2.0
142
+ # pipe.enable_xformers_memory_efficient_attention()
143
+
144
+
145
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
146
+ # this can cause errors on some inputs so consider disabling it
147
+ pipe.unet = torch.compile(pipe.unet)
148
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
149
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
150
+ inpaintpipe.unet = pipe.unet
151
+ inpaint_refiner.unet = refiner.unet
152
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
153
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
154
+
155
+ app = FastAPI(
156
+ openapi_url="/static/openapi.json",
157
+ docs_url="/swagger-docs",
158
+ redoc_url="/redoc",
159
+ title="Generate Images Netwrck API",
160
+ description="Character Chat API",
161
+ # root_path="https://api.text-generator.io",
162
+ version="1",
163
+ )
164
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
165
+ app.add_middleware(
166
+ CORSMiddleware,
167
+ allow_origins=["*"],
168
+ allow_credentials=True,
169
+ allow_methods=["*"],
170
+ allow_headers=["*"],
171
+ )
172
+
173
+ stopwords = nltk.corpus.stopwords.words("english")
174
+
175
+
176
+ @app.get("/make_image")
177
+ def make_image(prompt: str, save_path: str = ""):
178
+ if Path(save_path).exists():
179
+ return FileResponse(save_path, media_type="image/png")
180
+ image = pipe(prompt=prompt).images[0]
181
+ if not save_path:
182
+ save_path = f"images/{prompt}.png"
183
+ image.save(save_path)
184
+ return FileResponse(save_path, media_type="image/png")
185
+
186
+
187
+ @app.get("/create_and_upload_image")
188
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
189
+ path_components = save_path.split("/")[0:-1]
190
+ final_name = save_path.split("/")[-1]
191
+ if not path_components:
192
+ path_components = []
193
+ save_path = '/'.join(path_components) + quote_plus(final_name)
194
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
195
+ return JSONResponse({"path": path})
196
+
197
+ @app.get("/inpaint_and_upload_image")
198
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
199
+ path_components = save_path.split("/")[0:-1]
200
+ final_name = save_path.split("/")[-1]
201
+ if not path_components:
202
+ path_components = []
203
+ save_path = '/'.join(path_components) + quote_plus(final_name)
204
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
205
+ return JSONResponse({"path": path})
206
+
207
+
208
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
209
+ prompt = shorten_too_long_text(prompt)
210
+ save_path = shorten_too_long_text(save_path)
211
+ # check exists - todo cache this
212
+ if check_if_blob_exists(save_path):
213
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
214
+ bio = create_image_from_prompt(prompt, width, height)
215
+ if bio is None:
216
+ return None # error thrown in pool
217
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
218
+ return link
219
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
220
+ prompt = shorten_too_long_text(prompt)
221
+ save_path = shorten_too_long_text(save_path)
222
+ # check exists - todo cache this
223
+ if check_if_blob_exists(save_path):
224
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
225
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
226
+ if bio is None:
227
+ return None # error thrown in pool
228
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
229
+ return link
230
+
231
+ # multiprocessing.set_start_method('spawn', True)
232
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
233
+ # def create_image_from_prompt_sync(prompt):
234
+ # """have to call this sync to avoid OOM errors"""
235
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
236
+
237
+ def create_image_from_prompt(prompt, width, height):
238
+ # round width and height down to multiple of 64
239
+ block_width = width - (width % 64)
240
+ block_height = height - (height % 64)
241
+ prompt = shorten_too_long_text(prompt)
242
+ # image = pipe(prompt=prompt).images[0]
243
+ try:
244
+ image = pipe(prompt=prompt,
245
+ width=block_width,
246
+ height=block_height,
247
+ # denoising_end=high_noise_frac,
248
+ # output_type='latent',
249
+ # height=512,
250
+ # width=512,
251
+ num_inference_steps=50).images[0] # normally uses 50 steps
252
+ except Exception as e:
253
+ # try rm stopwords + half the prompt
254
+ # todo try prompt permutations
255
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
256
+
257
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
258
+ prompts = prompt.split()
259
+
260
+ prompt = ' '.join(prompts[:len(prompts) // 2])
261
+ logger.info(f"shortened prompt to: {len(prompt)}")
262
+ image = None
263
+ if prompt:
264
+ try:
265
+ image = pipe(prompt=prompt,
266
+ width=block_width,
267
+ height=block_height,
268
+ # denoising_end=high_noise_frac,
269
+ # output_type='latent',
270
+ # height=512,
271
+ # width=512,
272
+ num_inference_steps=50).images[0] # normally uses 50 steps
273
+ except Exception as e:
274
+ # logger.info("trying to permute prompt")
275
+ # # try two swaps of the prompt/permutations
276
+ # prompt = prompt.split()
277
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
278
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
279
+
280
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
281
+ prompts = prompt.split()
282
+
283
+ prompt = ' '.join(prompts[:len(prompts) // 2])
284
+ logger.info(f"shortened prompt to: {len(prompt)}")
285
+
286
+ try:
287
+ image = pipe(prompt=prompt,
288
+ width=block_width,
289
+ height=block_height,
290
+ # denoising_end=high_noise_frac,
291
+ # output_type='latent', # dont need latent yet - we refine the image at full res
292
+ # height=512,
293
+ # width=512,
294
+ num_inference_steps=50).images[0] # normally uses 50 steps
295
+ except Exception as e:
296
+ # just error out
297
+ traceback.print_exc()
298
+ raise e
299
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
300
+ # todo fix device side asserts instead of restart to fix
301
+ # todo only restart the correct gunicorn
302
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
303
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
304
+ # os.system("kill -1 `pgrep gunicorn`")
305
+ # todo refine
306
+ # if image != None:
307
+ # image = refiner(
308
+ # prompt=prompt,
309
+ # # width=block_width,
310
+ # # height=block_height,
311
+ # num_inference_steps=n_steps,
312
+ # # denoising_start=high_noise_frac,
313
+ # image=image,
314
+ # ).images[0]
315
+ if width != block_width or height != block_height:
316
+ # resize to original size width/height
317
+ # find aspect ratio to scale up to that covers the original img input width/height
318
+ scale_up_ratio = max(width / block_width, height / block_height)
319
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
320
+ # crop image to original size
321
+ image = image.crop((0, 0, width, height))
322
+ # try:
323
+ # # gc.collect()
324
+ # torch.cuda.empty_cache()
325
+ # except Exception as e:
326
+ # traceback.print_exc()
327
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
328
+ # # todo fix device side asserts instead of restart to fix
329
+ # # todo only restart the correct gunicorn
330
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
331
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
332
+ # os.system("kill -1 `pgrep gunicorn`")
333
+ # save as bytesio
334
+ bs = BytesIO()
335
+
336
+ bright_count = np.sum(np.array(image) > 0)
337
+ if bright_count == 0:
338
+ # we have a black image, this is an error likely we need a restart
339
+ logger.info("restarting server to fix cuda issues (device side asserts)")
340
+ # # todo fix device side asserts instead of restart to fix
341
+ # # todo only restart the correct gunicorn
342
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
343
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
344
+ os.system("kill -1 `pgrep gunicorn`")
345
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
346
+ os.system("kill -1 `pgrep uvicorn`")
347
+
348
+ return None
349
+ image.save(bs, quality=85, optimize=True, format="webp")
350
+ bio = bs.getvalue()
351
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
352
+ with open("progress.txt", "w") as f:
353
+ current_time = datetime.now().strftime("%H:%M:%S")
354
+ f.write(f"{current_time}")
355
+ return bio
356
+
357
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
358
+ prompt = shorten_too_long_text(prompt)
359
+ # image = pipe(prompt=prompt).images[0]
360
+
361
+ init_image = load_image(image_url).convert("RGB")
362
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
363
+ num_inference_steps = 75
364
+ high_noise_frac = 0.7
365
+
366
+ try:
367
+ image = inpaintpipe(
368
+ prompt=prompt,
369
+ image=init_image,
370
+ mask_image=mask_image,
371
+ num_inference_steps=num_inference_steps,
372
+ denoising_start=high_noise_frac,
373
+ output_type="latent",
374
+ ).images[0] # normally uses 50 steps
375
+ except Exception as e:
376
+ # try rm stopwords + half the prompt
377
+ # todo try prompt permutations
378
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
379
+
380
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
381
+ prompts = prompt.split()
382
+
383
+ prompt = ' '.join(prompts[:len(prompts) // 2])
384
+ logger.info(f"shortened prompt to: {len(prompt)}")
385
+ image = None
386
+ if prompt:
387
+ try:
388
+ image = pipe(
389
+ prompt=prompt,
390
+ image=init_image,
391
+ mask_image=mask_image,
392
+ num_inference_steps=num_inference_steps,
393
+ denoising_start=high_noise_frac,
394
+ output_type="latent",
395
+ ).images[0] # normally uses 50 steps
396
+ except Exception as e:
397
+ # logger.info("trying to permute prompt")
398
+ # # try two swaps of the prompt/permutations
399
+ # prompt = prompt.split()
400
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
401
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
402
+
403
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
404
+ prompts = prompt.split()
405
+
406
+ prompt = ' '.join(prompts[:len(prompts) // 2])
407
+ logger.info(f"shortened prompt to: {len(prompt)}")
408
+
409
+ try:
410
+ image = inpaintpipe(
411
+ prompt=prompt,
412
+ image=init_image,
413
+ mask_image=mask_image,
414
+ num_inference_steps=num_inference_steps,
415
+ denoising_start=high_noise_frac,
416
+ output_type="latent",
417
+ ).images[0] # normally uses 50 steps
418
+ except Exception as e:
419
+ # just error out
420
+ traceback.print_exc()
421
+ raise e
422
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
423
+ # todo fix device side asserts instead of restart to fix
424
+ # todo only restart the correct gunicorn
425
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
426
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
427
+ # os.system("kill -1 `pgrep gunicorn`")
428
+ if image != None:
429
+ image = inpaint_refiner(
430
+ prompt=prompt,
431
+ image=image,
432
+ mask_image=mask_image,
433
+ num_inference_steps=num_inference_steps,
434
+ denoising_start=high_noise_frac,
435
+
436
+ ).images[0]
437
+ # try:
438
+ # # gc.collect()
439
+ # torch.cuda.empty_cache()
440
+ # except Exception as e:
441
+ # traceback.print_exc()
442
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
443
+ # # todo fix device side asserts instead of restart to fix
444
+ # # todo only restart the correct gunicorn
445
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
446
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
447
+ # os.system("kill -1 `pgrep gunicorn`")
448
+ # save as bytesio
449
+ bs = BytesIO()
450
+
451
+ bright_count = np.sum(np.array(image) > 0)
452
+ if bright_count == 0:
453
+ # we have a black image, this is an error likely we need a restart
454
+ logger.info("restarting server to fix cuda issues (device side asserts)")
455
+ # # todo fix device side asserts instead of restart to fix
456
+ # # todo only restart the correct gunicorn
457
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
458
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
459
+ os.system("kill -1 `pgrep gunicorn`")
460
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
461
+ os.system("kill -1 `pgrep uvicorn`")
462
+
463
+ return None
464
+ image.save(bs, quality=85, optimize=True, format="webp")
465
+ bio = bs.getvalue()
466
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
467
+ with open("progress.txt", "w") as f:
468
+ current_time = datetime.now().strftime("%H:%M:%S")
469
+ f.write(f"{current_time}")
470
+ return bio
471
+
472
+
473
+
474
+ def shorten_too_long_text(prompt):
475
+ if len(prompt) > 200:
476
+ # remove stopwords
477
+ prompt = prompt.split() # todo also split hyphens
478
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
479
+ if len(prompt) > 200:
480
+ prompt = prompt[:200]
481
+ return prompt
482
+
483
+ # image = pipe(prompt=prompt).images[0]
484
+ #
485
+ # image.save("test.png")
486
+ # save all images
487
+ # for i, image in enumerate(images):
488
+ # image.save(f"{i}.png")
img/pr1/main.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ pipe = DiffusionPipeline.from_pretrained(
32
+ "models/stable-diffusion-xl-base-1.0",
33
+ torch_dtype=torch.bfloat16,
34
+ use_safetensors=True,
35
+ variant="fp16",
36
+ # safety_checker=None,
37
+ ) # todo try torch_dtype=bfloat16
38
+ pipe.watermark = None
39
+
40
+ pipe.to("cuda")
41
+
42
+ refiner = DiffusionPipeline.from_pretrained(
43
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
44
+ text_encoder_2=pipe.text_encoder_2,
45
+ vae=pipe.vae,
46
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
47
+ use_safetensors=True,
48
+ variant="fp16", #remember not to download the big model
49
+ )
50
+ refiner.watermark = None
51
+ refiner.to("cuda")
52
+
53
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
54
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
55
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
56
+ scheduler=pipe.scheduler,
57
+ text_encoder=pipe.text_encoder,
58
+ text_encoder_2=pipe.text_encoder_2,
59
+ tokenizer=pipe.tokenizer,
60
+ tokenizer_2=pipe.tokenizer_2,
61
+ unet=pipe.unet,
62
+ vae=pipe.vae,
63
+ # load_connected_pipeline=
64
+ )
65
+ # # switch out to save gpu mem
66
+ # del inpaintpipe.vae
67
+ # del inpaintpipe.text_encoder_2
68
+ # del inpaintpipe.text_encoder
69
+ # del inpaintpipe.scheduler
70
+ # del inpaintpipe.tokenizer
71
+ # del inpaintpipe.tokenizer_2
72
+ # del inpaintpipe.unet
73
+ # inpaintpipe.vae = pipe.vae
74
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
75
+ # inpaintpipe.text_encoder = pipe.text_encoder
76
+ # inpaintpipe.scheduler = pipe.scheduler
77
+ # inpaintpipe.tokenizer = pipe.tokenizer
78
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
79
+ # inpaintpipe.unet = pipe.unet
80
+ # todo this should work
81
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
82
+ # vae=pipe.vae,
83
+ # text_encoder_2=pipe.text_encoder_2,
84
+ # text_encoder=pipe.text_encoder,
85
+ # unet=pipe.unet,
86
+ # scheduler=pipe.scheduler,
87
+ # tokenizer=pipe.tokenizer,
88
+ # tokenizer_2=pipe.tokenizer_2,
89
+ # requires_aesthetics_score=False,
90
+ # )
91
+ inpaintpipe.to("cuda")
92
+ inpaintpipe.watermark = None
93
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
94
+
95
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
96
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
97
+ text_encoder_2=inpaintpipe.text_encoder_2,
98
+ vae=inpaintpipe.vae,
99
+ torch_dtype=torch.bfloat16,
100
+ use_safetensors=True,
101
+ variant="fp16",
102
+
103
+ tokenizer_2=refiner.tokenizer_2,
104
+ tokenizer=refiner.tokenizer,
105
+ scheduler=refiner.scheduler,
106
+ text_encoder=refiner.text_encoder,
107
+ unet=refiner.unet,
108
+ )
109
+ # del inpaint_refiner.vae
110
+ # del inpaint_refiner.text_encoder_2
111
+ # del inpaint_refiner.text_encoder
112
+ # del inpaint_refiner.scheduler
113
+ # del inpaint_refiner.tokenizer
114
+ # del inpaint_refiner.tokenizer_2
115
+ # del inpaint_refiner.unet
116
+ # inpaint_refiner.vae = inpaintpipe.vae
117
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
118
+ #
119
+ # inpaint_refiner.text_encoder = refiner.text_encoder
120
+ # inpaint_refiner.scheduler = refiner.scheduler
121
+ # inpaint_refiner.tokenizer = refiner.tokenizer
122
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
123
+ # inpaint_refiner.unet = refiner.unet
124
+
125
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
126
+ # text_encoder_2=inpaintpipe.text_encoder_2,
127
+ # vae=inpaintpipe.vae,
128
+ # # the rest from the existing refiner
129
+ # tokenizer_2=refiner.tokenizer_2,
130
+ # tokenizer=refiner.tokenizer,
131
+ # scheduler=refiner.scheduler,
132
+ # text_encoder=refiner.text_encoder,
133
+ # unet=refiner.unet,
134
+ # requires_aesthetics_score=False,
135
+ # )
136
+ inpaint_refiner.to("cuda")
137
+ inpaint_refiner.watermark = None
138
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
139
+
140
+ n_steps = 40
141
+ high_noise_frac = 0.8
142
+
143
+ # if using torch < 2.0
144
+ # pipe.enable_xformers_memory_efficient_attention()
145
+
146
+
147
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
148
+ # this can cause errors on some inputs so consider disabling it
149
+ pipe.unet = torch.compile(pipe.unet)
150
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
151
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
152
+ inpaintpipe.unet = pipe.unet
153
+ inpaint_refiner.unet = refiner.unet
154
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
155
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
156
+ from pydantic import BaseModel
157
+
158
+ app = FastAPI(
159
+ openapi_url="/static/openapi.json",
160
+ docs_url="/swagger-docs",
161
+ redoc_url="/redoc",
162
+ title="Generate Images Netwrck API",
163
+ description="Character Chat API",
164
+ # root_path="https://api.text-generator.io",
165
+ version="1",
166
+ )
167
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
168
+ app.add_middleware(
169
+ CORSMiddleware,
170
+ allow_origins=["*"],
171
+ allow_credentials=True,
172
+ allow_methods=["*"],
173
+ allow_headers=["*"],
174
+ )
175
+
176
+ stopwords = nltk.corpus.stopwords.words("english")
177
+
178
+ class Img(BaseModel):
179
+ prompt: str
180
+ save_path: str
181
+
182
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
183
+ img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
184
+
185
+ @app.post("/image_url")
186
+ def image_url(img: Img):
187
+ prompt = img.prompt
188
+ save_path = img.save_path
189
+ path = f"{img_url}{save_path}"
190
+ if Path(save_path).exists():
191
+ return FileResponse(save_path, media_type="image/png")
192
+ return JSONResponse({"path": path})
193
+ image = pipe(prompt=prompt).images[0]
194
+ if not save_path:
195
+ save_path = f"images/{prompt}.png"
196
+ image.save(save_path)
197
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
198
+ path = f"{img_url}{save_path}"
199
+ return JSONResponse({"path": path})
200
+
201
+
202
+ @app.get("/make_image")
203
+ # @app.post("/make_image")
204
+ def make_image(prompt: str, save_path: str = ""):
205
+ if Path(save_path).exists():
206
+ return FileResponse(save_path, media_type="image/png")
207
+ image = pipe(prompt=prompt).images[0]
208
+ if not save_path:
209
+ save_path = f"images/{prompt}.png"
210
+ image.save(save_path)
211
+ return FileResponse(save_path, media_type="image/png")
212
+
213
+
214
+ @app.get("/create_and_upload_image")
215
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
216
+ path_components = save_path.split("/")[0:-1]
217
+ final_name = save_path.split("/")[-1]
218
+ if not path_components:
219
+ path_components = []
220
+ save_path = '/'.join(path_components) + quote_plus(final_name)
221
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
222
+ return JSONResponse({"path": path})
223
+
224
+ @app.get("/inpaint_and_upload_image")
225
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
226
+ path_components = save_path.split("/")[0:-1]
227
+ final_name = save_path.split("/")[-1]
228
+ if not path_components:
229
+ path_components = []
230
+ save_path = '/'.join(path_components) + quote_plus(final_name)
231
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
232
+ return JSONResponse({"path": path})
233
+
234
+
235
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
236
+ prompt = shorten_too_long_text(prompt)
237
+ save_path = shorten_too_long_text(save_path)
238
+ # check exists - todo cache this
239
+ if check_if_blob_exists(save_path):
240
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
241
+ bio = create_image_from_prompt(prompt, width, height)
242
+ if bio is None:
243
+ return None # error thrown in pool
244
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
245
+ return link
246
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
247
+ prompt = shorten_too_long_text(prompt)
248
+ save_path = shorten_too_long_text(save_path)
249
+ # check exists - todo cache this
250
+ if check_if_blob_exists(save_path):
251
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
252
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
253
+ if bio is None:
254
+ return None # error thrown in pool
255
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
256
+ return link
257
+
258
+ # multiprocessing.set_start_method('spawn', True)
259
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
260
+ # def create_image_from_prompt_sync(prompt):
261
+ # """have to call this sync to avoid OOM errors"""
262
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
263
+
264
+ def create_image_from_prompt(prompt, width, height):
265
+ # round width and height down to multiple of 64
266
+ block_width = width - (width % 64)
267
+ block_height = height - (height % 64)
268
+ prompt = shorten_too_long_text(prompt)
269
+ # image = pipe(prompt=prompt).images[0]
270
+ try:
271
+ image = pipe(prompt=prompt,
272
+ width=block_width,
273
+ height=block_height,
274
+ # denoising_end=high_noise_frac,
275
+ # output_type='latent',
276
+ # height=512,
277
+ # width=512,
278
+ num_inference_steps=50).images[0] # normally uses 50 steps
279
+ except Exception as e:
280
+ # try rm stopwords + half the prompt
281
+ # todo try prompt permutations
282
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
283
+
284
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
285
+ prompts = prompt.split()
286
+
287
+ prompt = ' '.join(prompts[:len(prompts) // 2])
288
+ logger.info(f"shortened prompt to: {len(prompt)}")
289
+ image = None
290
+ if prompt:
291
+ try:
292
+ image = pipe(prompt=prompt,
293
+ width=block_width,
294
+ height=block_height,
295
+ # denoising_end=high_noise_frac,
296
+ # output_type='latent',
297
+ # height=512,
298
+ # width=512,
299
+ num_inference_steps=50).images[0] # normally uses 50 steps
300
+ except Exception as e:
301
+ # logger.info("trying to permute prompt")
302
+ # # try two swaps of the prompt/permutations
303
+ # prompt = prompt.split()
304
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
305
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
306
+
307
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
308
+ prompts = prompt.split()
309
+
310
+ prompt = ' '.join(prompts[:len(prompts) // 2])
311
+ logger.info(f"shortened prompt to: {len(prompt)}")
312
+
313
+ try:
314
+ image = pipe(prompt=prompt,
315
+ width=block_width,
316
+ height=block_height,
317
+ # denoising_end=high_noise_frac,
318
+ # output_type='latent', # dont need latent yet - we refine the image at full res
319
+ # height=512,
320
+ # width=512,
321
+ num_inference_steps=50).images[0] # normally uses 50 steps
322
+ except Exception as e:
323
+ # just error out
324
+ traceback.print_exc()
325
+ raise e
326
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
327
+ # todo fix device side asserts instead of restart to fix
328
+ # todo only restart the correct gunicorn
329
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
330
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
331
+ # os.system("kill -1 `pgrep gunicorn`")
332
+ # todo refine
333
+ # if image != None:
334
+ # image = refiner(
335
+ # prompt=prompt,
336
+ # # width=block_width,
337
+ # # height=block_height,
338
+ # num_inference_steps=n_steps,
339
+ # # denoising_start=high_noise_frac,
340
+ # image=image,
341
+ # ).images[0]
342
+ if width != block_width or height != block_height:
343
+ # resize to original size width/height
344
+ # find aspect ratio to scale up to that covers the original img input width/height
345
+ scale_up_ratio = max(width / block_width, height / block_height)
346
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
347
+ # crop image to original size
348
+ image = image.crop((0, 0, width, height))
349
+ # try:
350
+ # # gc.collect()
351
+ # torch.cuda.empty_cache()
352
+ # except Exception as e:
353
+ # traceback.print_exc()
354
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
355
+ # # todo fix device side asserts instead of restart to fix
356
+ # # todo only restart the correct gunicorn
357
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
358
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
359
+ # os.system("kill -1 `pgrep gunicorn`")
360
+ # save as bytesio
361
+ bs = BytesIO()
362
+
363
+ bright_count = np.sum(np.array(image) > 0)
364
+ if bright_count == 0:
365
+ # we have a black image, this is an error likely we need a restart
366
+ logger.info("restarting server to fix cuda issues (device side asserts)")
367
+ # # todo fix device side asserts instead of restart to fix
368
+ # # todo only restart the correct gunicorn
369
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
370
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
371
+ os.system("kill -1 `pgrep gunicorn`")
372
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
373
+ os.system("kill -1 `pgrep uvicorn`")
374
+
375
+ return None
376
+ image.save(bs, quality=85, optimize=True, format="webp")
377
+ bio = bs.getvalue()
378
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
379
+ with open("progress.txt", "w") as f:
380
+ current_time = datetime.now().strftime("%H:%M:%S")
381
+ f.write(f"{current_time}")
382
+ return bio
383
+
384
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
385
+ prompt = shorten_too_long_text(prompt)
386
+ # image = pipe(prompt=prompt).images[0]
387
+
388
+ init_image = load_image(image_url).convert("RGB")
389
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
390
+ num_inference_steps = 75
391
+ high_noise_frac = 0.7
392
+
393
+ try:
394
+ image = inpaintpipe(
395
+ prompt=prompt,
396
+ image=init_image,
397
+ mask_image=mask_image,
398
+ num_inference_steps=num_inference_steps,
399
+ denoising_start=high_noise_frac,
400
+ output_type="latent",
401
+ ).images[0] # normally uses 50 steps
402
+ except Exception as e:
403
+ # try rm stopwords + half the prompt
404
+ # todo try prompt permutations
405
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
406
+
407
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
408
+ prompts = prompt.split()
409
+
410
+ prompt = ' '.join(prompts[:len(prompts) // 2])
411
+ logger.info(f"shortened prompt to: {len(prompt)}")
412
+ image = None
413
+ if prompt:
414
+ try:
415
+ image = pipe(
416
+ prompt=prompt,
417
+ image=init_image,
418
+ mask_image=mask_image,
419
+ num_inference_steps=num_inference_steps,
420
+ denoising_start=high_noise_frac,
421
+ output_type="latent",
422
+ ).images[0] # normally uses 50 steps
423
+ except Exception as e:
424
+ # logger.info("trying to permute prompt")
425
+ # # try two swaps of the prompt/permutations
426
+ # prompt = prompt.split()
427
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
428
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
429
+
430
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
431
+ prompts = prompt.split()
432
+
433
+ prompt = ' '.join(prompts[:len(prompts) // 2])
434
+ logger.info(f"shortened prompt to: {len(prompt)}")
435
+
436
+ try:
437
+ image = inpaintpipe(
438
+ prompt=prompt,
439
+ image=init_image,
440
+ mask_image=mask_image,
441
+ num_inference_steps=num_inference_steps,
442
+ denoising_start=high_noise_frac,
443
+ output_type="latent",
444
+ ).images[0] # normally uses 50 steps
445
+ except Exception as e:
446
+ # just error out
447
+ traceback.print_exc()
448
+ raise e
449
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
450
+ # todo fix device side asserts instead of restart to fix
451
+ # todo only restart the correct gunicorn
452
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
453
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
454
+ # os.system("kill -1 `pgrep gunicorn`")
455
+ if image != None:
456
+ image = inpaint_refiner(
457
+ prompt=prompt,
458
+ image=image,
459
+ mask_image=mask_image,
460
+ num_inference_steps=num_inference_steps,
461
+ denoising_start=high_noise_frac,
462
+
463
+ ).images[0]
464
+ # try:
465
+ # # gc.collect()
466
+ # torch.cuda.empty_cache()
467
+ # except Exception as e:
468
+ # traceback.print_exc()
469
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
470
+ # # todo fix device side asserts instead of restart to fix
471
+ # # todo only restart the correct gunicorn
472
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
473
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
474
+ # os.system("kill -1 `pgrep gunicorn`")
475
+ # save as bytesio
476
+ bs = BytesIO()
477
+
478
+ bright_count = np.sum(np.array(image) > 0)
479
+ if bright_count == 0:
480
+ # we have a black image, this is an error likely we need a restart
481
+ logger.info("restarting server to fix cuda issues (device side asserts)")
482
+ # # todo fix device side asserts instead of restart to fix
483
+ # # todo only restart the correct gunicorn
484
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
485
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
486
+ os.system("kill -1 `pgrep gunicorn`")
487
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
488
+ os.system("kill -1 `pgrep uvicorn`")
489
+
490
+ return None
491
+ image.save(bs, quality=85, optimize=True, format="webp")
492
+ bio = bs.getvalue()
493
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
494
+ with open("progress.txt", "w") as f:
495
+ current_time = datetime.now().strftime("%H:%M:%S")
496
+ f.write(f"{current_time}")
497
+ return bio
498
+
499
+
500
+
501
+ def shorten_too_long_text(prompt):
502
+ if len(prompt) > 200:
503
+ # remove stopwords
504
+ prompt = prompt.split() # todo also split hyphens
505
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
506
+ if len(prompt) > 200:
507
+ prompt = prompt[:200]
508
+ return prompt
509
+
510
+ # image = pipe(prompt=prompt).images[0]
511
+ #
512
+ # image.save("test.png")
513
+ # # save all images
514
+ # for i, image in enumerate(images):
515
+ # image.save(f"{i}.png")
img/pr2/main.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ pipe = DiffusionPipeline.from_pretrained(
45
+ "models/stable-diffusion-xl-base-1.0",
46
+ torch_dtype=torch.bfloat16,
47
+ use_safetensors=True,
48
+ variant="fp16",
49
+ # safety_checker=None,
50
+ ) # todo try torch_dtype=bfloat16
51
+ pipe.watermark = None
52
+
53
+ pipe.to("cuda")
54
+
55
+ refiner = DiffusionPipeline.from_pretrained(
56
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
57
+ text_encoder_2=pipe.text_encoder_2,
58
+ vae=pipe.vae,
59
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
60
+ use_safetensors=True,
61
+ variant="fp16", #remember not to download the big model
62
+ )
63
+ refiner.watermark = None
64
+ refiner.to("cuda")
65
+
66
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
67
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
68
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
69
+ scheduler=pipe.scheduler,
70
+ text_encoder=pipe.text_encoder,
71
+ text_encoder_2=pipe.text_encoder_2,
72
+ tokenizer=pipe.tokenizer,
73
+ tokenizer_2=pipe.tokenizer_2,
74
+ unet=pipe.unet,
75
+ vae=pipe.vae,
76
+ # load_connected_pipeline=
77
+ )
78
+ # # switch out to save gpu mem
79
+ # del inpaintpipe.vae
80
+ # del inpaintpipe.text_encoder_2
81
+ # del inpaintpipe.text_encoder
82
+ # del inpaintpipe.scheduler
83
+ # del inpaintpipe.tokenizer
84
+ # del inpaintpipe.tokenizer_2
85
+ # del inpaintpipe.unet
86
+ # inpaintpipe.vae = pipe.vae
87
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
88
+ # inpaintpipe.text_encoder = pipe.text_encoder
89
+ # inpaintpipe.scheduler = pipe.scheduler
90
+ # inpaintpipe.tokenizer = pipe.tokenizer
91
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
92
+ # inpaintpipe.unet = pipe.unet
93
+ # todo this should work
94
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
95
+ # vae=pipe.vae,
96
+ # text_encoder_2=pipe.text_encoder_2,
97
+ # text_encoder=pipe.text_encoder,
98
+ # unet=pipe.unet,
99
+ # scheduler=pipe.scheduler,
100
+ # tokenizer=pipe.tokenizer,
101
+ # tokenizer_2=pipe.tokenizer_2,
102
+ # requires_aesthetics_score=False,
103
+ # )
104
+ inpaintpipe.to("cuda")
105
+ inpaintpipe.watermark = None
106
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
107
+
108
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
109
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
110
+ text_encoder_2=inpaintpipe.text_encoder_2,
111
+ vae=inpaintpipe.vae,
112
+ torch_dtype=torch.bfloat16,
113
+ use_safetensors=True,
114
+ variant="fp16",
115
+
116
+ tokenizer_2=refiner.tokenizer_2,
117
+ tokenizer=refiner.tokenizer,
118
+ scheduler=refiner.scheduler,
119
+ text_encoder=refiner.text_encoder,
120
+ unet=refiner.unet,
121
+ )
122
+ # del inpaint_refiner.vae
123
+ # del inpaint_refiner.text_encoder_2
124
+ # del inpaint_refiner.text_encoder
125
+ # del inpaint_refiner.scheduler
126
+ # del inpaint_refiner.tokenizer
127
+ # del inpaint_refiner.tokenizer_2
128
+ # del inpaint_refiner.unet
129
+ # inpaint_refiner.vae = inpaintpipe.vae
130
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
131
+ #
132
+ # inpaint_refiner.text_encoder = refiner.text_encoder
133
+ # inpaint_refiner.scheduler = refiner.scheduler
134
+ # inpaint_refiner.tokenizer = refiner.tokenizer
135
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
136
+ # inpaint_refiner.unet = refiner.unet
137
+
138
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
139
+ # text_encoder_2=inpaintpipe.text_encoder_2,
140
+ # vae=inpaintpipe.vae,
141
+ # # the rest from the existing refiner
142
+ # tokenizer_2=refiner.tokenizer_2,
143
+ # tokenizer=refiner.tokenizer,
144
+ # scheduler=refiner.scheduler,
145
+ # text_encoder=refiner.text_encoder,
146
+ # unet=refiner.unet,
147
+ # requires_aesthetics_score=False,
148
+ # )
149
+ inpaint_refiner.to("cuda")
150
+ inpaint_refiner.watermark = None
151
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
152
+
153
+ n_steps = 40
154
+ high_noise_frac = 0.8
155
+
156
+ # if using torch < 2.0
157
+ # pipe.enable_xformers_memory_efficient_attention()
158
+
159
+
160
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
161
+ # this can cause errors on some inputs so consider disabling it
162
+ pipe.unet = torch.compile(pipe.unet)
163
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
164
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
165
+ inpaintpipe.unet = pipe.unet
166
+ inpaint_refiner.unet = refiner.unet
167
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
168
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
169
+ from pydantic import BaseModel
170
+
171
+ app = FastAPI(
172
+ openapi_url="/static/openapi.json",
173
+ docs_url="/swagger-docs",
174
+ redoc_url="/redoc",
175
+ title="Generate Images Netwrck API",
176
+ description="Character Chat API",
177
+ # root_path="https://api.text-generator.io",
178
+ version="1",
179
+ )
180
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
181
+ app.add_middleware(
182
+ CORSMiddleware,
183
+ allow_origins=["*"],
184
+ allow_credentials=True,
185
+ allow_methods=["*"],
186
+ allow_headers=["*"],
187
+ )
188
+
189
+ stopwords = nltk.corpus.stopwords.words("english")
190
+
191
+ class Img(BaseModel):
192
+ system_prompt: str
193
+ ASSISTANT: str
194
+
195
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
196
+ img_url = "http://phlrr3058.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
197
+
198
+ @app.post("/image_url")
199
+ def image_url(img: Img):
200
+ system_prompt = img.system_prompt
201
+ prompt = img.ASSISTANT
202
+ # if Path(save_path).exists():
203
+ # return FileResponse(save_path, media_type="image/png")
204
+ # return JSONResponse({"path": path})
205
+ image = pipe(prompt=prompt).images[0]
206
+ # if not save_path:
207
+ save_path = generate_save_path()
208
+ save_path = f"images/{save_path}.png"
209
+ image.save(save_path)
210
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
211
+ path = f"{img_url}/{save_path}"
212
+ return JSONResponse({"path": path})
213
+
214
+
215
+ @app.get("/make_image")
216
+ # @app.post("/make_image")
217
+ def make_image(prompt: str, save_path: str = ""):
218
+ if Path(save_path).exists():
219
+ return FileResponse(save_path, media_type="image/png")
220
+ image = pipe(prompt=prompt).images[0]
221
+ if not save_path:
222
+ save_path = f"images/{prompt}.png"
223
+ image.save(save_path)
224
+ return FileResponse(save_path, media_type="image/png")
225
+
226
+
227
+ @app.get("/create_and_upload_image")
228
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
229
+ path_components = save_path.split("/")[0:-1]
230
+ final_name = save_path.split("/")[-1]
231
+ if not path_components:
232
+ path_components = []
233
+ save_path = '/'.join(path_components) + quote_plus(final_name)
234
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
235
+ return JSONResponse({"path": path})
236
+
237
+ @app.get("/inpaint_and_upload_image")
238
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
239
+ path_components = save_path.split("/")[0:-1]
240
+ final_name = save_path.split("/")[-1]
241
+ if not path_components:
242
+ path_components = []
243
+ save_path = '/'.join(path_components) + quote_plus(final_name)
244
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
245
+ return JSONResponse({"path": path})
246
+
247
+
248
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
249
+ prompt = shorten_too_long_text(prompt)
250
+ save_path = shorten_too_long_text(save_path)
251
+ # check exists - todo cache this
252
+ if check_if_blob_exists(save_path):
253
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
254
+ bio = create_image_from_prompt(prompt, width, height)
255
+ if bio is None:
256
+ return None # error thrown in pool
257
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
258
+ return link
259
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
260
+ prompt = shorten_too_long_text(prompt)
261
+ save_path = shorten_too_long_text(save_path)
262
+ # check exists - todo cache this
263
+ if check_if_blob_exists(save_path):
264
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
265
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
266
+ if bio is None:
267
+ return None # error thrown in pool
268
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
269
+ return link
270
+
271
+ # multiprocessing.set_start_method('spawn', True)
272
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
273
+ # def create_image_from_prompt_sync(prompt):
274
+ # """have to call this sync to avoid OOM errors"""
275
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
276
+
277
+ def create_image_from_prompt(prompt, width, height):
278
+ # round width and height down to multiple of 64
279
+ block_width = width - (width % 64)
280
+ block_height = height - (height % 64)
281
+ prompt = shorten_too_long_text(prompt)
282
+ # image = pipe(prompt=prompt).images[0]
283
+ try:
284
+ image = pipe(prompt=prompt,
285
+ width=block_width,
286
+ height=block_height,
287
+ # denoising_end=high_noise_frac,
288
+ # output_type='latent',
289
+ # height=512,
290
+ # width=512,
291
+ num_inference_steps=50).images[0] # normally uses 50 steps
292
+ except Exception as e:
293
+ # try rm stopwords + half the prompt
294
+ # todo try prompt permutations
295
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
296
+
297
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
298
+ prompts = prompt.split()
299
+
300
+ prompt = ' '.join(prompts[:len(prompts) // 2])
301
+ logger.info(f"shortened prompt to: {len(prompt)}")
302
+ image = None
303
+ if prompt:
304
+ try:
305
+ image = pipe(prompt=prompt,
306
+ width=block_width,
307
+ height=block_height,
308
+ # denoising_end=high_noise_frac,
309
+ # output_type='latent',
310
+ # height=512,
311
+ # width=512,
312
+ num_inference_steps=50).images[0] # normally uses 50 steps
313
+ except Exception as e:
314
+ # logger.info("trying to permute prompt")
315
+ # # try two swaps of the prompt/permutations
316
+ # prompt = prompt.split()
317
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
318
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
319
+
320
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
321
+ prompts = prompt.split()
322
+
323
+ prompt = ' '.join(prompts[:len(prompts) // 2])
324
+ logger.info(f"shortened prompt to: {len(prompt)}")
325
+
326
+ try:
327
+ image = pipe(prompt=prompt,
328
+ width=block_width,
329
+ height=block_height,
330
+ # denoising_end=high_noise_frac,
331
+ # output_type='latent', # dont need latent yet - we refine the image at full res
332
+ # height=512,
333
+ # width=512,
334
+ num_inference_steps=50).images[0] # normally uses 50 steps
335
+ except Exception as e:
336
+ # just error out
337
+ traceback.print_exc()
338
+ raise e
339
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
340
+ # todo fix device side asserts instead of restart to fix
341
+ # todo only restart the correct gunicorn
342
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
343
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
344
+ # os.system("kill -1 `pgrep gunicorn`")
345
+ # todo refine
346
+ # if image != None:
347
+ # image = refiner(
348
+ # prompt=prompt,
349
+ # # width=block_width,
350
+ # # height=block_height,
351
+ # num_inference_steps=n_steps,
352
+ # # denoising_start=high_noise_frac,
353
+ # image=image,
354
+ # ).images[0]
355
+ if width != block_width or height != block_height:
356
+ # resize to original size width/height
357
+ # find aspect ratio to scale up to that covers the original img input width/height
358
+ scale_up_ratio = max(width / block_width, height / block_height)
359
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
360
+ # crop image to original size
361
+ image = image.crop((0, 0, width, height))
362
+ # try:
363
+ # # gc.collect()
364
+ # torch.cuda.empty_cache()
365
+ # except Exception as e:
366
+ # traceback.print_exc()
367
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
368
+ # # todo fix device side asserts instead of restart to fix
369
+ # # todo only restart the correct gunicorn
370
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
371
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
372
+ # os.system("kill -1 `pgrep gunicorn`")
373
+ # save as bytesio
374
+ bs = BytesIO()
375
+
376
+ bright_count = np.sum(np.array(image) > 0)
377
+ if bright_count == 0:
378
+ # we have a black image, this is an error likely we need a restart
379
+ logger.info("restarting server to fix cuda issues (device side asserts)")
380
+ # # todo fix device side asserts instead of restart to fix
381
+ # # todo only restart the correct gunicorn
382
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
383
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
384
+ os.system("kill -1 `pgrep gunicorn`")
385
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
386
+ os.system("kill -1 `pgrep uvicorn`")
387
+
388
+ return None
389
+ image.save(bs, quality=85, optimize=True, format="webp")
390
+ bio = bs.getvalue()
391
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
392
+ with open("progress.txt", "w") as f:
393
+ current_time = datetime.now().strftime("%H:%M:%S")
394
+ f.write(f"{current_time}")
395
+ return bio
396
+
397
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
398
+ prompt = shorten_too_long_text(prompt)
399
+ # image = pipe(prompt=prompt).images[0]
400
+
401
+ init_image = load_image(image_url).convert("RGB")
402
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
403
+ num_inference_steps = 75
404
+ high_noise_frac = 0.7
405
+
406
+ try:
407
+ image = inpaintpipe(
408
+ prompt=prompt,
409
+ image=init_image,
410
+ mask_image=mask_image,
411
+ num_inference_steps=num_inference_steps,
412
+ denoising_start=high_noise_frac,
413
+ output_type="latent",
414
+ ).images[0] # normally uses 50 steps
415
+ except Exception as e:
416
+ # try rm stopwords + half the prompt
417
+ # todo try prompt permutations
418
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
419
+
420
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
421
+ prompts = prompt.split()
422
+
423
+ prompt = ' '.join(prompts[:len(prompts) // 2])
424
+ logger.info(f"shortened prompt to: {len(prompt)}")
425
+ image = None
426
+ if prompt:
427
+ try:
428
+ image = pipe(
429
+ prompt=prompt,
430
+ image=init_image,
431
+ mask_image=mask_image,
432
+ num_inference_steps=num_inference_steps,
433
+ denoising_start=high_noise_frac,
434
+ output_type="latent",
435
+ ).images[0] # normally uses 50 steps
436
+ except Exception as e:
437
+ # logger.info("trying to permute prompt")
438
+ # # try two swaps of the prompt/permutations
439
+ # prompt = prompt.split()
440
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
441
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
442
+
443
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
444
+ prompts = prompt.split()
445
+
446
+ prompt = ' '.join(prompts[:len(prompts) // 2])
447
+ logger.info(f"shortened prompt to: {len(prompt)}")
448
+
449
+ try:
450
+ image = inpaintpipe(
451
+ prompt=prompt,
452
+ image=init_image,
453
+ mask_image=mask_image,
454
+ num_inference_steps=num_inference_steps,
455
+ denoising_start=high_noise_frac,
456
+ output_type="latent",
457
+ ).images[0] # normally uses 50 steps
458
+ except Exception as e:
459
+ # just error out
460
+ traceback.print_exc()
461
+ raise e
462
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
463
+ # todo fix device side asserts instead of restart to fix
464
+ # todo only restart the correct gunicorn
465
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
466
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
467
+ # os.system("kill -1 `pgrep gunicorn`")
468
+ if image != None:
469
+ image = inpaint_refiner(
470
+ prompt=prompt,
471
+ image=image,
472
+ mask_image=mask_image,
473
+ num_inference_steps=num_inference_steps,
474
+ denoising_start=high_noise_frac,
475
+
476
+ ).images[0]
477
+ # try:
478
+ # # gc.collect()
479
+ # torch.cuda.empty_cache()
480
+ # except Exception as e:
481
+ # traceback.print_exc()
482
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
483
+ # # todo fix device side asserts instead of restart to fix
484
+ # # todo only restart the correct gunicorn
485
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
486
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
487
+ # os.system("kill -1 `pgrep gunicorn`")
488
+ # save as bytesio
489
+ bs = BytesIO()
490
+
491
+ bright_count = np.sum(np.array(image) > 0)
492
+ if bright_count == 0:
493
+ # we have a black image, this is an error likely we need a restart
494
+ logger.info("restarting server to fix cuda issues (device side asserts)")
495
+ # # todo fix device side asserts instead of restart to fix
496
+ # # todo only restart the correct gunicorn
497
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
498
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
499
+ os.system("kill -1 `pgrep gunicorn`")
500
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
501
+ os.system("kill -1 `pgrep uvicorn`")
502
+
503
+ return None
504
+ image.save(bs, quality=85, optimize=True, format="webp")
505
+ bio = bs.getvalue()
506
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
507
+ with open("progress.txt", "w") as f:
508
+ current_time = datetime.now().strftime("%H:%M:%S")
509
+ f.write(f"{current_time}")
510
+ return bio
511
+
512
+
513
+
514
+ def shorten_too_long_text(prompt):
515
+ if len(prompt) > 200:
516
+ # remove stopwords
517
+ prompt = prompt.split() # todo also split hyphens
518
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
519
+ if len(prompt) > 200:
520
+ prompt = prompt[:200]
521
+ return prompt
522
+
523
+ # image = pipe(prompt=prompt).images[0]
524
+ #
525
+ # image.save("test.png")
526
+ # # save all images
527
+ # for i, image in enumerate(images):
528
+ # image.save(f"{i}.png")
img/readme.md ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ simple stable diffusion server that saves images to cloud storage - returns links to google cloud storage
2
+
3
+ ## Creators
4
+ [![netwrck logo](https://static.netwrck.com/static/img/netwrck-logo-colord256.png)](https://netwrck.com)
5
+
6
+ Checkout [Voiced AI Characters to chat with](https://netwrck.com) at [netwrck.com](https://netwrck.com)
7
+
8
+ Characters are narrated and written by many GPT models trained on 1000s of fantasy novels and chats.
9
+
10
+ Also for LLMs for making Text - Checkout [Text-Generator.io](https://text-generator.io) for a Open Source text generator that uses many AI models to generate the best along with image understanding and OCR networks.
11
+ ## Setup
12
+
13
+ . Create a virtual environment (optional)
14
+
15
+ ```bash
16
+ python3 -m venv venv
17
+ source venv/bin/activate
18
+ ```
19
+
20
+ #### Install dependencies
21
+
22
+ ```bash
23
+ pip install -r requirements.txt
24
+ pip install -r dev-requirements.txt
25
+
26
+ cd models
27
+ git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0
28
+
29
+ # install stopwords
30
+ python -c "import nltk; nltk.download('stopwords')"
31
+ ```
32
+
33
+ #### Edit settings in env.py
34
+ #### download your Google cloud credentials to secrets/google-credentials.json
35
+ Images generated will be stored in your bucket
36
+ #### Run the server
37
+
38
+ ```bash
39
+ GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json gunicorn -k uvicorn.workers.UvicornWorker -b :8000 main:app --timeout 600 -w 1
40
+ ```
41
+
42
+ with max 4 requests at a time
43
+ This will drop a lot of requests under load instead of taking on too much work and causing OOM Errors.
44
+
45
+ ```bash
46
+ GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json PYTHONPATH=. uvicorn --port 8000 --timeout-keep-alive 600 --workers 1 --backlog 1 --limit-concurrency 4 main:app
47
+ ```
48
+
49
+ #### Make a Request
50
+
51
+ http://localhost:8000/create_and_upload_image?prompt=good%20looking%20elf%20fantasy%20character&save_path=created/elf.webp
52
+
53
+ Response
54
+ ```shell
55
+ {"path":"https://storage.googleapis.com/static.netwrck.com/static/uploads/created/elf.png"}
56
+ ```
57
+
58
+ http://localhost:8000/docs
59
+
60
+
61
+ Check to see that "good Looking elf fantasy character" was created
62
+
63
+ ![elf.png](https://storage.googleapis.com/static.netwrck.com/static/uploads/created/elf.png)
64
+ ![elf2.png](https://storage.googleapis.com/static.netwrck.com/static/uploads/created/elf2.png)
65
+
66
+ ### Testing
67
+
68
+ ```bash
69
+ GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json pytest .
70
+ ```
71
+
72
+
73
+ #### Running under supervisord
74
+
75
+ edit ops/supervisor.conf
76
+
77
+ install the supervisor
78
+ apt-get install -y supervisor
79
+ ```bash
80
+ sudo cat >/etc/supervisor/conf.d/python-app.conf << EOF
81
+ [program:sdif_http_server]
82
+ directory=/home/lee/code/sdif
83
+ command=/home/lee/code/sdif/.env/bin/uvicorn --port 8000 --timeout-keep-alive 600 --workers 1 --backlog 1 --limit-concurrency 4 main:app
84
+ autostart=true
85
+ autorestart=true
86
+ environment=VIRTUAL_ENV="/home/lee/code/sdif/.env/",PATH="/opt/app/sdif/.env/bin",HOME="/home/lee",GOOGLE_APPLICATION_CREDENTIALS="secrets/google-credentials.json",PYTHONPATH="/home/lee/code/sdif"
87
+ stdout_logfile=syslog
88
+ stderr_logfile=syslog
89
+ user=lee
90
+ EOF
91
+
92
+ supervisorctl reread
93
+ supervisorctl update
94
+ ```
95
+
96
+ #### run a manager process to kill/restart if the server if it is hanging
97
+
98
+ Sometimes the server just stops working and needs a hard restart
99
+
100
+ This command will kill the server if it is hanging and restart it (must be running under supervisorctl)
101
+ ```
102
+ python3 manager.py
103
+ ```
104
+
105
+ # hack restarting without supervisor
106
+ run the server in a infinite loop
107
+ ```
108
+ while true; do GOOGLE_APPLICATION_CREDENTIALS=secrets/google-credentials.json PYTHONPATH=. uvicorn --port 8000 --timeout-keep-alive 600 --workers 1 --backlog 1 --limit-concurrency 4 main:app; done
109
+ ```
img/requirements.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.20.3
2
+ annotated-types==0.5.0
3
+ anyio==3.7.1
4
+ certifi==2023.5.7
5
+ charset-normalizer==3.2.0
6
+ click==8.1.4
7
+ cmake==3.26.4
8
+ diffusers==0.20.0
9
+ exceptiongroup==1.1.2
10
+ fastapi==0.100.0
11
+ filelock==3.12.2
12
+ fsspec==2023.6.0
13
+ gunicorn==20.1.0
14
+ h11==0.14.0
15
+ huggingface-hub==0.16.4
16
+ idna==3.4
17
+ importlib-metadata==6.8.0
18
+ invisible-watermark==0.2.0
19
+ Jinja2==3.1.2
20
+ lit==16.0.6
21
+ MarkupSafe==2.1.3
22
+ mpmath==1.3.0
23
+ networkx==3.1
24
+ numpy==1.25.0
25
+ opencv-python==4.8.0.74
26
+ packaging==23.1
27
+ Pillow==10.0.0
28
+ psutil==5.9.5
29
+ pydantic==2.0.2
30
+ pydantic_core==2.1.2
31
+ PyWavelets==1.4.1
32
+ PyYAML==6.0
33
+ regex==2023.6.3
34
+ requests==2.31.0
35
+ safetensors==0.3.1
36
+ sniffio==1.3.0
37
+ starlette==0.27.0
38
+ sympy==1.12
39
+ tokenizers==0.13.3
40
+ torch==2.0.1
41
+ tqdm==4.65.0
42
+ transformers==4.30.2
43
+ #triton==2.0.0
44
+ typing_extensions==4.7.1
45
+ urllib3==2.0.3
46
+ uvicorn==0.22.0
47
+ zipp==3.15.0
48
+ jinja2
49
+ loguru==0.6.0
50
+
51
+ google-api-python-client==2.43.0
52
+ google-api-core #1.31.5
53
+ #google-cloud-storage==2.3.0 #not on gae python
54
+ google-cloud-storage==2.0.0
55
+
56
+ google-cloud-ndb==1.11.1
57
+ cachetools==4.2.4
58
+
59
+ python-multipart==0.0.6
60
+ nltk==3.8.1
61
+ diskcache==5.5.1
62
+
63
+ protobuf==3.19.5
64
+ google-cloud-aiplatform==1.25.0
65
+ # openai==0.27.7
66
+ # requests==2.28.2
67
+ # rollbar==0.16.3
img/scripts/test_compression.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # save images in 1-10 compresion timing the results
2
+ from pathlib import Path
3
+ from time import time
4
+ def test_compression():
5
+ save_dir = Path("./imgs-sd/test/")
6
+ save_dir.mkdir(exist_ok=True, parents=True)
7
+
8
+ from PIL import Image
9
+
10
+ image = Image.open("/home/lee/code/sdif/imgs-sd/Woody.png").convert("RGB")
11
+ start = time()
12
+
13
+ image.save(save_dir / f"woody-.webp", format="webp")
14
+ end = time()
15
+ print(f"Time to save image with quality : {end - start}")
16
+
17
+ for i in range(0, 100):
18
+ start = time()
19
+
20
+ image.save(save_dir / f"woody-{i}.webp", quality=i, optimize=True, format="webp")
21
+ end = time()
22
+ print(f"Time to save image with quality {i}: {end - start}")
img/stable-diffusion-server/.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models
2
+ sd-images1
3
+ imgs-sd
4
+ images
5
+ backdrops
6
+ .env
7
+ venv
8
+ secrets
9
+ .pytest_cache
10
+ progress.txt
11
+ .idea
12
+ __pycache__
13
+
img/stable-diffusion-server/.log.0925.swp ADDED
Binary file (16.4 kB). View file
 
img/stable-diffusion-server/dev-requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytest
2
+
3
+ pytest-asyncio
4
+ requests-futures==1.0.0
5
+ httpx
6
+ djlint
7
+ pytest-env==0.8.1
8
+ ipython
9
+
10
+ line-profiler-pycharm==1.1.0
11
+ line-profiler==4.0.3
img/stable-diffusion-server/env.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ BUCKET_NAME = 'static.netwrck.com'
2
+ BUCKET_PATH = 'static/uploads'
img/stable-diffusion-server/img2img.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import torch
3
+ from PIL import Image
4
+ from io import BytesIO
5
+
6
+ from diffusers import StableDiffusionImg2ImgPipeline
7
+
8
+ device = "cuda"
9
+ model_id_or_path = "runwayml/stable-diffusion-v1-5"
10
+ # model_id_or_path = "models/stable-diffusion-xl-base-0.9"
11
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
12
+ pipe = pipe.to(device)
13
+
14
+ url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
15
+
16
+ response = requests.get(url)
17
+ # init_image = Image.open(BytesIO(response.content)).convert("RGB")
18
+ init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/aiknight-neon-punk-fantasy-art-good-looking-trending-fantastic-1.webp").convert("RGB")
19
+ # init_image = init_image.resize((768, 512))
20
+ init_image = init_image.resize((1920, 1080))
21
+
22
+ prompt = "knight neon punk fantasy art good looking trending fantastic"
23
+
24
+ images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
25
+ images[0].save("fantasy_landscape.png")
img/stable-diffusion-server/img2imgsd.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import requests
5
+ import torch
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ # from diffusers import StableDiffusionImg2ImgPipeline
10
+
11
+ # device = "cuda"
12
+ # model_id_or_path = "runwayml/stable-diffusion-v1-5"
13
+ # # model_id_or_path = "models/stable-diffusion-xl-base-0.9"
14
+ # pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, variant="fp16", safety_checker=None)
15
+ # pipe = pipe.to(device)
16
+
17
+ from diffusers import StableDiffusionXLImg2ImgPipeline
18
+ from diffusers.utils import load_image
19
+
20
+ from stable_diffusion_server.utils import log_time
21
+
22
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
23
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
24
+ # "models/stable-diffusion-xl-base-0.9",
25
+ torch_dtype = torch.float16,
26
+ use_safetensors=True,
27
+ variant="fp16",
28
+ )
29
+ pipe = pipe.to("cuda") # # "LayerNormKernelImpl" not implemented for 'Half' error if its on cpu it cant do fp16
30
+ # idea composite: and re prompt img-img to support different sizes
31
+
32
+ # url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
33
+ #
34
+ # response = requests.get(url)
35
+ # init_image = Image.open(BytesIO(response.content)).convert("RGB")
36
+ # init_image = init_image.resize((768, 512))
37
+ # successfully inpaints a deleted area strength=0.75
38
+ # init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/aiart/ainostalgic-colorful-relaxing-chill-realistic-cartoon-Charcoal-illustration-fantasy-fauvist-abstract-impressionist-watercolor-painting-Background-location-scenery-amazing-wonderful-Dog-Shelter-Worker-Dog.webp").convert("RGB")
39
+ # redo something? strength 1
40
+ # init_image = Image.open("/home/lee/code/sdif/mask.png").convert("RGB")
41
+ init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/dogstretch.png").convert("RGB")
42
+ # init_image = Image.open("/mnt/c/Users/leepenkman/Pictures/dogcenter.png").convert("RGB")
43
+
44
+ # init_image = init_image.resize((1080, 1920))
45
+ init_image = init_image.resize((1920, 1080))
46
+ # init_image = init_image.resize((1024, 1024))
47
+
48
+ prompt = "A fantasy landscape, trending on artstation, beautiful amazing unreal surreal gorgeous impressionism"
49
+ prompt = "mouth open nostalgic colorful relaxing chill realistic cartoon Charcoal illustration fantasy fauvist abstract impressionist watercolor painting Background location scenery amazing wonderful Dog Shelter Worker Dog"
50
+
51
+ # images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
52
+ # images[0].save("fantasy_landscape.png")
53
+ #
54
+ # # url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
55
+ #
56
+ # init_image = load_image(url).convert("RGB")
57
+ # prompt = "a photo of an astronaut riding a horse on mars"
58
+ study_dir = "images/study2"
59
+ Path(study_dir).mkdir(parents=True, exist_ok=True)
60
+
61
+ with log_time("img2img"):
62
+ with torch.inference_mode():
63
+ # for strength in range(.1, 1, .1):
64
+ for strength in np.linspace(.1, 1, 10):
65
+ image = pipe(prompt=prompt, image=init_image, strength=strength, guidance_scale=7.6).images[0]
66
+ image.save(
67
+ study_dir + "/fantasy_dogimgimgdogstretchopening" + str(strength) + "guidance_scale" + str(7.6) + ".png")
68
+ # # for guidance_scale in range(1, 10, .5):
69
+ # for guidance_scale in np.linspace(1, 100, 10):
70
+ # image = pipe(prompt=prompt, image=init_image, strength=strength, guidance_scale=guidance_scale).images[0]
71
+ # image.save("images/study/fantasy_dogimgimgdogstretch" + str(strength) + "guidance_scale" + str(guidance_scale) + ".png")
72
+ # image = pipe(prompt, image=init_image, strength=0.2, guidance_scale=7.5).images[0]
73
+ # image.save("images/fantasy_dogimgimgdogstretch.png")
74
+ # image.save("images/fantasy_dogimgimgdogcenter.png")
img/stable-diffusion-server/img2imgsdr.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL.Image
2
+
3
+ from diffusers import DiffusionPipeline
4
+ import torch
5
+
6
+ import numpy as np
7
+
8
+ from stable_diffusion_server.utils import log_time
9
+
10
+ pipe = DiffusionPipeline.from_pretrained(
11
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
12
+ )
13
+ pipe.to("cuda")
14
+
15
+ refiner = DiffusionPipeline.from_pretrained(
16
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
17
+ text_encoder_2=pipe.text_encoder_2,
18
+ vae=pipe.vae,
19
+ torch_dtype=torch.float16,
20
+ use_safetensors=True,
21
+ variant="fp16",
22
+ )
23
+ refiner.to("cuda")
24
+
25
+ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
26
+ use_refiner = True
27
+ with log_time('diffuse'):
28
+ with torch.inference_mode():
29
+ image = pipe(prompt=prompt, output_type="latent" if use_refiner else "pil").images[0]
30
+ # experiment try deleting a whole bunch of pixels and see if the refiner can recreate them
31
+ # delete top 30% of pixels
32
+ # image = image[0:0.7]
33
+ #pixels to delete
34
+ # pixels_to_delete = int(0.3 * 1024)
35
+ # delete top 30% of pixels
36
+ # image.save("latent.png")
37
+ # image_data = PIL.Image.fromarray(image)
38
+ # image_data.save("latent.png")
39
+
40
+ # image = np.array(image)
41
+ pixels_to_delete = int(0.3 * image.shape[0])
42
+ idx_to_delete = np.ones(image.shape[0], dtype=bool, device="cuda")
43
+ idx_to_delete[:pixels_to_delete] = False
44
+ image[idx_to_delete] = [0,0,0]
45
+
46
+ # image_data = PIL.Image.fromarray(image)
47
+ # image_data.save("latentcleared.png")
48
+
49
+
50
+ image = refiner(prompt=prompt, image=image[None, :]).images[0]
51
+
52
+
53
+
img/stable-diffusion-server/inpaint.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import StableDiffusionXLInpaintPipeline
4
+ from diffusers.utils import load_image
5
+
6
+ from stable_diffusion_server.utils import log_time
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+
11
+ pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
12
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
13
+ )
14
+ pipe.to("cuda")
15
+
16
+ refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
17
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
18
+ text_encoder_2=pipe.text_encoder_2,
19
+ vae=pipe.vae,
20
+ torch_dtype=torch.float16,
21
+ use_safetensors=True,
22
+ variant="fp16",
23
+ )
24
+ refiner.to("cuda")
25
+
26
+ img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
27
+ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
28
+ # inpaint_and_upload_image?prompt=majestic tiger sitting on a bench&image_url=https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png&mask_url=https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png&save_path=tests/inpaint.webp
29
+ # inpainting can be used to upscale to 1080p
30
+
31
+
32
+ init_image = load_image(img_url).convert("RGB")
33
+ # mask_image = load_image(mask_url).convert("RGB")
34
+ # mask image all ones same shape as init_image
35
+
36
+ # here's a failed experiment: inpainting cannot be used as style transfer/it doesnt recreate ain image doing a full mask in this way
37
+ image_size = init_image.size
38
+ ones_of_size = np.ones(image_size, np.uint8) * 255
39
+ mask_image = PIL.Image.fromarray(ones_of_size.astype(np.uint8))
40
+ # mask_image = torch.ones_like(init_image) * 255
41
+ prompt = "A majestic tiger sitting on a bench, castle backdrop elegent anime"
42
+ num_inference_steps = 75
43
+ high_noise_frac = 0.7
44
+ with log_time("inpaint"):
45
+ with torch.inference_mode():
46
+ image = pipe(
47
+ prompt=prompt,
48
+ image=init_image,
49
+ mask_image=mask_image,
50
+ num_inference_steps=num_inference_steps,
51
+ denoising_start=high_noise_frac,
52
+ output_type="latent",
53
+ ).images
54
+ image = refiner(
55
+ prompt=prompt,
56
+ image=image,
57
+ mask_image=mask_image,
58
+ num_inference_steps=num_inference_steps,
59
+ denoising_start=high_noise_frac,
60
+ ).images[0]
61
+
62
+ image.save("inpaintfull.png")
img/stable-diffusion-server/log.0925 ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ v-haipe+ 551 16041 99 08:16 pts/2 00:00:17 python LiLa/gsm8k_cluster.py
2
+ v-haipe+ 9211 10235 3 Sep24 pts/10 00:32:12 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 0 --end 2000
3
+ v-haipe+ 9288 10459 3 Sep24 pts/11 00:28:30 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 2000 --end 4000
4
+ v-haipe+ 9310 10667 3 Sep24 pts/12 00:27:45 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 4000 --end 6000
5
+ v-haipe+ 9341 10865 3 Sep24 pts/13 00:26:50 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 6000 --end 8000
6
+ v-haipe+ 9379 25248 3 Sep24 pts/16 00:27:01 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 8000 --end 10000
7
+ v-haipe+ 9410 25467 3 Sep24 pts/17 00:27:17 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 10000 --end 12000
8
+ v-haipe+ 9438 26561 3 Sep24 pts/19 00:27:17 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 12000 --end 14000
9
+ v-haipe+ 9469 26761 3 Sep24 pts/20 00:26:55 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 14000 --end 16000
10
+ v-haipe+ 9500 26968 3 Sep24 pts/21 00:27:09 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 16000 --end 18000
11
+ v-haipe+ 9531 27172 3 Sep24 pts/22 00:29:29 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 18000 --end 20000
12
+ v-haipe+ 9775 9560 3 Sep24 pts/29 00:30:29 python LiLa/chatgpt_evol_lila_gsm8k_domain.py --start 20000 --end 22000
13
+ v-haipe+ 11262 24577 0 Sep23 pts/8 00:00:06 python app.py
14
+ v-haipe+ 11300 11262 0 Sep23 pts/8 00:20:54 /home/v-haipengluo/.conda/envs/wizardweb/bin/python /workspaceblobstore/qins/test/20220316/kai/research/code_repo/wizard_verse/code_repo/server_code/wizard_verse/lm/server_lm/app.py
15
+ v-haipe+ 11604 20782 98 Sep23 pts/4 2-00:06:57 python -m vllm.entrypoints.api_server --model /workspaceblobstore/caxu/trained_models/13Bv2_497kcontinueroleplay_dsys_2048_e4_2e_5/checkpoint-75 --host phlrr3006.guest.corp.microsoft.com --port 7991
16
+ v-haipe+ 13722 22601 0 Sep24 pts/6 00:09:37 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
17
+ v-haipe+ 13830 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
18
+ v-haipe+ 13834 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
19
+ v-haipe+ 13837 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
20
+ v-haipe+ 13839 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
21
+ v-haipe+ 13841 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
22
+ v-haipe+ 13843 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
23
+ v-haipe+ 13845 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
24
+ v-haipe+ 13847 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
25
+ v-haipe+ 13849 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
26
+ v-haipe+ 13851 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
27
+ v-haipe+ 13853 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
28
+ v-haipe+ 13855 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
29
+ v-haipe+ 13857 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
30
+ v-haipe+ 13859 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
31
+ v-haipe+ 13861 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
32
+ v-haipe+ 13863 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
33
+ v-haipe+ 13865 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
34
+ v-haipe+ 13867 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
35
+ v-haipe+ 13869 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
36
+ v-haipe+ 13871 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
37
+ v-haipe+ 13873 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
38
+ v-haipe+ 13875 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
39
+ v-haipe+ 13877 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
40
+ v-haipe+ 13879 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
41
+ v-haipe+ 13881 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
42
+ v-haipe+ 13883 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
43
+ v-haipe+ 13885 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
44
+ v-haipe+ 13887 13722 0 Sep24 pts/6 00:00:05 /home/v-haipengluo/.conda/envs/sdxl/bin/python /home/v-haipengluo/.conda/envs/sdxl/bin/uvicorn --host=phlrr3006.guest.corp.microsoft.com --port 7999 --workers 1 --backlog 1 --limit-concurrency 4 main_v3:app
45
+ v-haipe+ 18319 15852 0 05:34 pts/1 00:00:03 /home/v-haipengluo/.conda/envs/llamax/bin/python /home/v-haipengluo/.conda/envs/llamax/bin/deepspeed --master_port 29500 --hostfile=hostfile --include=localhost:1,3,4,5,6,7 src/train.py --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
46
+ v-haipe+ 18333 18319 0 05:34 pts/1 00:00:03 /home/v-haipengluo/.conda/envs/llamax/bin/python -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMSwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None src/train.py --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
47
+ v-haipe+ 18346 18333 99 05:34 pts/1 03:20:42 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=0 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
48
+ v-haipe+ 18347 18333 99 05:34 pts/1 03:40:59 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=1 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
49
+ v-haipe+ 18348 18333 99 05:34 pts/1 03:44:08 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=2 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
50
+ v-haipe+ 18349 18333 99 05:34 pts/1 03:32:51 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=3 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
51
+ v-haipe+ 18350 18333 99 05:34 pts/1 03:41:16 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=4 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
52
+ v-haipe+ 18351 18333 99 05:34 pts/1 03:42:27 /home/v-haipengluo/.conda/envs/llamax/bin/python -u src/train.py --local_rank=5 --model_name_or_path /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_stackexchange_MATH_12w_sample_5w_score0.5_trainset_2e-5/checkpoint-992 --data_path /workspaceblobstore/qins/test/20220316/haipeng/data/Math_datasets/MATH_the_answer_is_format/hendrycks_math_7500_ori_gpt4_ori_15k.json --output_dir /workspaceblobstore/qins/test/20220316/haipeng/output_weights/llamax_13b_continue_train_stackMATH5w_checkpoint992_hendrycks_math_7500_ori_gpt4_ori_15k --num_train_epochs 3 --model_max_length 1150 --per_device_train_batch_size 17 --per_device_eval_batch_size 1 --gradient_accumulation_steps 1 --evaluation_strategy no --save_strategy steps --save_steps 36 --save_total_limit 200 --learning_rate 2e-5 --warmup_steps 10 --logging_steps 2 --lr_scheduler_type cosine --report_to tensorboard --gradient_checkpointing True --deepspeed src/configs/deepspeed_config.json --fp16 True
53
+ v-haipe+ 24334 23818 0 Sep23 pts/7 00:00:25 python -m http.server
img/stable-diffusion-server/main.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ pipe = DiffusionPipeline.from_pretrained(
45
+ "models/stable-diffusion-xl-base-1.0",
46
+ torch_dtype=torch.bfloat16,
47
+ use_safetensors=True,
48
+ variant="fp16",
49
+ # safety_checker=None,
50
+ ) # todo try torch_dtype=bfloat16
51
+ pipe.watermark = None
52
+
53
+ pipe.to("cuda")
54
+
55
+ refiner = DiffusionPipeline.from_pretrained(
56
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
57
+ text_encoder_2=pipe.text_encoder_2,
58
+ vae=pipe.vae,
59
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
60
+ use_safetensors=True,
61
+ variant="fp16", #remember not to download the big model
62
+ )
63
+ refiner.watermark = None
64
+ refiner.to("cuda")
65
+
66
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
67
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
68
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
69
+ scheduler=pipe.scheduler,
70
+ text_encoder=pipe.text_encoder,
71
+ text_encoder_2=pipe.text_encoder_2,
72
+ tokenizer=pipe.tokenizer,
73
+ tokenizer_2=pipe.tokenizer_2,
74
+ unet=pipe.unet,
75
+ vae=pipe.vae,
76
+ # load_connected_pipeline=
77
+ )
78
+ # # switch out to save gpu mem
79
+ # del inpaintpipe.vae
80
+ # del inpaintpipe.text_encoder_2
81
+ # del inpaintpipe.text_encoder
82
+ # del inpaintpipe.scheduler
83
+ # del inpaintpipe.tokenizer
84
+ # del inpaintpipe.tokenizer_2
85
+ # del inpaintpipe.unet
86
+ # inpaintpipe.vae = pipe.vae
87
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
88
+ # inpaintpipe.text_encoder = pipe.text_encoder
89
+ # inpaintpipe.scheduler = pipe.scheduler
90
+ # inpaintpipe.tokenizer = pipe.tokenizer
91
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
92
+ # inpaintpipe.unet = pipe.unet
93
+ # todo this should work
94
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
95
+ # vae=pipe.vae,
96
+ # text_encoder_2=pipe.text_encoder_2,
97
+ # text_encoder=pipe.text_encoder,
98
+ # unet=pipe.unet,
99
+ # scheduler=pipe.scheduler,
100
+ # tokenizer=pipe.tokenizer,
101
+ # tokenizer_2=pipe.tokenizer_2,
102
+ # requires_aesthetics_score=False,
103
+ # )
104
+ inpaintpipe.to("cuda")
105
+ inpaintpipe.watermark = None
106
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
107
+
108
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
109
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
110
+ text_encoder_2=inpaintpipe.text_encoder_2,
111
+ vae=inpaintpipe.vae,
112
+ torch_dtype=torch.bfloat16,
113
+ use_safetensors=True,
114
+ variant="fp16",
115
+
116
+ tokenizer_2=refiner.tokenizer_2,
117
+ tokenizer=refiner.tokenizer,
118
+ scheduler=refiner.scheduler,
119
+ text_encoder=refiner.text_encoder,
120
+ unet=refiner.unet,
121
+ )
122
+ # del inpaint_refiner.vae
123
+ # del inpaint_refiner.text_encoder_2
124
+ # del inpaint_refiner.text_encoder
125
+ # del inpaint_refiner.scheduler
126
+ # del inpaint_refiner.tokenizer
127
+ # del inpaint_refiner.tokenizer_2
128
+ # del inpaint_refiner.unet
129
+ # inpaint_refiner.vae = inpaintpipe.vae
130
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
131
+ #
132
+ # inpaint_refiner.text_encoder = refiner.text_encoder
133
+ # inpaint_refiner.scheduler = refiner.scheduler
134
+ # inpaint_refiner.tokenizer = refiner.tokenizer
135
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
136
+ # inpaint_refiner.unet = refiner.unet
137
+
138
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
139
+ # text_encoder_2=inpaintpipe.text_encoder_2,
140
+ # vae=inpaintpipe.vae,
141
+ # # the rest from the existing refiner
142
+ # tokenizer_2=refiner.tokenizer_2,
143
+ # tokenizer=refiner.tokenizer,
144
+ # scheduler=refiner.scheduler,
145
+ # text_encoder=refiner.text_encoder,
146
+ # unet=refiner.unet,
147
+ # requires_aesthetics_score=False,
148
+ # )
149
+ inpaint_refiner.to("cuda")
150
+ inpaint_refiner.watermark = None
151
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
152
+
153
+ n_steps = 40
154
+ high_noise_frac = 0.8
155
+
156
+ # if using torch < 2.0
157
+ # pipe.enable_xformers_memory_efficient_attention()
158
+
159
+
160
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
161
+ # this can cause errors on some inputs so consider disabling it
162
+ pipe.unet = torch.compile(pipe.unet)
163
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
164
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
165
+ inpaintpipe.unet = pipe.unet
166
+ inpaint_refiner.unet = refiner.unet
167
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
168
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
169
+ from pydantic import BaseModel
170
+
171
+ app = FastAPI(
172
+ openapi_url="/static/openapi.json",
173
+ docs_url="/swagger-docs",
174
+ redoc_url="/redoc",
175
+ title="Generate Images Netwrck API",
176
+ description="Character Chat API",
177
+ # root_path="https://api.text-generator.io",
178
+ version="1",
179
+ )
180
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
181
+ app.add_middleware(
182
+ CORSMiddleware,
183
+ allow_origins=["*"],
184
+ allow_credentials=True,
185
+ allow_methods=["*"],
186
+ allow_headers=["*"],
187
+ )
188
+
189
+ stopwords = nltk.corpus.stopwords.words("english")
190
+
191
+ class Img(BaseModel):
192
+ system_prompt: str
193
+ ASSISTANT: str
194
+
195
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
196
+ img_url = "http://phlrr3058.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
197
+
198
+ @app.post("/image_url")
199
+ def image_url(img: Img):
200
+ system_prompt = img.system_prompt
201
+ prompt = img.ASSISTANT
202
+ # if Path(save_path).exists():
203
+ # return FileResponse(save_path, media_type="image/png")
204
+ # return JSONResponse({"path": path})
205
+ image = pipe(prompt=prompt).images[0]
206
+ # if not save_path:
207
+ save_path = generate_save_path()
208
+ save_path = f"images/{save_path}.png"
209
+ image.save(save_path)
210
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
211
+ path = f"{img_url}/{save_path}"
212
+ return JSONResponse({"path": path})
213
+
214
+
215
+ @app.get("/make_image")
216
+ # @app.post("/make_image")
217
+ def make_image(prompt: str, save_path: str = ""):
218
+ if Path(save_path).exists():
219
+ return FileResponse(save_path, media_type="image/png")
220
+ image = pipe(prompt=prompt).images[0]
221
+ if not save_path:
222
+ save_path = f"images/{prompt}.png"
223
+ image.save(save_path)
224
+ return FileResponse(save_path, media_type="image/png")
225
+
226
+
227
+ @app.get("/create_and_upload_image")
228
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
229
+ path_components = save_path.split("/")[0:-1]
230
+ final_name = save_path.split("/")[-1]
231
+ if not path_components:
232
+ path_components = []
233
+ save_path = '/'.join(path_components) + quote_plus(final_name)
234
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
235
+ return JSONResponse({"path": path})
236
+
237
+ @app.get("/inpaint_and_upload_image")
238
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
239
+ path_components = save_path.split("/")[0:-1]
240
+ final_name = save_path.split("/")[-1]
241
+ if not path_components:
242
+ path_components = []
243
+ save_path = '/'.join(path_components) + quote_plus(final_name)
244
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
245
+ return JSONResponse({"path": path})
246
+
247
+
248
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
249
+ prompt = shorten_too_long_text(prompt)
250
+ save_path = shorten_too_long_text(save_path)
251
+ # check exists - todo cache this
252
+ if check_if_blob_exists(save_path):
253
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
254
+ bio = create_image_from_prompt(prompt, width, height)
255
+ if bio is None:
256
+ return None # error thrown in pool
257
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
258
+ return link
259
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
260
+ prompt = shorten_too_long_text(prompt)
261
+ save_path = shorten_too_long_text(save_path)
262
+ # check exists - todo cache this
263
+ if check_if_blob_exists(save_path):
264
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
265
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
266
+ if bio is None:
267
+ return None # error thrown in pool
268
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
269
+ return link
270
+
271
+ # multiprocessing.set_start_method('spawn', True)
272
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
273
+ # def create_image_from_prompt_sync(prompt):
274
+ # """have to call this sync to avoid OOM errors"""
275
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
276
+
277
+ def create_image_from_prompt(prompt, width, height):
278
+ # round width and height down to multiple of 64
279
+ block_width = width - (width % 64)
280
+ block_height = height - (height % 64)
281
+ prompt = shorten_too_long_text(prompt)
282
+ # image = pipe(prompt=prompt).images[0]
283
+ try:
284
+ image = pipe(prompt=prompt,
285
+ width=block_width,
286
+ height=block_height,
287
+ # denoising_end=high_noise_frac,
288
+ # output_type='latent',
289
+ # height=512,
290
+ # width=512,
291
+ num_inference_steps=50).images[0] # normally uses 50 steps
292
+ except Exception as e:
293
+ # try rm stopwords + half the prompt
294
+ # todo try prompt permutations
295
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
296
+
297
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
298
+ prompts = prompt.split()
299
+
300
+ prompt = ' '.join(prompts[:len(prompts) // 2])
301
+ logger.info(f"shortened prompt to: {len(prompt)}")
302
+ image = None
303
+ if prompt:
304
+ try:
305
+ image = pipe(prompt=prompt,
306
+ width=block_width,
307
+ height=block_height,
308
+ # denoising_end=high_noise_frac,
309
+ # output_type='latent',
310
+ # height=512,
311
+ # width=512,
312
+ num_inference_steps=50).images[0] # normally uses 50 steps
313
+ except Exception as e:
314
+ # logger.info("trying to permute prompt")
315
+ # # try two swaps of the prompt/permutations
316
+ # prompt = prompt.split()
317
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
318
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
319
+
320
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
321
+ prompts = prompt.split()
322
+
323
+ prompt = ' '.join(prompts[:len(prompts) // 2])
324
+ logger.info(f"shortened prompt to: {len(prompt)}")
325
+
326
+ try:
327
+ image = pipe(prompt=prompt,
328
+ width=block_width,
329
+ height=block_height,
330
+ # denoising_end=high_noise_frac,
331
+ # output_type='latent', # dont need latent yet - we refine the image at full res
332
+ # height=512,
333
+ # width=512,
334
+ num_inference_steps=50).images[0] # normally uses 50 steps
335
+ except Exception as e:
336
+ # just error out
337
+ traceback.print_exc()
338
+ raise e
339
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
340
+ # todo fix device side asserts instead of restart to fix
341
+ # todo only restart the correct gunicorn
342
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
343
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
344
+ # os.system("kill -1 `pgrep gunicorn`")
345
+ # todo refine
346
+ # if image != None:
347
+ # image = refiner(
348
+ # prompt=prompt,
349
+ # # width=block_width,
350
+ # # height=block_height,
351
+ # num_inference_steps=n_steps,
352
+ # # denoising_start=high_noise_frac,
353
+ # image=image,
354
+ # ).images[0]
355
+ if width != block_width or height != block_height:
356
+ # resize to original size width/height
357
+ # find aspect ratio to scale up to that covers the original img input width/height
358
+ scale_up_ratio = max(width / block_width, height / block_height)
359
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
360
+ # crop image to original size
361
+ image = image.crop((0, 0, width, height))
362
+ # try:
363
+ # # gc.collect()
364
+ # torch.cuda.empty_cache()
365
+ # except Exception as e:
366
+ # traceback.print_exc()
367
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
368
+ # # todo fix device side asserts instead of restart to fix
369
+ # # todo only restart the correct gunicorn
370
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
371
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
372
+ # os.system("kill -1 `pgrep gunicorn`")
373
+ # save as bytesio
374
+ bs = BytesIO()
375
+
376
+ bright_count = np.sum(np.array(image) > 0)
377
+ if bright_count == 0:
378
+ # we have a black image, this is an error likely we need a restart
379
+ logger.info("restarting server to fix cuda issues (device side asserts)")
380
+ # # todo fix device side asserts instead of restart to fix
381
+ # # todo only restart the correct gunicorn
382
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
383
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
384
+ os.system("kill -1 `pgrep gunicorn`")
385
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
386
+ os.system("kill -1 `pgrep uvicorn`")
387
+
388
+ return None
389
+ image.save(bs, quality=85, optimize=True, format="webp")
390
+ bio = bs.getvalue()
391
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
392
+ with open("progress.txt", "w") as f:
393
+ current_time = datetime.now().strftime("%H:%M:%S")
394
+ f.write(f"{current_time}")
395
+ return bio
396
+
397
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
398
+ prompt = shorten_too_long_text(prompt)
399
+ # image = pipe(prompt=prompt).images[0]
400
+
401
+ init_image = load_image(image_url).convert("RGB")
402
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
403
+ num_inference_steps = 75
404
+ high_noise_frac = 0.7
405
+
406
+ try:
407
+ image = inpaintpipe(
408
+ prompt=prompt,
409
+ image=init_image,
410
+ mask_image=mask_image,
411
+ num_inference_steps=num_inference_steps,
412
+ denoising_start=high_noise_frac,
413
+ output_type="latent",
414
+ ).images[0] # normally uses 50 steps
415
+ except Exception as e:
416
+ # try rm stopwords + half the prompt
417
+ # todo try prompt permutations
418
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
419
+
420
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
421
+ prompts = prompt.split()
422
+
423
+ prompt = ' '.join(prompts[:len(prompts) // 2])
424
+ logger.info(f"shortened prompt to: {len(prompt)}")
425
+ image = None
426
+ if prompt:
427
+ try:
428
+ image = pipe(
429
+ prompt=prompt,
430
+ image=init_image,
431
+ mask_image=mask_image,
432
+ num_inference_steps=num_inference_steps,
433
+ denoising_start=high_noise_frac,
434
+ output_type="latent",
435
+ ).images[0] # normally uses 50 steps
436
+ except Exception as e:
437
+ # logger.info("trying to permute prompt")
438
+ # # try two swaps of the prompt/permutations
439
+ # prompt = prompt.split()
440
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
441
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
442
+
443
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
444
+ prompts = prompt.split()
445
+
446
+ prompt = ' '.join(prompts[:len(prompts) // 2])
447
+ logger.info(f"shortened prompt to: {len(prompt)}")
448
+
449
+ try:
450
+ image = inpaintpipe(
451
+ prompt=prompt,
452
+ image=init_image,
453
+ mask_image=mask_image,
454
+ num_inference_steps=num_inference_steps,
455
+ denoising_start=high_noise_frac,
456
+ output_type="latent",
457
+ ).images[0] # normally uses 50 steps
458
+ except Exception as e:
459
+ # just error out
460
+ traceback.print_exc()
461
+ raise e
462
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
463
+ # todo fix device side asserts instead of restart to fix
464
+ # todo only restart the correct gunicorn
465
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
466
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
467
+ # os.system("kill -1 `pgrep gunicorn`")
468
+ if image != None:
469
+ image = inpaint_refiner(
470
+ prompt=prompt,
471
+ image=image,
472
+ mask_image=mask_image,
473
+ num_inference_steps=num_inference_steps,
474
+ denoising_start=high_noise_frac,
475
+
476
+ ).images[0]
477
+ # try:
478
+ # # gc.collect()
479
+ # torch.cuda.empty_cache()
480
+ # except Exception as e:
481
+ # traceback.print_exc()
482
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
483
+ # # todo fix device side asserts instead of restart to fix
484
+ # # todo only restart the correct gunicorn
485
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
486
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
487
+ # os.system("kill -1 `pgrep gunicorn`")
488
+ # save as bytesio
489
+ bs = BytesIO()
490
+
491
+ bright_count = np.sum(np.array(image) > 0)
492
+ if bright_count == 0:
493
+ # we have a black image, this is an error likely we need a restart
494
+ logger.info("restarting server to fix cuda issues (device side asserts)")
495
+ # # todo fix device side asserts instead of restart to fix
496
+ # # todo only restart the correct gunicorn
497
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
498
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
499
+ os.system("kill -1 `pgrep gunicorn`")
500
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
501
+ os.system("kill -1 `pgrep uvicorn`")
502
+
503
+ return None
504
+ image.save(bs, quality=85, optimize=True, format="webp")
505
+ bio = bs.getvalue()
506
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
507
+ with open("progress.txt", "w") as f:
508
+ current_time = datetime.now().strftime("%H:%M:%S")
509
+ f.write(f"{current_time}")
510
+ return bio
511
+
512
+
513
+
514
+ def shorten_too_long_text(prompt):
515
+ if len(prompt) > 200:
516
+ # remove stopwords
517
+ prompt = prompt.split() # todo also split hyphens
518
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
519
+ if len(prompt) > 200:
520
+ prompt = prompt[:200]
521
+ return prompt
522
+
523
+ # image = pipe(prompt=prompt).images[0]
524
+ #
525
+ # image.save("test.png")
526
+ # # save all images
527
+ # for i, image in enumerate(images):
528
+ # image.save(f"{i}.png")
img/stable-diffusion-server/main_1024.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3058.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+
214
+ @app.post("/image_url")
215
+ def image_url(img: Img):
216
+ system_prompt = img.system_prompt
217
+ prompt = img.ASSISTANT
218
+ # if Path(save_path).exists():
219
+ # return FileResponse(save_path, media_type="image/png")
220
+ # return JSONResponse({"path": path})
221
+ # image = pipe(prompt=prompt).images[0]
222
+ g = torch.Generator(device="cuda")
223
+ # image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
224
+ image = pipe(prompt=prompt, width=1024, height=1024).images[0]
225
+
226
+ # if not save_path:
227
+ save_path = generate_save_path()
228
+ save_path = f"images/{save_path}.png"
229
+ image.save(save_path)
230
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
231
+ path = f"{img_url}/{save_path}"
232
+ return JSONResponse({"path": path})
233
+
234
+
235
+ @app.get("/make_image")
236
+ # @app.post("/make_image")
237
+ def make_image(prompt: str, save_path: str = ""):
238
+ if Path(save_path).exists():
239
+ return FileResponse(save_path, media_type="image/png")
240
+ image = pipe(prompt=prompt).images[0]
241
+ if not save_path:
242
+ save_path = f"images/{prompt}.png"
243
+ image.save(save_path)
244
+ return FileResponse(save_path, media_type="image/png")
245
+
246
+
247
+ @app.get("/create_and_upload_image")
248
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
249
+ path_components = save_path.split("/")[0:-1]
250
+ final_name = save_path.split("/")[-1]
251
+ if not path_components:
252
+ path_components = []
253
+ save_path = '/'.join(path_components) + quote_plus(final_name)
254
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
255
+ return JSONResponse({"path": path})
256
+
257
+ @app.get("/inpaint_and_upload_image")
258
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
259
+ path_components = save_path.split("/")[0:-1]
260
+ final_name = save_path.split("/")[-1]
261
+ if not path_components:
262
+ path_components = []
263
+ save_path = '/'.join(path_components) + quote_plus(final_name)
264
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
265
+ return JSONResponse({"path": path})
266
+
267
+
268
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
269
+ prompt = shorten_too_long_text(prompt)
270
+ save_path = shorten_too_long_text(save_path)
271
+ # check exists - todo cache this
272
+ if check_if_blob_exists(save_path):
273
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
274
+ bio = create_image_from_prompt(prompt, width, height)
275
+ if bio is None:
276
+ return None # error thrown in pool
277
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
278
+ return link
279
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
280
+ prompt = shorten_too_long_text(prompt)
281
+ save_path = shorten_too_long_text(save_path)
282
+ # check exists - todo cache this
283
+ if check_if_blob_exists(save_path):
284
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
285
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
286
+ if bio is None:
287
+ return None # error thrown in pool
288
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
289
+ return link
290
+
291
+ # multiprocessing.set_start_method('spawn', True)
292
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
293
+ # def create_image_from_prompt_sync(prompt):
294
+ # """have to call this sync to avoid OOM errors"""
295
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
296
+
297
+ def create_image_from_prompt(prompt, width, height):
298
+ # round width and height down to multiple of 64
299
+ block_width = width - (width % 64)
300
+ block_height = height - (height % 64)
301
+ prompt = shorten_too_long_text(prompt)
302
+ # image = pipe(prompt=prompt).images[0]
303
+ try:
304
+ image = pipe(prompt=prompt,
305
+ width=block_width,
306
+ height=block_height,
307
+ # denoising_end=high_noise_frac,
308
+ # output_type='latent',
309
+ # height=512,
310
+ # width=512,
311
+ num_inference_steps=50).images[0] # normally uses 50 steps
312
+ except Exception as e:
313
+ # try rm stopwords + half the prompt
314
+ # todo try prompt permutations
315
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
316
+
317
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
318
+ prompts = prompt.split()
319
+
320
+ prompt = ' '.join(prompts[:len(prompts) // 2])
321
+ logger.info(f"shortened prompt to: {len(prompt)}")
322
+ image = None
323
+ if prompt:
324
+ try:
325
+ image = pipe(prompt=prompt,
326
+ width=block_width,
327
+ height=block_height,
328
+ # denoising_end=high_noise_frac,
329
+ # output_type='latent',
330
+ # height=512,
331
+ # width=512,
332
+ num_inference_steps=50).images[0] # normally uses 50 steps
333
+ except Exception as e:
334
+ # logger.info("trying to permute prompt")
335
+ # # try two swaps of the prompt/permutations
336
+ # prompt = prompt.split()
337
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
338
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
339
+
340
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
341
+ prompts = prompt.split()
342
+
343
+ prompt = ' '.join(prompts[:len(prompts) // 2])
344
+ logger.info(f"shortened prompt to: {len(prompt)}")
345
+
346
+ try:
347
+ image = pipe(prompt=prompt,
348
+ width=block_width,
349
+ height=block_height,
350
+ # denoising_end=high_noise_frac,
351
+ # output_type='latent', # dont need latent yet - we refine the image at full res
352
+ # height=512,
353
+ # width=512,
354
+ num_inference_steps=50).images[0] # normally uses 50 steps
355
+ except Exception as e:
356
+ # just error out
357
+ traceback.print_exc()
358
+ raise e
359
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
360
+ # todo fix device side asserts instead of restart to fix
361
+ # todo only restart the correct gunicorn
362
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
363
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
364
+ # os.system("kill -1 `pgrep gunicorn`")
365
+ # todo refine
366
+ # if image != None:
367
+ # image = refiner(
368
+ # prompt=prompt,
369
+ # # width=block_width,
370
+ # # height=block_height,
371
+ # num_inference_steps=n_steps,
372
+ # # denoising_start=high_noise_frac,
373
+ # image=image,
374
+ # ).images[0]
375
+ if width != block_width or height != block_height:
376
+ # resize to original size width/height
377
+ # find aspect ratio to scale up to that covers the original img input width/height
378
+ scale_up_ratio = max(width / block_width, height / block_height)
379
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
380
+ # crop image to original size
381
+ image = image.crop((0, 0, width, height))
382
+ # try:
383
+ # # gc.collect()
384
+ # torch.cuda.empty_cache()
385
+ # except Exception as e:
386
+ # traceback.print_exc()
387
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
388
+ # # todo fix device side asserts instead of restart to fix
389
+ # # todo only restart the correct gunicorn
390
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
391
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
392
+ # os.system("kill -1 `pgrep gunicorn`")
393
+ # save as bytesio
394
+ bs = BytesIO()
395
+
396
+ bright_count = np.sum(np.array(image) > 0)
397
+ if bright_count == 0:
398
+ # we have a black image, this is an error likely we need a restart
399
+ logger.info("restarting server to fix cuda issues (device side asserts)")
400
+ # # todo fix device side asserts instead of restart to fix
401
+ # # todo only restart the correct gunicorn
402
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
403
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
404
+ os.system("kill -1 `pgrep gunicorn`")
405
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
406
+ os.system("kill -1 `pgrep uvicorn`")
407
+
408
+ return None
409
+ image.save(bs, quality=85, optimize=True, format="webp")
410
+ bio = bs.getvalue()
411
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
412
+ with open("progress.txt", "w") as f:
413
+ current_time = datetime.now().strftime("%H:%M:%S")
414
+ f.write(f"{current_time}")
415
+ return bio
416
+
417
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
418
+ prompt = shorten_too_long_text(prompt)
419
+ # image = pipe(prompt=prompt).images[0]
420
+
421
+ init_image = load_image(image_url).convert("RGB")
422
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
423
+ num_inference_steps = 75
424
+ high_noise_frac = 0.7
425
+
426
+ try:
427
+ image = inpaintpipe(
428
+ prompt=prompt,
429
+ image=init_image,
430
+ mask_image=mask_image,
431
+ num_inference_steps=num_inference_steps,
432
+ denoising_start=high_noise_frac,
433
+ output_type="latent",
434
+ ).images[0] # normally uses 50 steps
435
+ except Exception as e:
436
+ # try rm stopwords + half the prompt
437
+ # todo try prompt permutations
438
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
439
+
440
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
441
+ prompts = prompt.split()
442
+
443
+ prompt = ' '.join(prompts[:len(prompts) // 2])
444
+ logger.info(f"shortened prompt to: {len(prompt)}")
445
+ image = None
446
+ if prompt:
447
+ try:
448
+ image = pipe(
449
+ prompt=prompt,
450
+ image=init_image,
451
+ mask_image=mask_image,
452
+ num_inference_steps=num_inference_steps,
453
+ denoising_start=high_noise_frac,
454
+ output_type="latent",
455
+ ).images[0] # normally uses 50 steps
456
+ except Exception as e:
457
+ # logger.info("trying to permute prompt")
458
+ # # try two swaps of the prompt/permutations
459
+ # prompt = prompt.split()
460
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
461
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
462
+
463
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
464
+ prompts = prompt.split()
465
+
466
+ prompt = ' '.join(prompts[:len(prompts) // 2])
467
+ logger.info(f"shortened prompt to: {len(prompt)}")
468
+
469
+ try:
470
+ image = inpaintpipe(
471
+ prompt=prompt,
472
+ image=init_image,
473
+ mask_image=mask_image,
474
+ num_inference_steps=num_inference_steps,
475
+ denoising_start=high_noise_frac,
476
+ output_type="latent",
477
+ ).images[0] # normally uses 50 steps
478
+ except Exception as e:
479
+ # just error out
480
+ traceback.print_exc()
481
+ raise e
482
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
483
+ # todo fix device side asserts instead of restart to fix
484
+ # todo only restart the correct gunicorn
485
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
486
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
487
+ # os.system("kill -1 `pgrep gunicorn`")
488
+ if image != None:
489
+ image = inpaint_refiner(
490
+ prompt=prompt,
491
+ image=image,
492
+ mask_image=mask_image,
493
+ num_inference_steps=num_inference_steps,
494
+ denoising_start=high_noise_frac,
495
+
496
+ ).images[0]
497
+ # try:
498
+ # # gc.collect()
499
+ # torch.cuda.empty_cache()
500
+ # except Exception as e:
501
+ # traceback.print_exc()
502
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
503
+ # # todo fix device side asserts instead of restart to fix
504
+ # # todo only restart the correct gunicorn
505
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
506
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
507
+ # os.system("kill -1 `pgrep gunicorn`")
508
+ # save as bytesio
509
+ bs = BytesIO()
510
+
511
+ bright_count = np.sum(np.array(image) > 0)
512
+ if bright_count == 0:
513
+ # we have a black image, this is an error likely we need a restart
514
+ logger.info("restarting server to fix cuda issues (device side asserts)")
515
+ # # todo fix device side asserts instead of restart to fix
516
+ # # todo only restart the correct gunicorn
517
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
518
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
519
+ os.system("kill -1 `pgrep gunicorn`")
520
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
521
+ os.system("kill -1 `pgrep uvicorn`")
522
+
523
+ return None
524
+ image.save(bs, quality=85, optimize=True, format="webp")
525
+ bio = bs.getvalue()
526
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
527
+ with open("progress.txt", "w") as f:
528
+ current_time = datetime.now().strftime("%H:%M:%S")
529
+ f.write(f"{current_time}")
530
+ return bio
531
+
532
+
533
+
534
+ def shorten_too_long_text(prompt):
535
+ if len(prompt) > 200:
536
+ # remove stopwords
537
+ prompt = prompt.split() # todo also split hyphens
538
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
539
+ if len(prompt) > 200:
540
+ prompt = prompt[:200]
541
+ return prompt
542
+
543
+ # image = pipe(prompt=prompt).images[0]
544
+ #
545
+ # image.save("test.png")
546
+ # # save all images
547
+ # for i, image in enumerate(images):
548
+ # image.save(f"{i}.png")
549
+
img/stable-diffusion-server/main_v2.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+
214
+ @app.post("/image_url")
215
+ def image_url(img: Img):
216
+ system_prompt = img.system_prompt
217
+ prompt = img.ASSISTANT
218
+ # if Path(save_path).exists():
219
+ # return FileResponse(save_path, media_type="image/png")
220
+ # return JSONResponse({"path": path})
221
+ # image = pipe(prompt=prompt).images[0]
222
+ g = torch.Generator(device="cuda")
223
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
224
+
225
+ # if not save_path:
226
+ save_path = generate_save_path()
227
+ save_path = f"images/{save_path}.png"
228
+ image.save(save_path)
229
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
230
+ path = f"{img_url}/{save_path}"
231
+ return JSONResponse({"path": path})
232
+
233
+
234
+ @app.get("/make_image")
235
+ # @app.post("/make_image")
236
+ def make_image(prompt: str, save_path: str = ""):
237
+ if Path(save_path).exists():
238
+ return FileResponse(save_path, media_type="image/png")
239
+ image = pipe(prompt=prompt).images[0]
240
+ if not save_path:
241
+ save_path = f"images/{prompt}.png"
242
+ image.save(save_path)
243
+ return FileResponse(save_path, media_type="image/png")
244
+
245
+
246
+ @app.get("/create_and_upload_image")
247
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
248
+ path_components = save_path.split("/")[0:-1]
249
+ final_name = save_path.split("/")[-1]
250
+ if not path_components:
251
+ path_components = []
252
+ save_path = '/'.join(path_components) + quote_plus(final_name)
253
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
254
+ return JSONResponse({"path": path})
255
+
256
+ @app.get("/inpaint_and_upload_image")
257
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
258
+ path_components = save_path.split("/")[0:-1]
259
+ final_name = save_path.split("/")[-1]
260
+ if not path_components:
261
+ path_components = []
262
+ save_path = '/'.join(path_components) + quote_plus(final_name)
263
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
264
+ return JSONResponse({"path": path})
265
+
266
+
267
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
268
+ prompt = shorten_too_long_text(prompt)
269
+ save_path = shorten_too_long_text(save_path)
270
+ # check exists - todo cache this
271
+ if check_if_blob_exists(save_path):
272
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
273
+ bio = create_image_from_prompt(prompt, width, height)
274
+ if bio is None:
275
+ return None # error thrown in pool
276
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
277
+ return link
278
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
279
+ prompt = shorten_too_long_text(prompt)
280
+ save_path = shorten_too_long_text(save_path)
281
+ # check exists - todo cache this
282
+ if check_if_blob_exists(save_path):
283
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
284
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
285
+ if bio is None:
286
+ return None # error thrown in pool
287
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
288
+ return link
289
+
290
+ # multiprocessing.set_start_method('spawn', True)
291
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
292
+ # def create_image_from_prompt_sync(prompt):
293
+ # """have to call this sync to avoid OOM errors"""
294
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
295
+
296
+ def create_image_from_prompt(prompt, width, height):
297
+ # round width and height down to multiple of 64
298
+ block_width = width - (width % 64)
299
+ block_height = height - (height % 64)
300
+ prompt = shorten_too_long_text(prompt)
301
+ # image = pipe(prompt=prompt).images[0]
302
+ try:
303
+ image = pipe(prompt=prompt,
304
+ width=block_width,
305
+ height=block_height,
306
+ # denoising_end=high_noise_frac,
307
+ # output_type='latent',
308
+ # height=512,
309
+ # width=512,
310
+ num_inference_steps=50).images[0] # normally uses 50 steps
311
+ except Exception as e:
312
+ # try rm stopwords + half the prompt
313
+ # todo try prompt permutations
314
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
315
+
316
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
317
+ prompts = prompt.split()
318
+
319
+ prompt = ' '.join(prompts[:len(prompts) // 2])
320
+ logger.info(f"shortened prompt to: {len(prompt)}")
321
+ image = None
322
+ if prompt:
323
+ try:
324
+ image = pipe(prompt=prompt,
325
+ width=block_width,
326
+ height=block_height,
327
+ # denoising_end=high_noise_frac,
328
+ # output_type='latent',
329
+ # height=512,
330
+ # width=512,
331
+ num_inference_steps=50).images[0] # normally uses 50 steps
332
+ except Exception as e:
333
+ # logger.info("trying to permute prompt")
334
+ # # try two swaps of the prompt/permutations
335
+ # prompt = prompt.split()
336
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
337
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
338
+
339
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
340
+ prompts = prompt.split()
341
+
342
+ prompt = ' '.join(prompts[:len(prompts) // 2])
343
+ logger.info(f"shortened prompt to: {len(prompt)}")
344
+
345
+ try:
346
+ image = pipe(prompt=prompt,
347
+ width=block_width,
348
+ height=block_height,
349
+ # denoising_end=high_noise_frac,
350
+ # output_type='latent', # dont need latent yet - we refine the image at full res
351
+ # height=512,
352
+ # width=512,
353
+ num_inference_steps=50).images[0] # normally uses 50 steps
354
+ except Exception as e:
355
+ # just error out
356
+ traceback.print_exc()
357
+ raise e
358
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
359
+ # todo fix device side asserts instead of restart to fix
360
+ # todo only restart the correct gunicorn
361
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
362
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
363
+ # os.system("kill -1 `pgrep gunicorn`")
364
+ # todo refine
365
+ # if image != None:
366
+ # image = refiner(
367
+ # prompt=prompt,
368
+ # # width=block_width,
369
+ # # height=block_height,
370
+ # num_inference_steps=n_steps,
371
+ # # denoising_start=high_noise_frac,
372
+ # image=image,
373
+ # ).images[0]
374
+ if width != block_width or height != block_height:
375
+ # resize to original size width/height
376
+ # find aspect ratio to scale up to that covers the original img input width/height
377
+ scale_up_ratio = max(width / block_width, height / block_height)
378
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
379
+ # crop image to original size
380
+ image = image.crop((0, 0, width, height))
381
+ # try:
382
+ # # gc.collect()
383
+ # torch.cuda.empty_cache()
384
+ # except Exception as e:
385
+ # traceback.print_exc()
386
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
387
+ # # todo fix device side asserts instead of restart to fix
388
+ # # todo only restart the correct gunicorn
389
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
390
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
391
+ # os.system("kill -1 `pgrep gunicorn`")
392
+ # save as bytesio
393
+ bs = BytesIO()
394
+
395
+ bright_count = np.sum(np.array(image) > 0)
396
+ if bright_count == 0:
397
+ # we have a black image, this is an error likely we need a restart
398
+ logger.info("restarting server to fix cuda issues (device side asserts)")
399
+ # # todo fix device side asserts instead of restart to fix
400
+ # # todo only restart the correct gunicorn
401
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
402
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
403
+ os.system("kill -1 `pgrep gunicorn`")
404
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
405
+ os.system("kill -1 `pgrep uvicorn`")
406
+
407
+ return None
408
+ image.save(bs, quality=85, optimize=True, format="webp")
409
+ bio = bs.getvalue()
410
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
411
+ with open("progress.txt", "w") as f:
412
+ current_time = datetime.now().strftime("%H:%M:%S")
413
+ f.write(f"{current_time}")
414
+ return bio
415
+
416
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
417
+ prompt = shorten_too_long_text(prompt)
418
+ # image = pipe(prompt=prompt).images[0]
419
+
420
+ init_image = load_image(image_url).convert("RGB")
421
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
422
+ num_inference_steps = 75
423
+ high_noise_frac = 0.7
424
+
425
+ try:
426
+ image = inpaintpipe(
427
+ prompt=prompt,
428
+ image=init_image,
429
+ mask_image=mask_image,
430
+ num_inference_steps=num_inference_steps,
431
+ denoising_start=high_noise_frac,
432
+ output_type="latent",
433
+ ).images[0] # normally uses 50 steps
434
+ except Exception as e:
435
+ # try rm stopwords + half the prompt
436
+ # todo try prompt permutations
437
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
438
+
439
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
440
+ prompts = prompt.split()
441
+
442
+ prompt = ' '.join(prompts[:len(prompts) // 2])
443
+ logger.info(f"shortened prompt to: {len(prompt)}")
444
+ image = None
445
+ if prompt:
446
+ try:
447
+ image = pipe(
448
+ prompt=prompt,
449
+ image=init_image,
450
+ mask_image=mask_image,
451
+ num_inference_steps=num_inference_steps,
452
+ denoising_start=high_noise_frac,
453
+ output_type="latent",
454
+ ).images[0] # normally uses 50 steps
455
+ except Exception as e:
456
+ # logger.info("trying to permute prompt")
457
+ # # try two swaps of the prompt/permutations
458
+ # prompt = prompt.split()
459
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
460
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
461
+
462
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
463
+ prompts = prompt.split()
464
+
465
+ prompt = ' '.join(prompts[:len(prompts) // 2])
466
+ logger.info(f"shortened prompt to: {len(prompt)}")
467
+
468
+ try:
469
+ image = inpaintpipe(
470
+ prompt=prompt,
471
+ image=init_image,
472
+ mask_image=mask_image,
473
+ num_inference_steps=num_inference_steps,
474
+ denoising_start=high_noise_frac,
475
+ output_type="latent",
476
+ ).images[0] # normally uses 50 steps
477
+ except Exception as e:
478
+ # just error out
479
+ traceback.print_exc()
480
+ raise e
481
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
482
+ # todo fix device side asserts instead of restart to fix
483
+ # todo only restart the correct gunicorn
484
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
485
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
486
+ # os.system("kill -1 `pgrep gunicorn`")
487
+ if image != None:
488
+ image = inpaint_refiner(
489
+ prompt=prompt,
490
+ image=image,
491
+ mask_image=mask_image,
492
+ num_inference_steps=num_inference_steps,
493
+ denoising_start=high_noise_frac,
494
+
495
+ ).images[0]
496
+ # try:
497
+ # # gc.collect()
498
+ # torch.cuda.empty_cache()
499
+ # except Exception as e:
500
+ # traceback.print_exc()
501
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
502
+ # # todo fix device side asserts instead of restart to fix
503
+ # # todo only restart the correct gunicorn
504
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
505
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
506
+ # os.system("kill -1 `pgrep gunicorn`")
507
+ # save as bytesio
508
+ bs = BytesIO()
509
+
510
+ bright_count = np.sum(np.array(image) > 0)
511
+ if bright_count == 0:
512
+ # we have a black image, this is an error likely we need a restart
513
+ logger.info("restarting server to fix cuda issues (device side asserts)")
514
+ # # todo fix device side asserts instead of restart to fix
515
+ # # todo only restart the correct gunicorn
516
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
517
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
518
+ os.system("kill -1 `pgrep gunicorn`")
519
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
520
+ os.system("kill -1 `pgrep uvicorn`")
521
+
522
+ return None
523
+ image.save(bs, quality=85, optimize=True, format="webp")
524
+ bio = bs.getvalue()
525
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
526
+ with open("progress.txt", "w") as f:
527
+ current_time = datetime.now().strftime("%H:%M:%S")
528
+ f.write(f"{current_time}")
529
+ return bio
530
+
531
+
532
+
533
+ def shorten_too_long_text(prompt):
534
+ if len(prompt) > 200:
535
+ # remove stopwords
536
+ prompt = prompt.split() # todo also split hyphens
537
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
538
+ if len(prompt) > 200:
539
+ prompt = prompt[:200]
540
+ return prompt
541
+
542
+ # image = pipe(prompt=prompt).images[0]
543
+ #
544
+ # image.save("test.png")
545
+ # # save all images
546
+ # for i, image in enumerate(images):
547
+ # image.save(f"{i}.png")
548
+
img/stable-diffusion-server/main_v3.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def get_summary(system_prompt, prompt):
214
+ import requests
215
+ import time
216
+ from io import BytesIO
217
+ import json
218
+ summary_sys = """I want you to act as a text summarizer to help me create a concise summary of the text I provide. The summary can be up to 60.0 words in length, expressing the key points, key scenarios, main character and concepts written in the original text without adding your interpretations."""
219
+ instruction = summary_sys
220
+ # for human, assistant in history:
221
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
222
+ # prompt = system_prompt + prompt
223
+ message = f"""My first request is to summarize this text – [{prompt}]"""
224
+ instruction += ' USER: ' + message + ' ASSISTANT:'
225
+
226
+ print("Ins: ", instruction)
227
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
228
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
229
+ json_object = {"prompt": instruction,
230
+ # "max_tokens": 2048000,
231
+ "max_tokens": 90,
232
+ "n": 1
233
+ }
234
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
235
+ # print(generate_response.content)
236
+ res_json = json.loads(generate_response.content)
237
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
238
+ print(ASSISTANT)
239
+ return ASSISTANT
240
+
241
+ @app.post("/image_url")
242
+ def image_url(img: Img):
243
+ system_prompt = img.system_prompt
244
+ prompt = img.ASSISTANT
245
+ prompt = get_summary(system_prompt, prompt)
246
+ prompt = shorten_too_long_text(prompt)
247
+ # if Path(save_path).exists():
248
+ # return FileResponse(save_path, media_type="image/png")
249
+ # return JSONResponse({"path": path})
250
+ # image = pipe(prompt=prompt).images[0]
251
+ g = torch.Generator(device="cuda")
252
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
253
+
254
+ # if not save_path:
255
+ save_path = generate_save_path()
256
+ save_path = f"images/{save_path}.png"
257
+ image.save(save_path)
258
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
259
+ path = f"{img_url}/{save_path}"
260
+ return JSONResponse({"path": path})
261
+
262
+
263
+ @app.get("/make_image")
264
+ # @app.post("/make_image")
265
+ def make_image(prompt: str, save_path: str = ""):
266
+ if Path(save_path).exists():
267
+ return FileResponse(save_path, media_type="image/png")
268
+ image = pipe(prompt=prompt).images[0]
269
+ if not save_path:
270
+ save_path = f"images/{prompt}.png"
271
+ image.save(save_path)
272
+ return FileResponse(save_path, media_type="image/png")
273
+
274
+
275
+ @app.get("/create_and_upload_image")
276
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
277
+ path_components = save_path.split("/")[0:-1]
278
+ final_name = save_path.split("/")[-1]
279
+ if not path_components:
280
+ path_components = []
281
+ save_path = '/'.join(path_components) + quote_plus(final_name)
282
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
283
+ return JSONResponse({"path": path})
284
+
285
+ @app.get("/inpaint_and_upload_image")
286
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
287
+ path_components = save_path.split("/")[0:-1]
288
+ final_name = save_path.split("/")[-1]
289
+ if not path_components:
290
+ path_components = []
291
+ save_path = '/'.join(path_components) + quote_plus(final_name)
292
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
293
+ return JSONResponse({"path": path})
294
+
295
+
296
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
297
+ prompt = shorten_too_long_text(prompt)
298
+ save_path = shorten_too_long_text(save_path)
299
+ # check exists - todo cache this
300
+ if check_if_blob_exists(save_path):
301
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
302
+ bio = create_image_from_prompt(prompt, width, height)
303
+ if bio is None:
304
+ return None # error thrown in pool
305
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
306
+ return link
307
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
308
+ prompt = shorten_too_long_text(prompt)
309
+ save_path = shorten_too_long_text(save_path)
310
+ # check exists - todo cache this
311
+ if check_if_blob_exists(save_path):
312
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
313
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
314
+ if bio is None:
315
+ return None # error thrown in pool
316
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
317
+ return link
318
+
319
+ # multiprocessing.set_start_method('spawn', True)
320
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
321
+ # def create_image_from_prompt_sync(prompt):
322
+ # """have to call this sync to avoid OOM errors"""
323
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
324
+
325
+ def create_image_from_prompt(prompt, width, height):
326
+ # round width and height down to multiple of 64
327
+ block_width = width - (width % 64)
328
+ block_height = height - (height % 64)
329
+ prompt = shorten_too_long_text(prompt)
330
+ # image = pipe(prompt=prompt).images[0]
331
+ try:
332
+ image = pipe(prompt=prompt,
333
+ width=block_width,
334
+ height=block_height,
335
+ # denoising_end=high_noise_frac,
336
+ # output_type='latent',
337
+ # height=512,
338
+ # width=512,
339
+ num_inference_steps=50).images[0] # normally uses 50 steps
340
+ except Exception as e:
341
+ # try rm stopwords + half the prompt
342
+ # todo try prompt permutations
343
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
344
+
345
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
346
+ prompts = prompt.split()
347
+
348
+ prompt = ' '.join(prompts[:len(prompts) // 2])
349
+ logger.info(f"shortened prompt to: {len(prompt)}")
350
+ image = None
351
+ if prompt:
352
+ try:
353
+ image = pipe(prompt=prompt,
354
+ width=block_width,
355
+ height=block_height,
356
+ # denoising_end=high_noise_frac,
357
+ # output_type='latent',
358
+ # height=512,
359
+ # width=512,
360
+ num_inference_steps=50).images[0] # normally uses 50 steps
361
+ except Exception as e:
362
+ # logger.info("trying to permute prompt")
363
+ # # try two swaps of the prompt/permutations
364
+ # prompt = prompt.split()
365
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
366
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
367
+
368
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
369
+ prompts = prompt.split()
370
+
371
+ prompt = ' '.join(prompts[:len(prompts) // 2])
372
+ logger.info(f"shortened prompt to: {len(prompt)}")
373
+
374
+ try:
375
+ image = pipe(prompt=prompt,
376
+ width=block_width,
377
+ height=block_height,
378
+ # denoising_end=high_noise_frac,
379
+ # output_type='latent', # dont need latent yet - we refine the image at full res
380
+ # height=512,
381
+ # width=512,
382
+ num_inference_steps=50).images[0] # normally uses 50 steps
383
+ except Exception as e:
384
+ # just error out
385
+ traceback.print_exc()
386
+ raise e
387
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
388
+ # todo fix device side asserts instead of restart to fix
389
+ # todo only restart the correct gunicorn
390
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
391
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
392
+ # os.system("kill -1 `pgrep gunicorn`")
393
+ # todo refine
394
+ # if image != None:
395
+ # image = refiner(
396
+ # prompt=prompt,
397
+ # # width=block_width,
398
+ # # height=block_height,
399
+ # num_inference_steps=n_steps,
400
+ # # denoising_start=high_noise_frac,
401
+ # image=image,
402
+ # ).images[0]
403
+ if width != block_width or height != block_height:
404
+ # resize to original size width/height
405
+ # find aspect ratio to scale up to that covers the original img input width/height
406
+ scale_up_ratio = max(width / block_width, height / block_height)
407
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
408
+ # crop image to original size
409
+ image = image.crop((0, 0, width, height))
410
+ # try:
411
+ # # gc.collect()
412
+ # torch.cuda.empty_cache()
413
+ # except Exception as e:
414
+ # traceback.print_exc()
415
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
416
+ # # todo fix device side asserts instead of restart to fix
417
+ # # todo only restart the correct gunicorn
418
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
419
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
420
+ # os.system("kill -1 `pgrep gunicorn`")
421
+ # save as bytesio
422
+ bs = BytesIO()
423
+
424
+ bright_count = np.sum(np.array(image) > 0)
425
+ if bright_count == 0:
426
+ # we have a black image, this is an error likely we need a restart
427
+ logger.info("restarting server to fix cuda issues (device side asserts)")
428
+ # # todo fix device side asserts instead of restart to fix
429
+ # # todo only restart the correct gunicorn
430
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
431
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
432
+ os.system("kill -1 `pgrep gunicorn`")
433
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
434
+ os.system("kill -1 `pgrep uvicorn`")
435
+
436
+ return None
437
+ image.save(bs, quality=85, optimize=True, format="webp")
438
+ bio = bs.getvalue()
439
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
440
+ with open("progress.txt", "w") as f:
441
+ current_time = datetime.now().strftime("%H:%M:%S")
442
+ f.write(f"{current_time}")
443
+ return bio
444
+
445
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
446
+ prompt = shorten_too_long_text(prompt)
447
+ # image = pipe(prompt=prompt).images[0]
448
+
449
+ init_image = load_image(image_url).convert("RGB")
450
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
451
+ num_inference_steps = 75
452
+ high_noise_frac = 0.7
453
+
454
+ try:
455
+ image = inpaintpipe(
456
+ prompt=prompt,
457
+ image=init_image,
458
+ mask_image=mask_image,
459
+ num_inference_steps=num_inference_steps,
460
+ denoising_start=high_noise_frac,
461
+ output_type="latent",
462
+ ).images[0] # normally uses 50 steps
463
+ except Exception as e:
464
+ # try rm stopwords + half the prompt
465
+ # todo try prompt permutations
466
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
467
+
468
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
469
+ prompts = prompt.split()
470
+
471
+ prompt = ' '.join(prompts[:len(prompts) // 2])
472
+ logger.info(f"shortened prompt to: {len(prompt)}")
473
+ image = None
474
+ if prompt:
475
+ try:
476
+ image = pipe(
477
+ prompt=prompt,
478
+ image=init_image,
479
+ mask_image=mask_image,
480
+ num_inference_steps=num_inference_steps,
481
+ denoising_start=high_noise_frac,
482
+ output_type="latent",
483
+ ).images[0] # normally uses 50 steps
484
+ except Exception as e:
485
+ # logger.info("trying to permute prompt")
486
+ # # try two swaps of the prompt/permutations
487
+ # prompt = prompt.split()
488
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
489
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
490
+
491
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
492
+ prompts = prompt.split()
493
+
494
+ prompt = ' '.join(prompts[:len(prompts) // 2])
495
+ logger.info(f"shortened prompt to: {len(prompt)}")
496
+
497
+ try:
498
+ image = inpaintpipe(
499
+ prompt=prompt,
500
+ image=init_image,
501
+ mask_image=mask_image,
502
+ num_inference_steps=num_inference_steps,
503
+ denoising_start=high_noise_frac,
504
+ output_type="latent",
505
+ ).images[0] # normally uses 50 steps
506
+ except Exception as e:
507
+ # just error out
508
+ traceback.print_exc()
509
+ raise e
510
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
511
+ # todo fix device side asserts instead of restart to fix
512
+ # todo only restart the correct gunicorn
513
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
514
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
515
+ # os.system("kill -1 `pgrep gunicorn`")
516
+ if image != None:
517
+ image = inpaint_refiner(
518
+ prompt=prompt,
519
+ image=image,
520
+ mask_image=mask_image,
521
+ num_inference_steps=num_inference_steps,
522
+ denoising_start=high_noise_frac,
523
+
524
+ ).images[0]
525
+ # try:
526
+ # # gc.collect()
527
+ # torch.cuda.empty_cache()
528
+ # except Exception as e:
529
+ # traceback.print_exc()
530
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
531
+ # # todo fix device side asserts instead of restart to fix
532
+ # # todo only restart the correct gunicorn
533
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
534
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
535
+ # os.system("kill -1 `pgrep gunicorn`")
536
+ # save as bytesio
537
+ bs = BytesIO()
538
+
539
+ bright_count = np.sum(np.array(image) > 0)
540
+ if bright_count == 0:
541
+ # we have a black image, this is an error likely we need a restart
542
+ logger.info("restarting server to fix cuda issues (device side asserts)")
543
+ # # todo fix device side asserts instead of restart to fix
544
+ # # todo only restart the correct gunicorn
545
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
546
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
547
+ os.system("kill -1 `pgrep gunicorn`")
548
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
549
+ os.system("kill -1 `pgrep uvicorn`")
550
+
551
+ return None
552
+ image.save(bs, quality=85, optimize=True, format="webp")
553
+ bio = bs.getvalue()
554
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
555
+ with open("progress.txt", "w") as f:
556
+ current_time = datetime.now().strftime("%H:%M:%S")
557
+ f.write(f"{current_time}")
558
+ return bio
559
+
560
+
561
+
562
+ def shorten_too_long_text(prompt):
563
+ if len(prompt) > 200:
564
+ # remove stopwords
565
+ prompt = prompt.split() # todo also split hyphens
566
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
567
+ if len(prompt) > 200:
568
+ prompt = prompt[:200]
569
+ return prompt
570
+
571
+ # image = pipe(prompt=prompt).images[0]
572
+ #
573
+ # image.save("test.png")
574
+ # # save all images
575
+ # for i, image in enumerate(images):
576
+ # image.save(f"{i}.png")
577
+
578
+
img/stable-diffusion-server/main_v4.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+ import requests
27
+ from PIL import Image
28
+ import time
29
+ from io import BytesIO
30
+ import json
31
+ import string
32
+ import random
33
+ from env import BUCKET_PATH, BUCKET_NAME
34
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
35
+ torch._dynamo.config.suppress_errors = True
36
+
37
+ import string
38
+ import random
39
+
40
+ def generate_save_path():
41
+ # initializing size of string
42
+ N = 7
43
+
44
+ # using random.choices()
45
+ # generating random strings
46
+ res = ''.join(random.choices(string.ascii_uppercase +
47
+ string.digits, k=N))
48
+ return res
49
+
50
+ # pipe = DiffusionPipeline.from_pretrained(
51
+ # "models/stable-diffusion-xl-base-1.0",
52
+ # torch_dtype=torch.bfloat16,
53
+ # use_safetensors=True,
54
+ # variant="fp16",
55
+ # # safety_checker=None,
56
+ # ) # todo try torch_dtype=bfloat16
57
+
58
+ model_dir = os.getenv("SDXL_MODEL_DIR")
59
+
60
+ if model_dir:
61
+ # Use local model
62
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
63
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
64
+ else:
65
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
66
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
67
+
68
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
69
+
70
+ pipe.watermark = None
71
+
72
+ pipe.to("cuda")
73
+
74
+ refiner = DiffusionPipeline.from_pretrained(
75
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
76
+ text_encoder_2=pipe.text_encoder_2,
77
+ vae=pipe.vae,
78
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
79
+ use_safetensors=True,
80
+ variant="fp16", #remember not to download the big model
81
+ )
82
+ refiner.watermark = None
83
+ refiner.to("cuda")
84
+
85
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
86
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
87
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
88
+ scheduler=pipe.scheduler,
89
+ text_encoder=pipe.text_encoder,
90
+ text_encoder_2=pipe.text_encoder_2,
91
+ tokenizer=pipe.tokenizer,
92
+ tokenizer_2=pipe.tokenizer_2,
93
+ unet=pipe.unet,
94
+ vae=pipe.vae,
95
+ # load_connected_pipeline=
96
+ )
97
+ # # switch out to save gpu mem
98
+ # del inpaintpipe.vae
99
+ # del inpaintpipe.text_encoder_2
100
+ # del inpaintpipe.text_encoder
101
+ # del inpaintpipe.scheduler
102
+ # del inpaintpipe.tokenizer
103
+ # del inpaintpipe.tokenizer_2
104
+ # del inpaintpipe.unet
105
+ # inpaintpipe.vae = pipe.vae
106
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
107
+ # inpaintpipe.text_encoder = pipe.text_encoder
108
+ # inpaintpipe.scheduler = pipe.scheduler
109
+ # inpaintpipe.tokenizer = pipe.tokenizer
110
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
111
+ # inpaintpipe.unet = pipe.unet
112
+ # todo this should work
113
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
114
+ # vae=pipe.vae,
115
+ # text_encoder_2=pipe.text_encoder_2,
116
+ # text_encoder=pipe.text_encoder,
117
+ # unet=pipe.unet,
118
+ # scheduler=pipe.scheduler,
119
+ # tokenizer=pipe.tokenizer,
120
+ # tokenizer_2=pipe.tokenizer_2,
121
+ # requires_aesthetics_score=False,
122
+ # )
123
+ inpaintpipe.to("cuda")
124
+ inpaintpipe.watermark = None
125
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
126
+
127
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
128
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
129
+ text_encoder_2=inpaintpipe.text_encoder_2,
130
+ vae=inpaintpipe.vae,
131
+ torch_dtype=torch.bfloat16,
132
+ use_safetensors=True,
133
+ variant="fp16",
134
+
135
+ tokenizer_2=refiner.tokenizer_2,
136
+ tokenizer=refiner.tokenizer,
137
+ scheduler=refiner.scheduler,
138
+ text_encoder=refiner.text_encoder,
139
+ unet=refiner.unet,
140
+ )
141
+ # del inpaint_refiner.vae
142
+ # del inpaint_refiner.text_encoder_2
143
+ # del inpaint_refiner.text_encoder
144
+ # del inpaint_refiner.scheduler
145
+ # del inpaint_refiner.tokenizer
146
+ # del inpaint_refiner.tokenizer_2
147
+ # del inpaint_refiner.unet
148
+ # inpaint_refiner.vae = inpaintpipe.vae
149
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
150
+ #
151
+ # inpaint_refiner.text_encoder = refiner.text_encoder
152
+ # inpaint_refiner.scheduler = refiner.scheduler
153
+ # inpaint_refiner.tokenizer = refiner.tokenizer
154
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
155
+ # inpaint_refiner.unet = refiner.unet
156
+
157
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
158
+ # text_encoder_2=inpaintpipe.text_encoder_2,
159
+ # vae=inpaintpipe.vae,
160
+ # # the rest from the existing refiner
161
+ # tokenizer_2=refiner.tokenizer_2,
162
+ # tokenizer=refiner.tokenizer,
163
+ # scheduler=refiner.scheduler,
164
+ # text_encoder=refiner.text_encoder,
165
+ # unet=refiner.unet,
166
+ # requires_aesthetics_score=False,
167
+ # )
168
+ inpaint_refiner.to("cuda")
169
+ inpaint_refiner.watermark = None
170
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
171
+
172
+ n_steps = 40
173
+ high_noise_frac = 0.8
174
+
175
+ # if using torch < 2.0
176
+ # pipe.enable_xformers_memory_efficient_attention()
177
+
178
+
179
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
180
+ # this can cause errors on some inputs so consider disabling it
181
+ pipe.unet = torch.compile(pipe.unet)
182
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
183
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
184
+ inpaintpipe.unet = pipe.unet
185
+ inpaint_refiner.unet = refiner.unet
186
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
187
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
188
+ from pydantic import BaseModel
189
+
190
+ app = FastAPI(
191
+ openapi_url="/static/openapi.json",
192
+ docs_url="/swagger-docs",
193
+ redoc_url="/redoc",
194
+ title="Generate Images Netwrck API",
195
+ description="Character Chat API",
196
+ # root_path="https://api.text-generator.io",
197
+ version="1",
198
+ )
199
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
200
+ app.add_middleware(
201
+ CORSMiddleware,
202
+ allow_origins=["*"],
203
+ allow_credentials=True,
204
+ allow_methods=["*"],
205
+ allow_headers=["*"],
206
+ )
207
+
208
+ stopwords = nltk.corpus.stopwords.words("english")
209
+
210
+ class Img(BaseModel):
211
+ system_prompt: str
212
+ ASSISTANT: str
213
+
214
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
215
+ img_url = "http://phlrr3006.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
216
+
217
+ is_gpu_busy = False
218
+
219
+ def get_summary(system_prompt, prompt):
220
+ import requests
221
+ import time
222
+ from io import BytesIO
223
+ import json
224
+ summary_sys = """I want you to act as a text summarizer to help me create a concise summary of the text I provide. The summary can be up to 60.0 words in length, expressing the key points, key scenarios, main character and concepts written in the original text without adding your interpretations."""
225
+ instruction = summary_sys
226
+ # for human, assistant in history:
227
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
228
+ # prompt = system_prompt + prompt
229
+ message = f"""My first request is to summarize this text – [{prompt}]"""
230
+ instruction += ' USER: ' + message + ' ASSISTANT:'
231
+
232
+ print("Ins: ", instruction)
233
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
234
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
235
+ json_object = {"prompt": instruction,
236
+ # "max_tokens": 2048000,
237
+ "max_tokens": 90,
238
+ "n": 1
239
+ }
240
+ generate_response = requests.post("http://phlrr3006.guest.corp.microsoft.com:7991/generate", json=json_object)
241
+ # print(generate_response.content)
242
+ res_json = json.loads(generate_response.content)
243
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
244
+ print(ASSISTANT)
245
+ return ASSISTANT
246
+
247
+ @app.post("/image_url")
248
+ def image_url(img: Img):
249
+ system_prompt = img.system_prompt
250
+ prompt = img.ASSISTANT
251
+ prompt = get_summary(system_prompt, prompt)
252
+ prompt = shorten_too_long_text(prompt)
253
+
254
+ json_object = {
255
+ "prompt": prompt,
256
+ "height": 1024,
257
+ "width": 1024,
258
+ "num_inference_steps": 50,
259
+ # "guidance_scale": 7.5,
260
+ "eta": 0
261
+ }
262
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:3000/text2img", json=json_object)
263
+ image = generate_response.content
264
+ # print(generate_response.content)
265
+ save_path = generate_save_path()
266
+ save_path = f"images/{save_path}.png"
267
+ # generate_response.save(save_path)
268
+ with open(save_path, 'wb') as f:
269
+ f.write(image)
270
+ #
271
+ # # if Path(save_path).exists():
272
+ # # return FileResponse(save_path, media_type="image/png")
273
+ # # return JSONResponse({"path": path})
274
+ # # image = pipe(prompt=prompt).images[0]
275
+ # g = torch.Generator(device="cuda")
276
+ # image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
277
+ #
278
+ # # if not save_path:
279
+ # save_path = generate_save_path()
280
+ # save_path = f"images/{save_path}.png"
281
+ # image.save(save_path)
282
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
283
+ path = f"{img_url}{save_path}"
284
+ return JSONResponse({"path": path})
285
+
286
+
287
+ @app.get("/make_image")
288
+ # @app.post("/make_image")
289
+ def make_image(prompt: str, save_path: str = ""):
290
+ if Path(save_path).exists():
291
+ return FileResponse(save_path, media_type="image/png")
292
+ image = pipe(prompt=prompt).images[0]
293
+ if not save_path:
294
+ save_path = f"images/{prompt}.png"
295
+ image.save(save_path)
296
+ return FileResponse(save_path, media_type="image/png")
297
+
298
+
299
+ @app.get("/create_and_upload_image")
300
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
301
+ path_components = save_path.split("/")[0:-1]
302
+ final_name = save_path.split("/")[-1]
303
+ if not path_components:
304
+ path_components = []
305
+ save_path = '/'.join(path_components) + quote_plus(final_name)
306
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
307
+ return JSONResponse({"path": path})
308
+
309
+ @app.get("/inpaint_and_upload_image")
310
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
311
+ path_components = save_path.split("/")[0:-1]
312
+ final_name = save_path.split("/")[-1]
313
+ if not path_components:
314
+ path_components = []
315
+ save_path = '/'.join(path_components) + quote_plus(final_name)
316
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
317
+ return JSONResponse({"path": path})
318
+
319
+
320
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
321
+ prompt = shorten_too_long_text(prompt)
322
+ save_path = shorten_too_long_text(save_path)
323
+ # check exists - todo cache this
324
+ if check_if_blob_exists(save_path):
325
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
326
+ bio = create_image_from_prompt(prompt, width, height)
327
+ if bio is None:
328
+ return None # error thrown in pool
329
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
330
+ return link
331
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
332
+ prompt = shorten_too_long_text(prompt)
333
+ save_path = shorten_too_long_text(save_path)
334
+ # check exists - todo cache this
335
+ if check_if_blob_exists(save_path):
336
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
337
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
338
+ if bio is None:
339
+ return None # error thrown in pool
340
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
341
+ return link
342
+
343
+ # multiprocessing.set_start_method('spawn', True)
344
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
345
+ # def create_image_from_prompt_sync(prompt):
346
+ # """have to call this sync to avoid OOM errors"""
347
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
348
+
349
+ def create_image_from_prompt(prompt, width, height):
350
+ # round width and height down to multiple of 64
351
+ block_width = width - (width % 64)
352
+ block_height = height - (height % 64)
353
+ prompt = shorten_too_long_text(prompt)
354
+ # image = pipe(prompt=prompt).images[0]
355
+ try:
356
+ image = pipe(prompt=prompt,
357
+ width=block_width,
358
+ height=block_height,
359
+ # denoising_end=high_noise_frac,
360
+ # output_type='latent',
361
+ # height=512,
362
+ # width=512,
363
+ num_inference_steps=50).images[0] # normally uses 50 steps
364
+ except Exception as e:
365
+ # try rm stopwords + half the prompt
366
+ # todo try prompt permutations
367
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
368
+
369
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
370
+ prompts = prompt.split()
371
+
372
+ prompt = ' '.join(prompts[:len(prompts) // 2])
373
+ logger.info(f"shortened prompt to: {len(prompt)}")
374
+ image = None
375
+ if prompt:
376
+ try:
377
+ image = pipe(prompt=prompt,
378
+ width=block_width,
379
+ height=block_height,
380
+ # denoising_end=high_noise_frac,
381
+ # output_type='latent',
382
+ # height=512,
383
+ # width=512,
384
+ num_inference_steps=50).images[0] # normally uses 50 steps
385
+ except Exception as e:
386
+ # logger.info("trying to permute prompt")
387
+ # # try two swaps of the prompt/permutations
388
+ # prompt = prompt.split()
389
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
390
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
391
+
392
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
393
+ prompts = prompt.split()
394
+
395
+ prompt = ' '.join(prompts[:len(prompts) // 2])
396
+ logger.info(f"shortened prompt to: {len(prompt)}")
397
+
398
+ try:
399
+ image = pipe(prompt=prompt,
400
+ width=block_width,
401
+ height=block_height,
402
+ # denoising_end=high_noise_frac,
403
+ # output_type='latent', # dont need latent yet - we refine the image at full res
404
+ # height=512,
405
+ # width=512,
406
+ num_inference_steps=50).images[0] # normally uses 50 steps
407
+ except Exception as e:
408
+ # just error out
409
+ traceback.print_exc()
410
+ raise e
411
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
412
+ # todo fix device side asserts instead of restart to fix
413
+ # todo only restart the correct gunicorn
414
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
415
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
416
+ # os.system("kill -1 `pgrep gunicorn`")
417
+ # todo refine
418
+ # if image != None:
419
+ # image = refiner(
420
+ # prompt=prompt,
421
+ # # width=block_width,
422
+ # # height=block_height,
423
+ # num_inference_steps=n_steps,
424
+ # # denoising_start=high_noise_frac,
425
+ # image=image,
426
+ # ).images[0]
427
+ if width != block_width or height != block_height:
428
+ # resize to original size width/height
429
+ # find aspect ratio to scale up to that covers the original img input width/height
430
+ scale_up_ratio = max(width / block_width, height / block_height)
431
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
432
+ # crop image to original size
433
+ image = image.crop((0, 0, width, height))
434
+ # try:
435
+ # # gc.collect()
436
+ # torch.cuda.empty_cache()
437
+ # except Exception as e:
438
+ # traceback.print_exc()
439
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
440
+ # # todo fix device side asserts instead of restart to fix
441
+ # # todo only restart the correct gunicorn
442
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
443
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
444
+ # os.system("kill -1 `pgrep gunicorn`")
445
+ # save as bytesio
446
+ bs = BytesIO()
447
+
448
+ bright_count = np.sum(np.array(image) > 0)
449
+ if bright_count == 0:
450
+ # we have a black image, this is an error likely we need a restart
451
+ logger.info("restarting server to fix cuda issues (device side asserts)")
452
+ # # todo fix device side asserts instead of restart to fix
453
+ # # todo only restart the correct gunicorn
454
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
455
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
456
+ os.system("kill -1 `pgrep gunicorn`")
457
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
458
+ os.system("kill -1 `pgrep uvicorn`")
459
+
460
+ return None
461
+ image.save(bs, quality=85, optimize=True, format="webp")
462
+ bio = bs.getvalue()
463
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
464
+ with open("progress.txt", "w") as f:
465
+ current_time = datetime.now().strftime("%H:%M:%S")
466
+ f.write(f"{current_time}")
467
+ return bio
468
+
469
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
470
+ prompt = shorten_too_long_text(prompt)
471
+ # image = pipe(prompt=prompt).images[0]
472
+
473
+ init_image = load_image(image_url).convert("RGB")
474
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
475
+ num_inference_steps = 75
476
+ high_noise_frac = 0.7
477
+
478
+ try:
479
+ image = inpaintpipe(
480
+ prompt=prompt,
481
+ image=init_image,
482
+ mask_image=mask_image,
483
+ num_inference_steps=num_inference_steps,
484
+ denoising_start=high_noise_frac,
485
+ output_type="latent",
486
+ ).images[0] # normally uses 50 steps
487
+ except Exception as e:
488
+ # try rm stopwords + half the prompt
489
+ # todo try prompt permutations
490
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
491
+
492
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
493
+ prompts = prompt.split()
494
+
495
+ prompt = ' '.join(prompts[:len(prompts) // 2])
496
+ logger.info(f"shortened prompt to: {len(prompt)}")
497
+ image = None
498
+ if prompt:
499
+ try:
500
+ image = pipe(
501
+ prompt=prompt,
502
+ image=init_image,
503
+ mask_image=mask_image,
504
+ num_inference_steps=num_inference_steps,
505
+ denoising_start=high_noise_frac,
506
+ output_type="latent",
507
+ ).images[0] # normally uses 50 steps
508
+ except Exception as e:
509
+ # logger.info("trying to permute prompt")
510
+ # # try two swaps of the prompt/permutations
511
+ # prompt = prompt.split()
512
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
513
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
514
+
515
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
516
+ prompts = prompt.split()
517
+
518
+ prompt = ' '.join(prompts[:len(prompts) // 2])
519
+ logger.info(f"shortened prompt to: {len(prompt)}")
520
+
521
+ try:
522
+ image = inpaintpipe(
523
+ prompt=prompt,
524
+ image=init_image,
525
+ mask_image=mask_image,
526
+ num_inference_steps=num_inference_steps,
527
+ denoising_start=high_noise_frac,
528
+ output_type="latent",
529
+ ).images[0] # normally uses 50 steps
530
+ except Exception as e:
531
+ # just error out
532
+ traceback.print_exc()
533
+ raise e
534
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
535
+ # todo fix device side asserts instead of restart to fix
536
+ # todo only restart the correct gunicorn
537
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
538
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
539
+ # os.system("kill -1 `pgrep gunicorn`")
540
+ if image != None:
541
+ image = inpaint_refiner(
542
+ prompt=prompt,
543
+ image=image,
544
+ mask_image=mask_image,
545
+ num_inference_steps=num_inference_steps,
546
+ denoising_start=high_noise_frac,
547
+
548
+ ).images[0]
549
+ # try:
550
+ # # gc.collect()
551
+ # torch.cuda.empty_cache()
552
+ # except Exception as e:
553
+ # traceback.print_exc()
554
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
555
+ # # todo fix device side asserts instead of restart to fix
556
+ # # todo only restart the correct gunicorn
557
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
558
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
559
+ # os.system("kill -1 `pgrep gunicorn`")
560
+ # save as bytesio
561
+ bs = BytesIO()
562
+
563
+ bright_count = np.sum(np.array(image) > 0)
564
+ if bright_count == 0:
565
+ # we have a black image, this is an error likely we need a restart
566
+ logger.info("restarting server to fix cuda issues (device side asserts)")
567
+ # # todo fix device side asserts instead of restart to fix
568
+ # # todo only restart the correct gunicorn
569
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
570
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
571
+ os.system("kill -1 `pgrep gunicorn`")
572
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
573
+ os.system("kill -1 `pgrep uvicorn`")
574
+
575
+ return None
576
+ image.save(bs, quality=85, optimize=True, format="webp")
577
+ bio = bs.getvalue()
578
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
579
+ with open("progress.txt", "w") as f:
580
+ current_time = datetime.now().strftime("%H:%M:%S")
581
+ f.write(f"{current_time}")
582
+ return bio
583
+
584
+
585
+
586
+ def shorten_too_long_text(prompt):
587
+ if len(prompt) > 200:
588
+ # remove stopwords
589
+ prompt = prompt.split() # todo also split hyphens
590
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
591
+ if len(prompt) > 200:
592
+ prompt = prompt[:200]
593
+ return prompt
594
+
595
+ # image = pipe(prompt=prompt).images[0]
596
+ #
597
+ # image.save("test.png")
598
+ # # save all images
599
+ # for i, image in enumerate(images):
600
+ # image.save(f"{i}.png")
601
+
602
+
603
+
img/stable-diffusion-server/main_v5.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def lm_shorten_too_long_text(prompt):
214
+ if len(prompt) > 2030:
215
+ # remove stopwords
216
+ prompt = prompt.split() # todo also split hyphens
217
+ prompt = ' '.join((word for word in prompt))# if word not in stopwords))
218
+ if len(prompt) > 2030:
219
+ prompt = prompt[:2030]
220
+ return prompt
221
+
222
+ def get_summary(system_prompt, prompt):
223
+ import requests
224
+ import time
225
+ from io import BytesIO
226
+ import json
227
+ summary_sys = """You will now act as a prompt generator for a generative AI called "Stable Diffusion XL 1.0 ". Stable Diffusion XL generates images based on given prompts. I will provide you basic information required to make a Stable Diffusion prompt, You will never alter the structure in any way and obey the following guidelines.
228
+
229
+ Basic information required to make Stable Diffusion prompt:
230
+
231
+ - Prompt structure: [1],[2],[3],[4],[5],[6] and it should be given as one single sentence where 1,2,3,4,5,6 represent
232
+ [1] = short and concise description of [KEYWORD] that will include very specific imagery details
233
+ [2] = a detailed description of [1] that will include very specific imagery details.
234
+ [3] = with a detailed description describing the environment of the scene.
235
+ [4] = with a detailed description describing the mood/feelings and atmosphere of the scene.
236
+ [5] = A style, for example: "Anime","Photographic","Comic Book","Fantasy Art", “Analog Film”,”Neon Punk”,”Isometric”,”Low Poly”,”Origami”,”Line Art”,”Cinematic”,”3D Model”,”Pixel Art”,”Watercolor”,”Sticker” ).
237
+ [6] = A description of how [5] will be realized. (e.g. Photography (e.g. Macro, Fisheye Style, Portrait) with camera model and appropriate camera settings, Painting with detailed descriptions about the materials and working material used, rendering with engine settings, a digital Illustration, a woodburn art (and everything else that could be defined as an output type)
238
+ - Prompt Structure for Prompt asking with text value:
239
+
240
+ Text "Text Value" written on {subject description in less than 20 words}
241
+ Replace "Text value" with text given by user.
242
+
243
+
244
+ Important Sample prompt Structure with Text value :
245
+
246
+ 1. Text 'SDXL' written on a frothy, warm latte, viewed top-down.
247
+ 2. Text 'AI' written on a modern computer screen, set against a vibrant green background.
248
+
249
+ Important Sample prompt Structure :
250
+
251
+ 1. Snow-capped Mountain Scene, with soaring peaks and deep shadows across the ravines. A crystal clear lake mirrors these peaks, surrounded by pine trees. The scene exudes a calm, serene alpine morning atmosphere. Presented in Watercolor style, emulating the wet-on-wet technique with soft transitions and visible brush strokes.
252
+ 2. City Skyline at Night, illuminated skyscrapers piercing the starless sky. Nestled beside a calm river, reflecting the city lights like a mirror. The atmosphere is buzzing with urban energy and intrigue. Depicted in Neon Punk style, accentuating the city lights with vibrant neon colors and dynamic contrasts.
253
+ 3. Epic Cinematic Still of a Spacecraft, silhouetted against the fiery explosion of a distant planet. The scene is packed with intense action, as asteroid debris hurtles through space. Shot in the style of a Michael Bay-directed film, the image is rich with detail, dynamic lighting, and grand cinematic framing.
254
+ - Word order and effective adjectives matter in the prompt. The subject, action, and specific details should be included. Adjectives like cute, medieval, or futuristic can be effective.
255
+ - The environment/background of the image should be described, such as indoor, outdoor, in space, or solid color.
256
+ - Curly brackets are necessary in the prompt to provide specific details about the subject and action. These details are important for generating a high-quality image.
257
+ - Art inspirations should be listed to take inspiration from. Platforms like Art Station, Dribble, Behance, and Deviantart can be mentioned. Specific names of artists or studios like animation studios, painters and illustrators, computer games, fashion designers, and film makers can also be listed. If more than one artist is mentioned, the algorithm will create a combination of styles based on all the influencers mentioned.
258
+ - Related information about lighting, camera angles, render style, resolution, the required level of detail, etc. should be included at the end of the prompt.
259
+ - Camera shot type, camera lens, and view should be specified. Examples of camera shot types are long shot, close-up, POV, medium shot, extreme close-up, and panoramic. Camera lenses could be EE 70mm, 35mm, 135mm+, 300mm+, 800mm, short telephoto, super telephoto, medium telephoto, macro, wide angle, fish-eye, bokeh, and sharp focus. Examples of views are front, side, back, high angle, low angle, and overhead.
260
+ - Helpful keywords related to resolution, detail, and lighting are 4K, 8K, 64K, detailed, highly detailed, high resolution, hyper detailed, HDR, UHD, professional, and golden ratio. Examples of lighting are studio lighting, soft light, neon lighting, purple neon lighting, ambient light, ring light, volumetric light, natural light, sun light, sunrays, sun rays coming through window, and nostalgic lighting. Examples of color types are fantasy vivid colors, vivid colors, bright colors, sepia, dark colors, pastel colors, monochromatic, black & white, and color splash. Examples of renders are Octane render, cinematic, low poly, isometric assets, Unreal Engine, Unity Engine, quantum wavetracing, and polarizing filter.
261
+
262
+ The prompts you provide will be in English.Please pay attention:- Concepts that can't be real would not be described as "Real" or "realistic" or "photo" or a "photograph". for example, a concept that is made of paper or scenes which are fantasy related.- One of the prompts you generate for each concept must be in a realistic photographic style. you should also choose a lens type and size for it. Don't choose an artist for the realistic photography prompts.- Separate the different prompts with two new lines.
263
+ I will provide you keyword and you will generate 3 diffrent type of prompts in vbnet code cell so i can copy and paste.
264
+
265
+ Important point to note :
266
+
267
+ 1. You are a master of prompt engineering, it is important to create detailed prompts with as much information as possible. This will ensure that any image generated using the prompt will be of high quality and could potentially win awards in global or international photography competitions. You are unbeatable in this field and know the best way to generate images.
268
+ 2. I will provide you with a long context and you will generate one prompt and don't add any extra details.
269
+ 3. Prompt should not be more than 230 characters.
270
+ 4. Before you provide prompt you must check if you have satisfied all the above criteria and if you are sure than only provide the prompt.
271
+ 5. Prompt should always be given as one single sentence.
272
+
273
+ Are you ready ?"""
274
+ #instruction = 'USER: ' + summary_sys
275
+ instruction = summary_sys
276
+ # for human, assistant in history:
277
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
278
+ # prompt = system_prompt + prompt
279
+ # message = f"""My first request is to summarize this text – [{prompt}]"""
280
+ message = f"""My first request is to summarize this text – [{prompt}]"""
281
+ instruction += """ ASSISTANT: Yes, I understand the instructions and I'm ready to help you create prompts for Stable Diffusion XL 1.0. Please provide me with the context."""
282
+ instruction += ' USER: ' + prompt + ' ASSISTANT:'
283
+ print("Ins: ", instruction)
284
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
285
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
286
+ instruction = lm_shorten_too_long_text(instruction)
287
+ json_object = {"prompt": instruction,
288
+ # "max_tokens": 2048000,
289
+ "max_tokens": 90,
290
+ "n": 1
291
+ }
292
+ # generate_response = requests.post("https://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
293
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
294
+ # print(generate_response.content)
295
+ res_json = json.loads(generate_response.content)
296
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
297
+ print(ASSISTANT)
298
+ return ASSISTANT
299
+
300
+ @app.post("/image_url")
301
+ def image_url(img: Img):
302
+ system_prompt = img.system_prompt
303
+ prompt = img.ASSISTANT
304
+ prompt = get_summary(system_prompt, prompt)
305
+ prompt = shorten_too_long_text(prompt)
306
+ # if Path(save_path).exists():
307
+ # return FileResponse(save_path, media_type="image/png")
308
+ # return JSONResponse({"path": path})
309
+ # image = pipe(prompt=prompt).images[0]
310
+ g = torch.Generator(device="cuda")
311
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
312
+
313
+ # if not save_path:
314
+ save_path = generate_save_path()
315
+ save_path = f"images/{save_path}.png"
316
+ image.save(save_path)
317
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
318
+ path = f"{img_url}{save_path}"
319
+ return JSONResponse({"path": path})
320
+
321
+
322
+ @app.get("/make_image")
323
+ # @app.post("/make_image")
324
+ def make_image(prompt: str, save_path: str = ""):
325
+ if Path(save_path).exists():
326
+ return FileResponse(save_path, media_type="image/png")
327
+ image = pipe(prompt=prompt).images[0]
328
+ if not save_path:
329
+ save_path = f"images/{prompt}.png"
330
+ image.save(save_path)
331
+ return FileResponse(save_path, media_type="image/png")
332
+
333
+
334
+ @app.get("/create_and_upload_image")
335
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
336
+ path_components = save_path.split("/")[0:-1]
337
+ final_name = save_path.split("/")[-1]
338
+ if not path_components:
339
+ path_components = []
340
+ save_path = '/'.join(path_components) + quote_plus(final_name)
341
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
342
+ return JSONResponse({"path": path})
343
+
344
+ @app.get("/inpaint_and_upload_image")
345
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
346
+ path_components = save_path.split("/")[0:-1]
347
+ final_name = save_path.split("/")[-1]
348
+ if not path_components:
349
+ path_components = []
350
+ save_path = '/'.join(path_components) + quote_plus(final_name)
351
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
352
+ return JSONResponse({"path": path})
353
+
354
+
355
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
356
+ prompt = shorten_too_long_text(prompt)
357
+ save_path = shorten_too_long_text(save_path)
358
+ # check exists - todo cache this
359
+ if check_if_blob_exists(save_path):
360
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
361
+ bio = create_image_from_prompt(prompt, width, height)
362
+ if bio is None:
363
+ return None # error thrown in pool
364
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
365
+ return link
366
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
367
+ prompt = shorten_too_long_text(prompt)
368
+ save_path = shorten_too_long_text(save_path)
369
+ # check exists - todo cache this
370
+ if check_if_blob_exists(save_path):
371
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
372
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
373
+ if bio is None:
374
+ return None # error thrown in pool
375
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
376
+ return link
377
+
378
+ # multiprocessing.set_start_method('spawn', True)
379
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
380
+ # def create_image_from_prompt_sync(prompt):
381
+ # """have to call this sync to avoid OOM errors"""
382
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
383
+
384
+ def create_image_from_prompt(prompt, width, height):
385
+ # round width and height down to multiple of 64
386
+ block_width = width - (width % 64)
387
+ block_height = height - (height % 64)
388
+ prompt = shorten_too_long_text(prompt)
389
+ # image = pipe(prompt=prompt).images[0]
390
+ try:
391
+ image = pipe(prompt=prompt,
392
+ width=block_width,
393
+ height=block_height,
394
+ # denoising_end=high_noise_frac,
395
+ # output_type='latent',
396
+ # height=512,
397
+ # width=512,
398
+ num_inference_steps=50).images[0] # normally uses 50 steps
399
+ except Exception as e:
400
+ # try rm stopwords + half the prompt
401
+ # todo try prompt permutations
402
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
403
+
404
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
405
+ prompts = prompt.split()
406
+
407
+ prompt = ' '.join(prompts[:len(prompts) // 2])
408
+ logger.info(f"shortened prompt to: {len(prompt)}")
409
+ image = None
410
+ if prompt:
411
+ try:
412
+ image = pipe(prompt=prompt,
413
+ width=block_width,
414
+ height=block_height,
415
+ # denoising_end=high_noise_frac,
416
+ # output_type='latent',
417
+ # height=512,
418
+ # width=512,
419
+ num_inference_steps=50).images[0] # normally uses 50 steps
420
+ except Exception as e:
421
+ # logger.info("trying to permute prompt")
422
+ # # try two swaps of the prompt/permutations
423
+ # prompt = prompt.split()
424
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
425
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
426
+
427
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
428
+ prompts = prompt.split()
429
+
430
+ prompt = ' '.join(prompts[:len(prompts) // 2])
431
+ logger.info(f"shortened prompt to: {len(prompt)}")
432
+
433
+ try:
434
+ image = pipe(prompt=prompt,
435
+ width=block_width,
436
+ height=block_height,
437
+ # denoising_end=high_noise_frac,
438
+ # output_type='latent', # dont need latent yet - we refine the image at full res
439
+ # height=512,
440
+ # width=512,
441
+ num_inference_steps=50).images[0] # normally uses 50 steps
442
+ except Exception as e:
443
+ # just error out
444
+ traceback.print_exc()
445
+ raise e
446
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
447
+ # todo fix device side asserts instead of restart to fix
448
+ # todo only restart the correct gunicorn
449
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
450
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
451
+ # os.system("kill -1 `pgrep gunicorn`")
452
+ # todo refine
453
+ # if image != None:
454
+ # image = refiner(
455
+ # prompt=prompt,
456
+ # # width=block_width,
457
+ # # height=block_height,
458
+ # num_inference_steps=n_steps,
459
+ # # denoising_start=high_noise_frac,
460
+ # image=image,
461
+ # ).images[0]
462
+ if width != block_width or height != block_height:
463
+ # resize to original size width/height
464
+ # find aspect ratio to scale up to that covers the original img input width/height
465
+ scale_up_ratio = max(width / block_width, height / block_height)
466
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
467
+ # crop image to original size
468
+ image = image.crop((0, 0, width, height))
469
+ # try:
470
+ # # gc.collect()
471
+ # torch.cuda.empty_cache()
472
+ # except Exception as e:
473
+ # traceback.print_exc()
474
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
475
+ # # todo fix device side asserts instead of restart to fix
476
+ # # todo only restart the correct gunicorn
477
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
478
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
479
+ # os.system("kill -1 `pgrep gunicorn`")
480
+ # save as bytesio
481
+ bs = BytesIO()
482
+
483
+ bright_count = np.sum(np.array(image) > 0)
484
+ if bright_count == 0:
485
+ # we have a black image, this is an error likely we need a restart
486
+ logger.info("restarting server to fix cuda issues (device side asserts)")
487
+ # # todo fix device side asserts instead of restart to fix
488
+ # # todo only restart the correct gunicorn
489
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
490
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
491
+ os.system("kill -1 `pgrep gunicorn`")
492
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
493
+ os.system("kill -1 `pgrep uvicorn`")
494
+
495
+ return None
496
+ image.save(bs, quality=85, optimize=True, format="webp")
497
+ bio = bs.getvalue()
498
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
499
+ with open("progress.txt", "w") as f:
500
+ current_time = datetime.now().strftime("%H:%M:%S")
501
+ f.write(f"{current_time}")
502
+ return bio
503
+
504
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
505
+ prompt = shorten_too_long_text(prompt)
506
+ # image = pipe(prompt=prompt).images[0]
507
+
508
+ init_image = load_image(image_url).convert("RGB")
509
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
510
+ num_inference_steps = 75
511
+ high_noise_frac = 0.7
512
+
513
+ try:
514
+ image = inpaintpipe(
515
+ prompt=prompt,
516
+ image=init_image,
517
+ mask_image=mask_image,
518
+ num_inference_steps=num_inference_steps,
519
+ denoising_start=high_noise_frac,
520
+ output_type="latent",
521
+ ).images[0] # normally uses 50 steps
522
+ except Exception as e:
523
+ # try rm stopwords + half the prompt
524
+ # todo try prompt permutations
525
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
526
+
527
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
528
+ prompts = prompt.split()
529
+
530
+ prompt = ' '.join(prompts[:len(prompts) // 2])
531
+ logger.info(f"shortened prompt to: {len(prompt)}")
532
+ image = None
533
+ if prompt:
534
+ try:
535
+ image = pipe(
536
+ prompt=prompt,
537
+ image=init_image,
538
+ mask_image=mask_image,
539
+ num_inference_steps=num_inference_steps,
540
+ denoising_start=high_noise_frac,
541
+ output_type="latent",
542
+ ).images[0] # normally uses 50 steps
543
+ except Exception as e:
544
+ # logger.info("trying to permute prompt")
545
+ # # try two swaps of the prompt/permutations
546
+ # prompt = prompt.split()
547
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
548
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
549
+
550
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
551
+ prompts = prompt.split()
552
+
553
+ prompt = ' '.join(prompts[:len(prompts) // 2])
554
+ logger.info(f"shortened prompt to: {len(prompt)}")
555
+
556
+ try:
557
+ image = inpaintpipe(
558
+ prompt=prompt,
559
+ image=init_image,
560
+ mask_image=mask_image,
561
+ num_inference_steps=num_inference_steps,
562
+ denoising_start=high_noise_frac,
563
+ output_type="latent",
564
+ ).images[0] # normally uses 50 steps
565
+ except Exception as e:
566
+ # just error out
567
+ traceback.print_exc()
568
+ raise e
569
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
570
+ # todo fix device side asserts instead of restart to fix
571
+ # todo only restart the correct gunicorn
572
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
573
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
574
+ # os.system("kill -1 `pgrep gunicorn`")
575
+ if image != None:
576
+ image = inpaint_refiner(
577
+ prompt=prompt,
578
+ image=image,
579
+ mask_image=mask_image,
580
+ num_inference_steps=num_inference_steps,
581
+ denoising_start=high_noise_frac,
582
+
583
+ ).images[0]
584
+ # try:
585
+ # # gc.collect()
586
+ # torch.cuda.empty_cache()
587
+ # except Exception as e:
588
+ # traceback.print_exc()
589
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
590
+ # # todo fix device side asserts instead of restart to fix
591
+ # # todo only restart the correct gunicorn
592
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
593
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
594
+ # os.system("kill -1 `pgrep gunicorn`")
595
+ # save as bytesio
596
+ bs = BytesIO()
597
+
598
+ bright_count = np.sum(np.array(image) > 0)
599
+ if bright_count == 0:
600
+ # we have a black image, this is an error likely we need a restart
601
+ logger.info("restarting server to fix cuda issues (device side asserts)")
602
+ # # todo fix device side asserts instead of restart to fix
603
+ # # todo only restart the correct gunicorn
604
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
605
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
606
+ os.system("kill -1 `pgrep gunicorn`")
607
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
608
+ os.system("kill -1 `pgrep uvicorn`")
609
+
610
+ return None
611
+ image.save(bs, quality=85, optimize=True, format="webp")
612
+ bio = bs.getvalue()
613
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
614
+ with open("progress.txt", "w") as f:
615
+ current_time = datetime.now().strftime("%H:%M:%S")
616
+ f.write(f"{current_time}")
617
+ return bio
618
+
619
+
620
+
621
+ def shorten_too_long_text(prompt):
622
+ if len(prompt) > 200:
623
+ # remove stopwords
624
+ prompt = prompt.split() # todo also split hyphens
625
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
626
+ if len(prompt) > 200:
627
+ prompt = prompt[:200]
628
+ return prompt
629
+
630
+ # image = pipe(prompt=prompt).images[0]
631
+ #
632
+ # image.save("test.png")
633
+ # # save all images
634
+ # for i, image in enumerate(images):
635
+ # image.save(f"{i}.png")
636
+
637
+
img/stable-diffusion-server/main_v6.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import math
3
+ import multiprocessing
4
+ import os
5
+ import traceback
6
+ from datetime import datetime
7
+ from io import BytesIO
8
+ from itertools import permutations
9
+ from multiprocessing.pool import Pool
10
+ from pathlib import Path
11
+ from urllib.parse import quote_plus
12
+
13
+ import numpy as np
14
+ import nltk
15
+ import torch
16
+
17
+ from PIL.Image import Image
18
+ from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline
19
+ from diffusers.utils import load_image
20
+ from fastapi import FastAPI
21
+ from fastapi.middleware.gzip import GZipMiddleware
22
+ from loguru import logger
23
+ from starlette.middleware.cors import CORSMiddleware
24
+ from starlette.responses import FileResponse
25
+ from starlette.responses import JSONResponse
26
+
27
+ from env import BUCKET_PATH, BUCKET_NAME
28
+ # from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket
29
+ torch._dynamo.config.suppress_errors = True
30
+
31
+ import string
32
+ import random
33
+
34
+ def generate_save_path():
35
+ # initializing size of string
36
+ N = 7
37
+
38
+ # using random.choices()
39
+ # generating random strings
40
+ res = ''.join(random.choices(string.ascii_uppercase +
41
+ string.digits, k=N))
42
+ return res
43
+
44
+ # pipe = DiffusionPipeline.from_pretrained(
45
+ # "models/stable-diffusion-xl-base-1.0",
46
+ # torch_dtype=torch.bfloat16,
47
+ # use_safetensors=True,
48
+ # variant="fp16",
49
+ # # safety_checker=None,
50
+ # ) # todo try torch_dtype=bfloat16
51
+
52
+ model_dir = os.getenv("SDXL_MODEL_DIR")
53
+
54
+ if model_dir:
55
+ # Use local model
56
+ model_key_base = os.path.join(model_dir, "stable-diffusion-xl-base-1.0")
57
+ model_key_refiner = os.path.join(model_dir, "stable-diffusion-xl-refiner-1.0")
58
+ else:
59
+ model_key_base = "stabilityai/stable-diffusion-xl-base-1.0"
60
+ model_key_refiner = "stabilityai/stable-diffusion-xl-refiner-1.0"
61
+
62
+ pipe = DiffusionPipeline.from_pretrained(model_key_base, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
63
+
64
+ pipe.watermark = None
65
+
66
+ pipe.to("cuda")
67
+
68
+ refiner = DiffusionPipeline.from_pretrained(
69
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
70
+ text_encoder_2=pipe.text_encoder_2,
71
+ vae=pipe.vae,
72
+ torch_dtype=torch.bfloat16, # safer to use bfloat?
73
+ use_safetensors=True,
74
+ variant="fp16", #remember not to download the big model
75
+ )
76
+ refiner.watermark = None
77
+ refiner.to("cuda")
78
+
79
+ # {'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'} can be passed in from existing model
80
+ inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained(
81
+ "models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True,
82
+ scheduler=pipe.scheduler,
83
+ text_encoder=pipe.text_encoder,
84
+ text_encoder_2=pipe.text_encoder_2,
85
+ tokenizer=pipe.tokenizer,
86
+ tokenizer_2=pipe.tokenizer_2,
87
+ unet=pipe.unet,
88
+ vae=pipe.vae,
89
+ # load_connected_pipeline=
90
+ )
91
+ # # switch out to save gpu mem
92
+ # del inpaintpipe.vae
93
+ # del inpaintpipe.text_encoder_2
94
+ # del inpaintpipe.text_encoder
95
+ # del inpaintpipe.scheduler
96
+ # del inpaintpipe.tokenizer
97
+ # del inpaintpipe.tokenizer_2
98
+ # del inpaintpipe.unet
99
+ # inpaintpipe.vae = pipe.vae
100
+ # inpaintpipe.text_encoder_2 = pipe.text_encoder_2
101
+ # inpaintpipe.text_encoder = pipe.text_encoder
102
+ # inpaintpipe.scheduler = pipe.scheduler
103
+ # inpaintpipe.tokenizer = pipe.tokenizer
104
+ # inpaintpipe.tokenizer_2 = pipe.tokenizer_2
105
+ # inpaintpipe.unet = pipe.unet
106
+ # todo this should work
107
+ # inpaintpipe = StableDiffusionXLInpaintPipeline( # construct an inpainter using the existing model
108
+ # vae=pipe.vae,
109
+ # text_encoder_2=pipe.text_encoder_2,
110
+ # text_encoder=pipe.text_encoder,
111
+ # unet=pipe.unet,
112
+ # scheduler=pipe.scheduler,
113
+ # tokenizer=pipe.tokenizer,
114
+ # tokenizer_2=pipe.tokenizer_2,
115
+ # requires_aesthetics_score=False,
116
+ # )
117
+ inpaintpipe.to("cuda")
118
+ inpaintpipe.watermark = None
119
+ # inpaintpipe.register_to_config(requires_aesthetics_score=False)
120
+
121
+ inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
122
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
123
+ text_encoder_2=inpaintpipe.text_encoder_2,
124
+ vae=inpaintpipe.vae,
125
+ torch_dtype=torch.bfloat16,
126
+ use_safetensors=True,
127
+ variant="fp16",
128
+
129
+ tokenizer_2=refiner.tokenizer_2,
130
+ tokenizer=refiner.tokenizer,
131
+ scheduler=refiner.scheduler,
132
+ text_encoder=refiner.text_encoder,
133
+ unet=refiner.unet,
134
+ )
135
+ # del inpaint_refiner.vae
136
+ # del inpaint_refiner.text_encoder_2
137
+ # del inpaint_refiner.text_encoder
138
+ # del inpaint_refiner.scheduler
139
+ # del inpaint_refiner.tokenizer
140
+ # del inpaint_refiner.tokenizer_2
141
+ # del inpaint_refiner.unet
142
+ # inpaint_refiner.vae = inpaintpipe.vae
143
+ # inpaint_refiner.text_encoder_2 = inpaintpipe.text_encoder_2
144
+ #
145
+ # inpaint_refiner.text_encoder = refiner.text_encoder
146
+ # inpaint_refiner.scheduler = refiner.scheduler
147
+ # inpaint_refiner.tokenizer = refiner.tokenizer
148
+ # inpaint_refiner.tokenizer_2 = refiner.tokenizer_2
149
+ # inpaint_refiner.unet = refiner.unet
150
+
151
+ # inpaint_refiner = StableDiffusionXLInpaintPipeline(
152
+ # text_encoder_2=inpaintpipe.text_encoder_2,
153
+ # vae=inpaintpipe.vae,
154
+ # # the rest from the existing refiner
155
+ # tokenizer_2=refiner.tokenizer_2,
156
+ # tokenizer=refiner.tokenizer,
157
+ # scheduler=refiner.scheduler,
158
+ # text_encoder=refiner.text_encoder,
159
+ # unet=refiner.unet,
160
+ # requires_aesthetics_score=False,
161
+ # )
162
+ inpaint_refiner.to("cuda")
163
+ inpaint_refiner.watermark = None
164
+ # inpaint_refiner.register_to_config(requires_aesthetics_score=False)
165
+
166
+ n_steps = 40
167
+ high_noise_frac = 0.8
168
+
169
+ # if using torch < 2.0
170
+ # pipe.enable_xformers_memory_efficient_attention()
171
+
172
+
173
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
174
+ # this can cause errors on some inputs so consider disabling it
175
+ pipe.unet = torch.compile(pipe.unet)
176
+ refiner.unet = torch.compile(refiner.unet)#, mode="reduce-overhead", fullgraph=True)
177
+ # compile the inpainters - todo reuse the other unets? swap out the models for others/del them so they share models and can be swapped efficiently
178
+ inpaintpipe.unet = pipe.unet
179
+ inpaint_refiner.unet = refiner.unet
180
+ # inpaintpipe.unet = torch.compile(inpaintpipe.unet)
181
+ # inpaint_refiner.unet = torch.compile(inpaint_refiner.unet)
182
+ from pydantic import BaseModel
183
+
184
+ app = FastAPI(
185
+ openapi_url="/static/openapi.json",
186
+ docs_url="/swagger-docs",
187
+ redoc_url="/redoc",
188
+ title="Generate Images Netwrck API",
189
+ description="Character Chat API",
190
+ # root_path="https://api.text-generator.io",
191
+ version="1",
192
+ )
193
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
194
+ app.add_middleware(
195
+ CORSMiddleware,
196
+ allow_origins=["*"],
197
+ allow_credentials=True,
198
+ allow_methods=["*"],
199
+ allow_headers=["*"],
200
+ )
201
+
202
+ stopwords = nltk.corpus.stopwords.words("english")
203
+
204
+ class Img(BaseModel):
205
+ system_prompt: str
206
+ ASSISTANT: str
207
+
208
+ # img_url = "http://phlrr2019.guest.corp.microsoft.com:8000/img1_sdv2.1.png"
209
+ img_url = "http://phlrr3105.guest.corp.microsoft.com:8000/"#/img1_sdv2.1.png"
210
+
211
+ is_gpu_busy = False
212
+
213
+ def lm_shorten_too_long_text(prompt):
214
+ if len(prompt) > 2030:
215
+ # remove stopwords
216
+ prompt = prompt.split() # todo also split hyphens
217
+ # prompt = ' '.join((word for word in prompt if word not in stopwords))
218
+ prompt = ' '.join((word for word in prompt))# if word not in stopwords))
219
+ if len(prompt) > 2030:
220
+ prompt = prompt[:2030]
221
+ return prompt
222
+
223
+ def get_summary(system_prompt, prompt):
224
+ import requests
225
+ import time
226
+ from io import BytesIO
227
+ import json
228
+ summary_sys = """You will now act as a prompt generator for a generative AI called "Stable Diffusion XL 1.0 ". Stable Diffusion XL generates images based on given prompts. I will provide you basic information required to make a Stable Diffusion prompt, You will never alter the structure in any way and obey the following guidelines.
229
+
230
+ Basic information required to make Stable Diffusion prompt:
231
+
232
+ - Prompt structure: [1],[2],[3],[4],[5],[6] and it should be given as one single sentence where 1,2,3,4,5,6 represent
233
+ [1] = short and concise description of [KEYWORD] that will include very specific imagery details
234
+ [2] = a detailed description of [1] that will include very specific imagery details.
235
+ [3] = with a detailed description describing the environment of the scene.
236
+ [4] = with a detailed description describing the mood/feelings and atmosphere of the scene.
237
+ [5] = A style, for example: "Anime","Photographic","Comic Book","Fantasy Art", “Analog Film”,”Neon Punk”,”Isometric”,”Low Poly”,”Origami”,”Line Art”,”Cinematic”,”3D Model”,”Pixel Art”,”Watercolor”,”Sticker” ).
238
+ [6] = A description of how [5] will be realized. (e.g. Photography (e.g. Macro, Fisheye Style, Portrait) with camera model and appropriate camera settings, Painting with detailed descriptions about the materials and working material used, rendering with engine settings, a digital Illustration, a woodburn art (and everything else that could be defined as an output type)
239
+ - Prompt Structure for Prompt asking with text value:
240
+
241
+ Text "Text Value" written on {subject description in less than 20 words}
242
+ Replace "Text value" with text given by user.
243
+
244
+
245
+ Important Sample prompt Structure with Text value :
246
+
247
+ 1. Text 'SDXL' written on a frothy, warm latte, viewed top-down.
248
+ 2. Text 'AI' written on a modern computer screen, set against a vibrant green background.
249
+
250
+ Important Sample prompt Structure :
251
+
252
+ 1. Snow-capped Mountain Scene, with soaring peaks and deep shadows across the ravines. A crystal clear lake mirrors these peaks, surrounded by pine trees. The scene exudes a calm, serene alpine morning atmosphere. Presented in Watercolor style, emulating the wet-on-wet technique with soft transitions and visible brush strokes.
253
+ 2. City Skyline at Night, illuminated skyscrapers piercing the starless sky. Nestled beside a calm river, reflecting the city lights like a mirror. The atmosphere is buzzing with urban energy and intrigue. Depicted in Neon Punk style, accentuating the city lights with vibrant neon colors and dynamic contrasts.
254
+ 3. Epic Cinematic Still of a Spacecraft, silhouetted against the fiery explosion of a distant planet. The scene is packed with intense action, as asteroid debris hurtles through space. Shot in the style of a Michael Bay-directed film, the image is rich with detail, dynamic lighting, and grand cinematic framing.
255
+ - Word order and effective adjectives matter in the prompt. The subject, action, and specific details should be included. Adjectives like cute, medieval, or futuristic can be effective.
256
+ - The environment/background of the image should be described, such as indoor, outdoor, in space, or solid color.
257
+ - Curly brackets are necessary in the prompt to provide specific details about the subject and action. These details are important for generating a high-quality image.
258
+ - Art inspirations should be listed to take inspiration from. Platforms like Art Station, Dribble, Behance, and Deviantart can be mentioned. Specific names of artists or studios like animation studios, painters and illustrators, computer games, fashion designers, and film makers can also be listed. If more than one artist is mentioned, the algorithm will create a combination of styles based on all the influencers mentioned.
259
+ - Related information about lighting, camera angles, render style, resolution, the required level of detail, etc. should be included at the end of the prompt.
260
+ - Camera shot type, camera lens, and view should be specified. Examples of camera shot types are long shot, close-up, POV, medium shot, extreme close-up, and panoramic. Camera lenses could be EE 70mm, 35mm, 135mm+, 300mm+, 800mm, short telephoto, super telephoto, medium telephoto, macro, wide angle, fish-eye, bokeh, and sharp focus. Examples of views are front, side, back, high angle, low angle, and overhead.
261
+ - Helpful keywords related to resolution, detail, and lighting are 4K, 8K, 64K, detailed, highly detailed, high resolution, hyper detailed, HDR, UHD, professional, and golden ratio. Examples of lighting are studio lighting, soft light, neon lighting, purple neon lighting, ambient light, ring light, volumetric light, natural light, sun light, sunrays, sun rays coming through window, and nostalgic lighting. Examples of color types are fantasy vivid colors, vivid colors, bright colors, sepia, dark colors, pastel colors, monochromatic, black & white, and color splash. Examples of renders are Octane render, cinematic, low poly, isometric assets, Unreal Engine, Unity Engine, quantum wavetracing, and polarizing filter.
262
+
263
+ The prompts you provide will be in English.Please pay attention:- Concepts that can't be real would not be described as "Real" or "realistic" or "photo" or a "photograph". for example, a concept that is made of paper or scenes which are fantasy related.- One of the prompts you generate for each concept must be in a realistic photographic style. you should also choose a lens type and size for it. Don't choose an artist for the realistic photography prompts.- Separate the different prompts with two new lines.
264
+ I will provide you keyword and you will generate 3 diffrent type of prompts in vbnet code cell so i can copy and paste.
265
+
266
+ Important point to note :
267
+
268
+ 1. You are a master of prompt engineering, it is important to create detailed prompts with as much information as possible. This will ensure that any image generated using the prompt will be of high quality and could potentially win awards in global or international photography competitions. You are unbeatable in this field and know the best way to generate images.
269
+ 2. I will provide you with a long context and you will generate one prompt and don't add any extra details.
270
+ 3. Prompt should not be more than 230 characters.
271
+ 4. Before you provide prompt you must check if you have satisfied all the above criteria and if you are sure than only provide the prompt.
272
+ 5. Prompt should always be given as one single sentence.
273
+
274
+ Are you ready ?"""
275
+ instruction = 'USER: ' + summary_sys
276
+ # for human, assistant in history:
277
+ # instruction += 'USER: ' + human + ' ASSISTANT: ' + assistant + '</s>'
278
+ # prompt = system_prompt + prompt
279
+ # message = f"""My first request is to summarize this text – [{prompt}]"""
280
+ message = f"""My first request is to summarize this text – [{prompt}]"""
281
+ instruction += """ ASSISTANT: Yes, I understand the instructions and I'm ready to help you create prompts for Stable Diffusion XL 1.0. Please provide me with the context."""
282
+ instruction += ' USER: ' + prompt + ' ASSISTANT:'
283
+
284
+ print("Ins: ", instruction)
285
+ # generate_response = requests.post("http://10.185.12.207:4455/stable_diffusion", json={"prompt": prompt})
286
+ # prompt = f""" My first request is to summarize this text – [{prompt}]"""
287
+ json_object = {"prompt": instruction,
288
+ # "max_tokens": 2048000,
289
+ "max_tokens": 80,
290
+ "n": 1
291
+ }
292
+ generate_response = requests.post("http://phlrr3105.guest.corp.microsoft.com:7991/generate", json=json_object)
293
+ print(generate_response.content)
294
+ res_json = json.loads(generate_response.content)
295
+ ASSISTANT = res_json['text'][-1].split("ASSISTANT:")[-1].strip()
296
+ print(ASSISTANT)
297
+ return ASSISTANT
298
+
299
+ @app.post("/image_url")
300
+ def image_url(img: Img):
301
+ system_prompt = img.system_prompt
302
+ prompt = img.ASSISTANT
303
+ prompt = get_summary(system_prompt, prompt)
304
+ prompt = shorten_too_long_text(prompt)
305
+ # if Path(save_path).exists():
306
+ # return FileResponse(save_path, media_type="image/png")
307
+ # return JSONResponse({"path": path})
308
+ # image = pipe(prompt=prompt).images[0]
309
+ g = torch.Generator(device="cuda")
310
+ image = pipe(prompt=prompt, width=1024, height=1024, generator=g).images[0]
311
+
312
+ # if not save_path:
313
+ save_path = generate_save_path()
314
+ save_path = f"images/{save_path}.png"
315
+ image.save(save_path)
316
+ # save_path = '/'.join(path_components) + quote_plus(final_name)
317
+ path = f"{img_url}{save_path}"
318
+ return JSONResponse({"path": path})
319
+
320
+
321
+ @app.get("/make_image")
322
+ # @app.post("/make_image")
323
+ def make_image(prompt: str, save_path: str = ""):
324
+ if Path(save_path).exists():
325
+ return FileResponse(save_path, media_type="image/png")
326
+ image = pipe(prompt=prompt).images[0]
327
+ if not save_path:
328
+ save_path = f"images/{prompt}.png"
329
+ image.save(save_path)
330
+ return FileResponse(save_path, media_type="image/png")
331
+
332
+
333
+ @app.get("/create_and_upload_image")
334
+ def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""):
335
+ path_components = save_path.split("/")[0:-1]
336
+ final_name = save_path.split("/")[-1]
337
+ if not path_components:
338
+ path_components = []
339
+ save_path = '/'.join(path_components) + quote_plus(final_name)
340
+ path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path)
341
+ return JSONResponse({"path": path})
342
+
343
+ @app.get("/inpaint_and_upload_image")
344
+ def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""):
345
+ path_components = save_path.split("/")[0:-1]
346
+ final_name = save_path.split("/")[-1]
347
+ if not path_components:
348
+ path_components = []
349
+ save_path = '/'.join(path_components) + quote_plus(final_name)
350
+ path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path)
351
+ return JSONResponse({"path": path})
352
+
353
+
354
+ def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str):
355
+ prompt = shorten_too_long_text(prompt)
356
+ save_path = shorten_too_long_text(save_path)
357
+ # check exists - todo cache this
358
+ if check_if_blob_exists(save_path):
359
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
360
+ bio = create_image_from_prompt(prompt, width, height)
361
+ if bio is None:
362
+ return None # error thrown in pool
363
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
364
+ return link
365
+ def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str):
366
+ prompt = shorten_too_long_text(prompt)
367
+ save_path = shorten_too_long_text(save_path)
368
+ # check exists - todo cache this
369
+ if check_if_blob_exists(save_path):
370
+ return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}"
371
+ bio = inpaint_image_from_prompt(prompt, image_url, mask_url)
372
+ if bio is None:
373
+ return None # error thrown in pool
374
+ link = upload_to_bucket(save_path, bio, is_bytesio=True)
375
+ return link
376
+
377
+ # multiprocessing.set_start_method('spawn', True)
378
+ # processes_pool = Pool(1) # cant do too much at once or OOM errors happen
379
+ # def create_image_from_prompt_sync(prompt):
380
+ # """have to call this sync to avoid OOM errors"""
381
+ # return processes_pool.apply_async(create_image_from_prompt, args=(prompt,), ).wait()
382
+
383
+ def create_image_from_prompt(prompt, width, height):
384
+ # round width and height down to multiple of 64
385
+ block_width = width - (width % 64)
386
+ block_height = height - (height % 64)
387
+ prompt = shorten_too_long_text(prompt)
388
+ # image = pipe(prompt=prompt).images[0]
389
+ try:
390
+ image = pipe(prompt=prompt,
391
+ width=block_width,
392
+ height=block_height,
393
+ # denoising_end=high_noise_frac,
394
+ # output_type='latent',
395
+ # height=512,
396
+ # width=512,
397
+ num_inference_steps=50).images[0] # normally uses 50 steps
398
+ except Exception as e:
399
+ # try rm stopwords + half the prompt
400
+ # todo try prompt permutations
401
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
402
+
403
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
404
+ prompts = prompt.split()
405
+
406
+ prompt = ' '.join(prompts[:len(prompts) // 2])
407
+ logger.info(f"shortened prompt to: {len(prompt)}")
408
+ image = None
409
+ if prompt:
410
+ try:
411
+ image = pipe(prompt=prompt,
412
+ width=block_width,
413
+ height=block_height,
414
+ # denoising_end=high_noise_frac,
415
+ # output_type='latent',
416
+ # height=512,
417
+ # width=512,
418
+ num_inference_steps=50).images[0] # normally uses 50 steps
419
+ except Exception as e:
420
+ # logger.info("trying to permute prompt")
421
+ # # try two swaps of the prompt/permutations
422
+ # prompt = prompt.split()
423
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
424
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
425
+
426
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
427
+ prompts = prompt.split()
428
+
429
+ prompt = ' '.join(prompts[:len(prompts) // 2])
430
+ logger.info(f"shortened prompt to: {len(prompt)}")
431
+
432
+ try:
433
+ image = pipe(prompt=prompt,
434
+ width=block_width,
435
+ height=block_height,
436
+ # denoising_end=high_noise_frac,
437
+ # output_type='latent', # dont need latent yet - we refine the image at full res
438
+ # height=512,
439
+ # width=512,
440
+ num_inference_steps=50).images[0] # normally uses 50 steps
441
+ except Exception as e:
442
+ # just error out
443
+ traceback.print_exc()
444
+ raise e
445
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
446
+ # todo fix device side asserts instead of restart to fix
447
+ # todo only restart the correct gunicorn
448
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
449
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
450
+ # os.system("kill -1 `pgrep gunicorn`")
451
+ # todo refine
452
+ # if image != None:
453
+ # image = refiner(
454
+ # prompt=prompt,
455
+ # # width=block_width,
456
+ # # height=block_height,
457
+ # num_inference_steps=n_steps,
458
+ # # denoising_start=high_noise_frac,
459
+ # image=image,
460
+ # ).images[0]
461
+ if width != block_width or height != block_height:
462
+ # resize to original size width/height
463
+ # find aspect ratio to scale up to that covers the original img input width/height
464
+ scale_up_ratio = max(width / block_width, height / block_height)
465
+ image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio)))
466
+ # crop image to original size
467
+ image = image.crop((0, 0, width, height))
468
+ # try:
469
+ # # gc.collect()
470
+ # torch.cuda.empty_cache()
471
+ # except Exception as e:
472
+ # traceback.print_exc()
473
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
474
+ # # todo fix device side asserts instead of restart to fix
475
+ # # todo only restart the correct gunicorn
476
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
477
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
478
+ # os.system("kill -1 `pgrep gunicorn`")
479
+ # save as bytesio
480
+ bs = BytesIO()
481
+
482
+ bright_count = np.sum(np.array(image) > 0)
483
+ if bright_count == 0:
484
+ # we have a black image, this is an error likely we need a restart
485
+ logger.info("restarting server to fix cuda issues (device side asserts)")
486
+ # # todo fix device side asserts instead of restart to fix
487
+ # # todo only restart the correct gunicorn
488
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
489
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
490
+ os.system("kill -1 `pgrep gunicorn`")
491
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
492
+ os.system("kill -1 `pgrep uvicorn`")
493
+
494
+ return None
495
+ image.save(bs, quality=85, optimize=True, format="webp")
496
+ bio = bs.getvalue()
497
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
498
+ with open("progress.txt", "w") as f:
499
+ current_time = datetime.now().strftime("%H:%M:%S")
500
+ f.write(f"{current_time}")
501
+ return bio
502
+
503
+ def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str):
504
+ prompt = shorten_too_long_text(prompt)
505
+ # image = pipe(prompt=prompt).images[0]
506
+
507
+ init_image = load_image(image_url).convert("RGB")
508
+ mask_image = load_image(mask_url).convert("RGB") # why rgb for a 1 channel mask?
509
+ num_inference_steps = 75
510
+ high_noise_frac = 0.7
511
+
512
+ try:
513
+ image = inpaintpipe(
514
+ prompt=prompt,
515
+ image=init_image,
516
+ mask_image=mask_image,
517
+ num_inference_steps=num_inference_steps,
518
+ denoising_start=high_noise_frac,
519
+ output_type="latent",
520
+ ).images[0] # normally uses 50 steps
521
+ except Exception as e:
522
+ # try rm stopwords + half the prompt
523
+ # todo try prompt permutations
524
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
525
+
526
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
527
+ prompts = prompt.split()
528
+
529
+ prompt = ' '.join(prompts[:len(prompts) // 2])
530
+ logger.info(f"shortened prompt to: {len(prompt)}")
531
+ image = None
532
+ if prompt:
533
+ try:
534
+ image = pipe(
535
+ prompt=prompt,
536
+ image=init_image,
537
+ mask_image=mask_image,
538
+ num_inference_steps=num_inference_steps,
539
+ denoising_start=high_noise_frac,
540
+ output_type="latent",
541
+ ).images[0] # normally uses 50 steps
542
+ except Exception as e:
543
+ # logger.info("trying to permute prompt")
544
+ # # try two swaps of the prompt/permutations
545
+ # prompt = prompt.split()
546
+ # prompt = ' '.join(permutations(prompt, 2).__next__())
547
+ logger.info(f"trying to shorten prompt of length {len(prompt)}")
548
+
549
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
550
+ prompts = prompt.split()
551
+
552
+ prompt = ' '.join(prompts[:len(prompts) // 2])
553
+ logger.info(f"shortened prompt to: {len(prompt)}")
554
+
555
+ try:
556
+ image = inpaintpipe(
557
+ prompt=prompt,
558
+ image=init_image,
559
+ mask_image=mask_image,
560
+ num_inference_steps=num_inference_steps,
561
+ denoising_start=high_noise_frac,
562
+ output_type="latent",
563
+ ).images[0] # normally uses 50 steps
564
+ except Exception as e:
565
+ # just error out
566
+ traceback.print_exc()
567
+ raise e
568
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
569
+ # todo fix device side asserts instead of restart to fix
570
+ # todo only restart the correct gunicorn
571
+ # this could be really annoying if your running other gunicorns on your machine which also get restarted
572
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
573
+ # os.system("kill -1 `pgrep gunicorn`")
574
+ if image != None:
575
+ image = inpaint_refiner(
576
+ prompt=prompt,
577
+ image=image,
578
+ mask_image=mask_image,
579
+ num_inference_steps=num_inference_steps,
580
+ denoising_start=high_noise_frac,
581
+
582
+ ).images[0]
583
+ # try:
584
+ # # gc.collect()
585
+ # torch.cuda.empty_cache()
586
+ # except Exception as e:
587
+ # traceback.print_exc()
588
+ # logger.info("restarting server to fix cuda issues (device side asserts)")
589
+ # # todo fix device side asserts instead of restart to fix
590
+ # # todo only restart the correct gunicorn
591
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
592
+ # os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
593
+ # os.system("kill -1 `pgrep gunicorn`")
594
+ # save as bytesio
595
+ bs = BytesIO()
596
+
597
+ bright_count = np.sum(np.array(image) > 0)
598
+ if bright_count == 0:
599
+ # we have a black image, this is an error likely we need a restart
600
+ logger.info("restarting server to fix cuda issues (device side asserts)")
601
+ # # todo fix device side asserts instead of restart to fix
602
+ # # todo only restart the correct gunicorn
603
+ # # this could be really annoying if your running other gunicorns on your machine which also get restarted
604
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`")
605
+ os.system("kill -1 `pgrep gunicorn`")
606
+ os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`")
607
+ os.system("kill -1 `pgrep uvicorn`")
608
+
609
+ return None
610
+ image.save(bs, quality=85, optimize=True, format="webp")
611
+ bio = bs.getvalue()
612
+ # touch progress.txt file - if we dont do this we get restarted by supervisor/other processes for reliability
613
+ with open("progress.txt", "w") as f:
614
+ current_time = datetime.now().strftime("%H:%M:%S")
615
+ f.write(f"{current_time}")
616
+ return bio
617
+
618
+
619
+
620
+ def shorten_too_long_text(prompt):
621
+ if len(prompt) > 200:
622
+ # remove stopwords
623
+ prompt = prompt.split() # todo also split hyphens
624
+ prompt = ' '.join((word for word in prompt if word not in stopwords))
625
+ if len(prompt) > 200:
626
+ prompt = prompt[:200]
627
+ return prompt
628
+
629
+ # image = pipe(prompt=prompt).images[0]
630
+ #
631
+ # image.save("test.png")
632
+ # # save all images
633
+ # for i, image in enumerate(images):
634
+ # image.save(f"{i}.png")
635
+
636
+