|
from compel import Compel, ReturnedEmbeddingsType |
|
import logging |
|
from abc import ABC |
|
import uuid |
|
|
|
import diffusers |
|
import torch |
|
from diffusers import StableDiffusionXLPipeline, DiffusionPipeline |
|
|
|
import numpy as np |
|
import threading |
|
import mmap |
|
|
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import numpy as np |
|
from tempfile import TemporaryFile |
|
from google.cloud import storage |
|
import sys |
|
import sentry_sdk |
|
from flask import Flask, request, jsonify |
|
import os |
|
from sequential_timer import SequentialTimer |
|
from safetensors.torch import load_file |
|
import copy |
|
import gc |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info("Diffusers version %s", diffusers.__version__) |
|
|
|
from axiom_logger import AxiomLogger |
|
axiom_logger = AxiomLogger() |
|
|
|
sentry_sdk.init( |
|
dsn="https://f750d1b039d66541f344ee6151d38166@o4505891057696768.ingest.sentry.io/4506071735205888", |
|
) |
|
|
|
LORAS_DIR = './safetensors' |
|
|
|
lora_lock = threading.Lock() |
|
|
|
handler_lock = threading.Lock() |
|
handler_index = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiffusersHandler(ABC): |
|
""" |
|
Diffusers handler class for text to image generation. |
|
""" |
|
|
|
def __init__(self): |
|
self.initialized = False |
|
self.req_id = None |
|
|
|
def initialize(self, properties): |
|
"""In this initialize function, the Stable Diffusion model is loaded and |
|
initialized here. |
|
Args: |
|
ctx (context): It is a JSON Object containing information |
|
pertaining to the model artefacts parameters. |
|
""" |
|
|
|
logger.info("Loading diffusion model") |
|
logger.info("I'm totally new and updated") |
|
|
|
|
|
device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" |
|
self.device_str = device_str |
|
|
|
print("my device is " + device_str) |
|
self.device = torch.device(device_str) |
|
self.pipe = StableDiffusionXLPipeline.from_pretrained( |
|
sys.argv[1], |
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.compel_base = Compel( |
|
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2], |
|
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], |
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, |
|
requires_pooled=[False, True]) |
|
logger.info("Compel initialized") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("moving base model to device: %s", device_str) |
|
self.pipe.to(self.device) |
|
|
|
logger.info(self.device) |
|
logger.info("Diffusion model from path %s loaded successfully") |
|
axiom_logger.info("Diffusion model initialized", device=self.device_str) |
|
|
|
self.initialized = True |
|
|
|
def preprocess(self, raw_requests): |
|
"""Basic text preprocessing, of the user's prompt. |
|
Args: |
|
requests (str): The Input data in the form of text is passed on to the preprocess |
|
function. |
|
Returns: |
|
list : The preprocess function returns a list of prompts. |
|
""" |
|
logger.info("Received requests: '%s'", raw_requests) |
|
self.working = True |
|
|
|
model_args = { |
|
"prompt": raw_requests[0]["prompt"], |
|
"negative_prompt": raw_requests[0].get("negative_prompt"), |
|
"width": raw_requests[0].get("width"), |
|
"height": raw_requests[0].get("height"), |
|
"num_inference_steps": raw_requests[0].get("num_inference_steps", 30), |
|
"guidance_scale": raw_requests[0].get("guidance_scale", 8.5) |
|
|
|
|
|
} |
|
|
|
extra_args = { |
|
"seed": raw_requests[0].get("seed", None), |
|
"style_lora": raw_requests[0].get("style_lora", None), |
|
"style_scale": raw_requests[0].get("style_scale", 1.0), |
|
"char_lora": raw_requests[0].get("char_lora", None), |
|
"char_scale": raw_requests[0].get("char_scale", 1.0) |
|
} |
|
|
|
|
|
|
|
logger.info("Processed request: '%s'", model_args) |
|
axiom_logger.info("Processed request:" + str(model_args), request_id=self.req_id, device=self.device_str) |
|
return model_args, extra_args |
|
|
|
|
|
def inference(self, request): |
|
"""Generates the image relevant to the received text. |
|
Args: |
|
inputs (list): List of Text from the pre-process function is passed here |
|
Returns: |
|
list : It returns a list of the generate images for the input text |
|
""" |
|
|
|
|
|
|
|
st = SequentialTimer() |
|
model_args, extra_args = request |
|
global lora_cache |
|
|
|
use_char_lora = extra_args['char_lora'] is not None |
|
use_style_lora = extra_args['style_lora'] is not None |
|
|
|
|
|
style_lora = extra_args['style_lora'] |
|
char_lora = extra_args['char_lora'] |
|
|
|
cross_attention_kwargs = {"scale": extra_args['char_scale'] if use_char_lora else extra_args['style_scale']} |
|
|
|
generator = torch.Generator(device="cuda").manual_seed(extra_args['seed']) if extra_args['seed'] else None |
|
|
|
|
|
self.prompt = model_args.pop("prompt") |
|
|
|
st.time("Base compel embedding") |
|
conditioning, pooled = self.compel_base(self.prompt) |
|
|
|
if use_style_lora: |
|
style_lora = os.path.join(LORAS_DIR, style_lora + '.safetensors') |
|
st.time("Load style lora") |
|
with lora_lock: |
|
self.pipe.load_lora_weights(style_lora) |
|
if use_char_lora: |
|
st.time("Fuse style lora into model") |
|
self.pipe.fuse_lora(lora_scale=extra_args['style_scale'], fuse_text_encoder=False) |
|
|
|
if use_char_lora: |
|
char_lora = os.path.join(LORAS_DIR, char_lora + '.safetensors') |
|
st.time('load character lora') |
|
with lora_lock: |
|
self.pipe.load_lora_weights(char_lora) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.time("base model inference") |
|
inferences = self.pipe( |
|
prompt_embeds=conditioning, |
|
pooled_prompt_embeds=pooled, |
|
generator=generator, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
**model_args |
|
).images |
|
|
|
|
|
|
|
if use_style_lora and use_char_lora: |
|
st.time("unfuse lora weights") |
|
self.pipe.unfuse_lora(unfuse_text_encoder=False) |
|
|
|
if use_style_lora or use_char_lora: |
|
st.time("unload lora weights") |
|
self.pipe.unload_lora_weights() |
|
|
|
st.time('end') |
|
|
|
|
|
axiom_logger.info("Generated images", request_id=self.req_id, device=self.device_str, timings=st.to_str()) |
|
return inferences |
|
|
|
def postprocess(self, inference_outputs): |
|
"""Post Process Function converts the generated image into Torchserve readable format. |
|
Args: |
|
inference_outputs (list): It contains the generated image of the input text. |
|
Returns: |
|
(list): Returns a list of the images. |
|
""" |
|
bucket_name = "outputs-storage-prod" |
|
client = storage.Client() |
|
self.working = False |
|
bucket = client.get_bucket(bucket_name) |
|
outputs = [] |
|
for image in inference_outputs: |
|
image_name = str(uuid.uuid4()) |
|
|
|
blob = bucket.blob(image_name + '.png') |
|
|
|
with TemporaryFile() as tmp: |
|
image.save(tmp, format="png") |
|
tmp.seek(0) |
|
blob.upload_from_file(tmp, content_type='image/png') |
|
|
|
|
|
|
|
|
|
url_name = 'https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png' |
|
outputs.append(url_name) |
|
axiom_logger.info("Pushed image to google cloud: "+ url_name, request_id=self.req_id, device=self.device_str) |
|
return outputs |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
gpu_count = torch.cuda.device_count() |
|
if gpu_count == 0: |
|
raise ValueError("No GPUs available!") |
|
|
|
handlers = [DiffusersHandler() for i in range(gpu_count)] |
|
for i in range(gpu_count): |
|
handlers[i].initialize({"gpu_id": i}) |
|
|
|
|
|
|
|
|
|
|
|
@app.route('/generate', methods=['POST']) |
|
def generate_image(): |
|
req_id = str(uuid.uuid4()) |
|
global handler_index |
|
selected_handler = None |
|
try: |
|
|
|
raw_requests = request.json |
|
axiom_logger.info(message="Received request", request_id=req_id, **raw_requests) |
|
|
|
with handler_lock: |
|
if handler_index == 0: |
|
gc.collect() |
|
selected_handler = handlers[handler_index] |
|
handler_index = (handler_index + 1) % gpu_count |
|
selected_handler.req_id = req_id |
|
|
|
processed_request = selected_handler.preprocess([raw_requests]) |
|
inferences = selected_handler.inference(processed_request) |
|
outputs = selected_handler.postprocess(inferences) |
|
selected_handler.req_id = None |
|
return jsonify({"image_urls": outputs}) |
|
except Exception as e: |
|
logger.error("Error during image generation: %s", str(e)) |
|
axiom_logger.critical("Error during image generation: " + str(e), request_id=req_id, device=selected_handler.device_str) |
|
return jsonify({"error": "Failed to generate image", "details": str(e)}), 500 |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=3000, threaded=True) |
|
|