from compel import Compel, ReturnedEmbeddingsType import logging from abc import ABC import uuid import diffusers import torch from diffusers import StableDiffusionXLPipeline, DiffusionPipeline import numpy as np import threading import mmap import base64 from io import BytesIO from PIL import Image import numpy as np from tempfile import TemporaryFile from google.cloud import storage import sys import sentry_sdk from flask import Flask, request, jsonify import os from sequential_timer import SequentialTimer from safetensors.torch import load_file import copy import gc logger = logging.getLogger(__name__) logger.info("Diffusers version %s", diffusers.__version__) from axiom_logger import AxiomLogger axiom_logger = AxiomLogger() sentry_sdk.init( dsn="https://f750d1b039d66541f344ee6151d38166@o4505891057696768.ingest.sentry.io/4506071735205888", ) LORAS_DIR = './safetensors' lora_lock = threading.Lock() handler_lock = threading.Lock() handler_index = 0 # class LoraCache(): # def __init__(self, loras_dir: str = LORAS_DIR): # self.loras_dir = loras_dir # self.cache = {} # def load_lora(self, lora_name: str): # if lora_name.endswith('.safetensors'): # lora_name = lora_name.rstrip('.safetensors') # if lora_name not in self.cache: # lora = load_file(os.path.join(self.loras_dir, lora_name+'.safetensors')) # self.cache[lora_name] = lora # return copy.deepcopy(self.cache[lora_name]) # lora_cache = LoraCache() class DiffusersHandler(ABC): """ Diffusers handler class for text to image generation. """ def __init__(self): self.initialized = False self.req_id = None def initialize(self, properties): """In this initialize function, the Stable Diffusion model is loaded and initialized here. Args: ctx (context): It is a JSON Object containing information pertaining to the model artefacts parameters. """ logger.info("Loading diffusion model") logger.info("I'm totally new and updated") device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" self.device_str = device_str print("my device is " + device_str) self.device = torch.device(device_str) self.pipe = StableDiffusionXLPipeline.from_pretrained( sys.argv[1], torch_dtype=torch.float16, use_safetensors=True, ) # self.refiner = DiffusionPipeline.from_pretrained( # "stabilityai/stable-diffusion-xl-refiner-1.0", # text_encoder_2=self.pipe.text_encoder_2, # vae=self.pipe.vae, # torch_dtype=torch.float16, # use_safetensors=True, # variant="fp16", # ) # self.refiner.enable_model_cpu_offload(properties.get("gpu_id")) # logger.info("Refiner initialized and o") self.compel_base = Compel( tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2], text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True]) logger.info("Compel initialized") # self.compel_refiner = Compel( # tokenizer=[self.refiner.tokenizer_2], # text_encoder=[self.refiner.text_encoder_2], # returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # requires_pooled=[True]) logger.info("moving base model to device: %s", device_str) self.pipe.to(self.device) logger.info(self.device) logger.info("Diffusion model from path %s loaded successfully") axiom_logger.info("Diffusion model initialized", device=self.device_str) self.initialized = True def preprocess(self, raw_requests): """Basic text preprocessing, of the user's prompt. Args: requests (str): The Input data in the form of text is passed on to the preprocess function. Returns: list : The preprocess function returns a list of prompts. """ logger.info("Received requests: '%s'", raw_requests) self.working = True model_args = { "prompt": raw_requests[0]["prompt"], "negative_prompt": raw_requests[0].get("negative_prompt"), "width": raw_requests[0].get("width"), "height": raw_requests[0].get("height"), "num_inference_steps": raw_requests[0].get("num_inference_steps", 30), "guidance_scale": raw_requests[0].get("guidance_scale", 8.5) # "lora_weights": raw_requests[0].get("lora_name", None) # "cross_attention_kwargs": {"scale": raw_requests[0].get("lora_scale", 0.0)} } extra_args = { "seed": raw_requests[0].get("seed", None), "style_lora": raw_requests[0].get("style_lora", None), "style_scale": raw_requests[0].get("style_scale", 1.0), "char_lora": raw_requests[0].get("char_lora", None), "char_scale": raw_requests[0].get("char_scale", 1.0) } logger.info("Processed request: '%s'", model_args) axiom_logger.info("Processed request:" + str(model_args), request_id=self.req_id, device=self.device_str) return model_args, extra_args def inference(self, request): """Generates the image relevant to the received text. Args: inputs (list): List of Text from the pre-process function is passed here Returns: list : It returns a list of the generate images for the input text """ # Handling inference for sequence_classification. # compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True]) st = SequentialTimer() model_args, extra_args = request global lora_cache use_char_lora = extra_args['char_lora'] is not None use_style_lora = extra_args['style_lora'] is not None style_lora = extra_args['style_lora'] char_lora = extra_args['char_lora'] cross_attention_kwargs = {"scale": extra_args['char_scale'] if use_char_lora else extra_args['style_scale']} generator = torch.Generator(device="cuda").manual_seed(extra_args['seed']) if extra_args['seed'] else None self.prompt = model_args.pop("prompt") st.time("Base compel embedding") conditioning, pooled = self.compel_base(self.prompt) if use_style_lora: style_lora = os.path.join(LORAS_DIR, style_lora + '.safetensors') st.time("Load style lora") with lora_lock: self.pipe.load_lora_weights(style_lora) if use_char_lora: st.time("Fuse style lora into model") self.pipe.fuse_lora(lora_scale=extra_args['style_scale'], fuse_text_encoder=False) if use_char_lora: char_lora = os.path.join(LORAS_DIR, char_lora + '.safetensors') st.time('load character lora') with lora_lock: self.pipe.load_lora_weights(char_lora) # lora_weights = model_args.pop("lora_weights") # if lora_weights is not None: # lora_path = os.path.join(LORAS_DIR, lora_weights + '.safetensors') # logger.info('LOADING LORA FROM: ' + lora_path) # self.pipe.load_lora_weights(lora_path) # Handling inference for sequence_classification. st.time("base model inference") inferences = self.pipe( prompt_embeds=conditioning, pooled_prompt_embeds=pooled, generator=generator, cross_attention_kwargs=cross_attention_kwargs, **model_args ).images # if lora_weights is not None: # self.pipe.unload_lora_weights() if use_style_lora and use_char_lora: st.time("unfuse lora weights") self.pipe.unfuse_lora(unfuse_text_encoder=False) if use_style_lora or use_char_lora: st.time("unload lora weights") self.pipe.unload_lora_weights() st.time('end') # logger.info("Generated image: '%s'", inferences) axiom_logger.info("Generated images", request_id=self.req_id, device=self.device_str, timings=st.to_str()) return inferences def postprocess(self, inference_outputs): """Post Process Function converts the generated image into Torchserve readable format. Args: inference_outputs (list): It contains the generated image of the input text. Returns: (list): Returns a list of the images. """ bucket_name = "outputs-storage-prod" client = storage.Client() self.working = False bucket = client.get_bucket(bucket_name) outputs = [] for image in inference_outputs: image_name = str(uuid.uuid4()) blob = bucket.blob(image_name + '.png') with TemporaryFile() as tmp: image.save(tmp, format="png") tmp.seek(0) blob.upload_from_file(tmp, content_type='image/png') # generate txt file with the image name and the prompt inside # blob = bucket.blob(image_name + '.txt') # blob.upload_from_string(self.prompt) url_name = 'https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png' outputs.append(url_name) axiom_logger.info("Pushed image to google cloud: "+ url_name, request_id=self.req_id, device=self.device_str) return outputs app = Flask(__name__) # Initialize the handler on startup gpu_count = torch.cuda.device_count() if gpu_count == 0: raise ValueError("No GPUs available!") handlers = [DiffusersHandler() for i in range(gpu_count)] for i in range(gpu_count): handlers[i].initialize({"gpu_id": i}) @app.route('/generate', methods=['POST']) def generate_image(): req_id = str(uuid.uuid4()) global handler_index selected_handler = None try: # Extract raw requests from HTTP POST body raw_requests = request.json axiom_logger.info(message="Received request", request_id=req_id, **raw_requests) with handler_lock: if handler_index == 0: gc.collect() selected_handler = handlers[handler_index] handler_index = (handler_index + 1) % gpu_count # Rotate to the next handler selected_handler.req_id = req_id processed_request = selected_handler.preprocess([raw_requests]) inferences = selected_handler.inference(processed_request) outputs = selected_handler.postprocess(inferences) selected_handler.req_id = None return jsonify({"image_urls": outputs}) except Exception as e: logger.error("Error during image generation: %s", str(e)) axiom_logger.critical("Error during image generation: " + str(e), request_id=req_id, device=selected_handler.device_str) return jsonify({"error": "Failed to generate image", "details": str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=3000, threaded=True)