|
from compel import Compel, ReturnedEmbeddingsType |
|
import logging |
|
from abc import ABC |
|
|
|
import diffusers |
|
import torch |
|
from diffusers import StableDiffusionXLPipeline |
|
|
|
import numpy as np |
|
import threading |
|
|
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import numpy as np |
|
import uuid |
|
from tempfile import TemporaryFile |
|
from google.cloud import storage |
|
import sys |
|
from flask import Flask, request, jsonify |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info("Diffusers version %s", diffusers.__version__) |
|
|
|
class DiffusersHandler(ABC): |
|
""" |
|
Diffusers handler class for text to image generation. |
|
""" |
|
|
|
def __init__(self): |
|
self.initialized = False |
|
|
|
def initialize(self, properties): |
|
"""In this initialize function, the Stable Diffusion model is loaded and |
|
initialized here. |
|
Args: |
|
ctx (context): It is a JSON Object containing information |
|
pertaining to the model artefacts parameters. |
|
""" |
|
|
|
logger.info("Loading diffusion model") |
|
logger.info("I'm totally new and updated") |
|
|
|
|
|
device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" |
|
|
|
print("my device is " + device_str) |
|
self.device = torch.device(device_str) |
|
self.pipe = StableDiffusionXLPipeline.from_pretrained( |
|
sys.argv[1], |
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
) |
|
|
|
logger.info("moving model to device: %s", device_str) |
|
self.pipe.to(self.device) |
|
|
|
logger.info(self.device) |
|
logger.info("Diffusion model from path %s loaded successfully") |
|
|
|
self.initialized = True |
|
|
|
def preprocess(self, raw_requests): |
|
"""Basic text preprocessing, of the user's prompt. |
|
Args: |
|
requests (str): The Input data in the form of text is passed on to the preprocess |
|
function. |
|
Returns: |
|
list : The preprocess function returns a list of prompts. |
|
""" |
|
logger.info("Received requests: '%s'", raw_requests) |
|
self.working = True |
|
|
|
processed_request = { |
|
"prompt": raw_requests[0]["prompt"], |
|
"negative_prompt": raw_requests[0].get("negative_prompt"), |
|
"width": raw_requests[0].get("width"), |
|
"height": raw_requests[0].get("height"), |
|
"num_inference_steps": raw_requests[0].get("num_inference_steps", 30), |
|
"guidance_scale": raw_requests[0].get("guidance_scale", 7.5), |
|
} |
|
|
|
logger.info("Processed request: '%s'", processed_request) |
|
return processed_request |
|
|
|
|
|
def inference(self, request): |
|
"""Generates the image relevant to the received text. |
|
Args: |
|
inputs (list): List of Text from the pre-process function is passed here |
|
Returns: |
|
list : It returns a list of the generate images for the input text |
|
""" |
|
|
|
|
|
compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True]) |
|
|
|
self.prompt = request.pop("prompt") |
|
conditioning, pooled = compel(self.prompt) |
|
|
|
|
|
inferences = self.pipe( |
|
prompt_embeds=conditioning, |
|
pooled_prompt_embeds=pooled, |
|
**request |
|
).images |
|
|
|
logger.info("Generated image: '%s'", inferences) |
|
return inferences |
|
|
|
def postprocess(self, inference_outputs): |
|
"""Post Process Function converts the generated image into Torchserve readable format. |
|
Args: |
|
inference_outputs (list): It contains the generated image of the input text. |
|
Returns: |
|
(list): Returns a list of the images. |
|
""" |
|
bucket_name = "outputs-storage-prod" |
|
client = storage.Client() |
|
self.working = False |
|
bucket = client.get_bucket(bucket_name) |
|
outputs = [] |
|
for image in inference_outputs: |
|
image_name = str(uuid.uuid4()) |
|
|
|
blob = bucket.blob(image_name + '.png') |
|
|
|
with TemporaryFile() as tmp: |
|
image.save(tmp, format="png") |
|
tmp.seek(0) |
|
blob.upload_from_file(tmp, content_type='image/png') |
|
|
|
|
|
|
|
|
|
|
|
outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png') |
|
return outputs |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
gpu_count = torch.cuda.device_count() |
|
if gpu_count == 0: |
|
raise ValueError("No GPUs available!") |
|
|
|
handlers = [DiffusersHandler() for i in range(gpu_count)] |
|
for i in range(gpu_count): |
|
handlers[i].initialize({"gpu_id": i}) |
|
|
|
handler_lock = threading.Lock() |
|
handler_index = 0 |
|
|
|
@app.route('/generate', methods=['POST']) |
|
def generate_image(): |
|
global handler_index |
|
try: |
|
|
|
raw_requests = request.json |
|
|
|
with handler_lock: |
|
selected_handler = handlers[handler_index] |
|
handler_index = (handler_index + 1) % gpu_count |
|
|
|
processed_request = selected_handler.preprocess([raw_requests]) |
|
inferences = selected_handler.inference(processed_request) |
|
outputs = selected_handler.postprocess(inferences) |
|
|
|
return jsonify({"image_urls": outputs}) |
|
except Exception as e: |
|
logger.error("Error during image generation: %s", str(e)) |
|
return jsonify({"error": "Failed to generate image", "details": str(e)}), 500 |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=3000, threaded=True) |