|
import gc |
|
import math |
|
import multiprocessing |
|
import os |
|
import traceback |
|
from datetime import datetime |
|
from io import BytesIO |
|
from itertools import permutations |
|
from multiprocessing.pool import Pool |
|
from pathlib import Path |
|
from urllib.parse import quote_plus |
|
|
|
import numpy as np |
|
import nltk |
|
import torch |
|
from PIL.Image import Image |
|
from diffusers import DiffusionPipeline, StableDiffusionXLInpaintPipeline |
|
from diffusers.utils import load_image |
|
from fastapi import FastAPI |
|
from fastapi.middleware.gzip import GZipMiddleware |
|
from loguru import logger |
|
from starlette.middleware.cors import CORSMiddleware |
|
from starlette.responses import FileResponse |
|
from starlette.responses import JSONResponse |
|
|
|
from env import BUCKET_PATH, BUCKET_NAME |
|
from stable_diffusion_server.bucket_api import check_if_blob_exists, upload_to_bucket |
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
"models/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=True, |
|
variant="fp16", |
|
|
|
) |
|
pipe.watermark = None |
|
|
|
pipe.to("cuda") |
|
|
|
refiner = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-refiner-1.0", |
|
text_encoder_2=pipe.text_encoder_2, |
|
vae=pipe.vae, |
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=True, |
|
variant="fp16", |
|
) |
|
refiner.watermark = None |
|
refiner.to("cuda") |
|
|
|
|
|
inpaintpipe = StableDiffusionXLInpaintPipeline.from_pretrained( |
|
"models/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, variant="fp16", use_safetensors=True, |
|
scheduler=pipe.scheduler, |
|
text_encoder=pipe.text_encoder, |
|
text_encoder_2=pipe.text_encoder_2, |
|
tokenizer=pipe.tokenizer, |
|
tokenizer_2=pipe.tokenizer_2, |
|
unet=pipe.unet, |
|
vae=pipe.vae, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inpaintpipe.to("cuda") |
|
inpaintpipe.watermark = None |
|
|
|
|
|
inpaint_refiner = StableDiffusionXLInpaintPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-refiner-1.0", |
|
text_encoder_2=inpaintpipe.text_encoder_2, |
|
vae=inpaintpipe.vae, |
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=True, |
|
variant="fp16", |
|
|
|
tokenizer_2=refiner.tokenizer_2, |
|
tokenizer=refiner.tokenizer, |
|
scheduler=refiner.scheduler, |
|
text_encoder=refiner.text_encoder, |
|
unet=refiner.unet, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inpaint_refiner.to("cuda") |
|
inpaint_refiner.watermark = None |
|
|
|
|
|
n_steps = 40 |
|
high_noise_frac = 0.8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe.unet = torch.compile(pipe.unet) |
|
refiner.unet = torch.compile(refiner.unet) |
|
|
|
inpaintpipe.unet = pipe.unet |
|
inpaint_refiner.unet = refiner.unet |
|
|
|
|
|
|
|
app = FastAPI( |
|
openapi_url="/static/openapi.json", |
|
docs_url="/swagger-docs", |
|
redoc_url="/redoc", |
|
title="Generate Images Netwrck API", |
|
description="Character Chat API", |
|
|
|
version="1", |
|
) |
|
app.add_middleware(GZipMiddleware, minimum_size=1000) |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
stopwords = nltk.corpus.stopwords.words("english") |
|
|
|
|
|
@app.get("/make_image") |
|
def make_image(prompt: str, save_path: str = ""): |
|
if Path(save_path).exists(): |
|
return FileResponse(save_path, media_type="image/png") |
|
image = pipe(prompt=prompt).images[0] |
|
if not save_path: |
|
save_path = f"images/{prompt}.png" |
|
image.save(save_path) |
|
return FileResponse(save_path, media_type="image/png") |
|
|
|
|
|
@app.get("/create_and_upload_image") |
|
def create_and_upload_image(prompt: str, width: int=1024, height:int=1024, save_path: str = ""): |
|
path_components = save_path.split("/")[0:-1] |
|
final_name = save_path.split("/")[-1] |
|
if not path_components: |
|
path_components = [] |
|
save_path = '/'.join(path_components) + quote_plus(final_name) |
|
path = get_image_or_create_upload_to_cloud_storage(prompt, width, height, save_path) |
|
return JSONResponse({"path": path}) |
|
|
|
@app.get("/inpaint_and_upload_image") |
|
def inpaint_and_upload_image(prompt: str, image_url:str, mask_url:str, save_path: str = ""): |
|
path_components = save_path.split("/")[0:-1] |
|
final_name = save_path.split("/")[-1] |
|
if not path_components: |
|
path_components = [] |
|
save_path = '/'.join(path_components) + quote_plus(final_name) |
|
path = get_image_or_inpaint_upload_to_cloud_storage(prompt, image_url, mask_url, save_path) |
|
return JSONResponse({"path": path}) |
|
|
|
|
|
def get_image_or_create_upload_to_cloud_storage(prompt:str,width:int, height:int, save_path:str): |
|
prompt = shorten_too_long_text(prompt) |
|
save_path = shorten_too_long_text(save_path) |
|
|
|
if check_if_blob_exists(save_path): |
|
return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}" |
|
bio = create_image_from_prompt(prompt, width, height) |
|
if bio is None: |
|
return None |
|
link = upload_to_bucket(save_path, bio, is_bytesio=True) |
|
return link |
|
def get_image_or_inpaint_upload_to_cloud_storage(prompt:str, image_url:str, mask_url:str, save_path:str): |
|
prompt = shorten_too_long_text(prompt) |
|
save_path = shorten_too_long_text(save_path) |
|
|
|
if check_if_blob_exists(save_path): |
|
return f"https://{BUCKET_NAME}/{BUCKET_PATH}/{save_path}" |
|
bio = inpaint_image_from_prompt(prompt, image_url, mask_url) |
|
if bio is None: |
|
return None |
|
link = upload_to_bucket(save_path, bio, is_bytesio=True) |
|
return link |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_image_from_prompt(prompt, width, height): |
|
|
|
block_width = width - (width % 64) |
|
block_height = height - (height % 64) |
|
prompt = shorten_too_long_text(prompt) |
|
|
|
try: |
|
image = pipe(prompt=prompt, |
|
width=block_width, |
|
height=block_height, |
|
|
|
|
|
|
|
|
|
num_inference_steps=50).images[0] |
|
except Exception as e: |
|
|
|
|
|
logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
|
prompt = ' '.join((word for word in prompt if word not in stopwords)) |
|
prompts = prompt.split() |
|
|
|
prompt = ' '.join(prompts[:len(prompts) // 2]) |
|
logger.info(f"shortened prompt to: {len(prompt)}") |
|
image = None |
|
if prompt: |
|
try: |
|
image = pipe(prompt=prompt, |
|
width=block_width, |
|
height=block_height, |
|
|
|
|
|
|
|
|
|
num_inference_steps=50).images[0] |
|
except Exception as e: |
|
|
|
|
|
|
|
|
|
logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
|
prompt = ' '.join((word for word in prompt if word not in stopwords)) |
|
prompts = prompt.split() |
|
|
|
prompt = ' '.join(prompts[:len(prompts) // 2]) |
|
logger.info(f"shortened prompt to: {len(prompt)}") |
|
|
|
try: |
|
image = pipe(prompt=prompt, |
|
width=block_width, |
|
height=block_height, |
|
|
|
|
|
|
|
|
|
num_inference_steps=50).images[0] |
|
except Exception as e: |
|
|
|
traceback.print_exc() |
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if width != block_width or height != block_height: |
|
|
|
|
|
scale_up_ratio = max(width / block_width, height / block_height) |
|
image = image.resize((math.ceil(block_width * scale_up_ratio), math.ceil(height * scale_up_ratio))) |
|
|
|
image = image.crop((0, 0, width, height)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bs = BytesIO() |
|
|
|
bright_count = np.sum(np.array(image) > 0) |
|
if bright_count == 0: |
|
|
|
logger.info("restarting server to fix cuda issues (device side asserts)") |
|
|
|
|
|
|
|
os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") |
|
os.system("kill -1 `pgrep gunicorn`") |
|
os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`") |
|
os.system("kill -1 `pgrep uvicorn`") |
|
|
|
return None |
|
image.save(bs, quality=85, optimize=True, format="webp") |
|
bio = bs.getvalue() |
|
|
|
with open("progress.txt", "w") as f: |
|
current_time = datetime.now().strftime("%H:%M:%S") |
|
f.write(f"{current_time}") |
|
return bio |
|
|
|
def inpaint_image_from_prompt(prompt, image_url: str, mask_url: str): |
|
prompt = shorten_too_long_text(prompt) |
|
|
|
|
|
init_image = load_image(image_url).convert("RGB") |
|
mask_image = load_image(mask_url).convert("RGB") |
|
num_inference_steps = 75 |
|
high_noise_frac = 0.7 |
|
|
|
try: |
|
image = inpaintpipe( |
|
prompt=prompt, |
|
image=init_image, |
|
mask_image=mask_image, |
|
num_inference_steps=num_inference_steps, |
|
denoising_start=high_noise_frac, |
|
output_type="latent", |
|
).images[0] |
|
except Exception as e: |
|
|
|
|
|
logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
|
prompt = ' '.join((word for word in prompt if word not in stopwords)) |
|
prompts = prompt.split() |
|
|
|
prompt = ' '.join(prompts[:len(prompts) // 2]) |
|
logger.info(f"shortened prompt to: {len(prompt)}") |
|
image = None |
|
if prompt: |
|
try: |
|
image = pipe( |
|
prompt=prompt, |
|
image=init_image, |
|
mask_image=mask_image, |
|
num_inference_steps=num_inference_steps, |
|
denoising_start=high_noise_frac, |
|
output_type="latent", |
|
).images[0] |
|
except Exception as e: |
|
|
|
|
|
|
|
|
|
logger.info(f"trying to shorten prompt of length {len(prompt)}") |
|
|
|
prompt = ' '.join((word for word in prompt if word not in stopwords)) |
|
prompts = prompt.split() |
|
|
|
prompt = ' '.join(prompts[:len(prompts) // 2]) |
|
logger.info(f"shortened prompt to: {len(prompt)}") |
|
|
|
try: |
|
image = inpaintpipe( |
|
prompt=prompt, |
|
image=init_image, |
|
mask_image=mask_image, |
|
num_inference_steps=num_inference_steps, |
|
denoising_start=high_noise_frac, |
|
output_type="latent", |
|
).images[0] |
|
except Exception as e: |
|
|
|
traceback.print_exc() |
|
raise e |
|
|
|
|
|
|
|
|
|
|
|
|
|
if image != None: |
|
image = inpaint_refiner( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask_image, |
|
num_inference_steps=num_inference_steps, |
|
denoising_start=high_noise_frac, |
|
|
|
).images[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bs = BytesIO() |
|
|
|
bright_count = np.sum(np.array(image) > 0) |
|
if bright_count == 0: |
|
|
|
logger.info("restarting server to fix cuda issues (device side asserts)") |
|
|
|
|
|
|
|
os.system("/usr/bin/bash kill -SIGHUP `pgrep gunicorn`") |
|
os.system("kill -1 `pgrep gunicorn`") |
|
os.system("/usr/bin/bash kill -SIGHUP `pgrep uvicorn`") |
|
os.system("kill -1 `pgrep uvicorn`") |
|
|
|
return None |
|
image.save(bs, quality=85, optimize=True, format="webp") |
|
bio = bs.getvalue() |
|
|
|
with open("progress.txt", "w") as f: |
|
current_time = datetime.now().strftime("%H:%M:%S") |
|
f.write(f"{current_time}") |
|
return bio |
|
|
|
|
|
|
|
def shorten_too_long_text(prompt): |
|
if len(prompt) > 200: |
|
|
|
prompt = prompt.split() |
|
prompt = ' '.join((word for word in prompt if word not in stopwords)) |
|
if len(prompt) > 200: |
|
prompt = prompt[:200] |
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|