|
import logging |
|
from abc import ABC |
|
|
|
import diffusers |
|
import torch |
|
from diffusers import StableDiffusionXLPipeline |
|
|
|
from ts.torch_handler.base_handler import BaseHandler |
|
import numpy as np |
|
|
|
import base64 |
|
from io import BytesIO |
|
from PIL import Image |
|
import numpy as np |
|
import uuid |
|
from tempfile import TemporaryFile |
|
from google.cloud import storage |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info("Diffusers version %s", diffusers.__version__) |
|
|
|
class DiffusersHandler(BaseHandler, ABC): |
|
""" |
|
Diffusers handler class for text to image generation. |
|
""" |
|
|
|
def __init__(self): |
|
self.initialized = False |
|
|
|
def initialize(self, ctx): |
|
"""In this initialize function, the Stable Diffusion model is loaded and |
|
initialized here. |
|
Args: |
|
ctx (context): It is a JSON Object containing information |
|
pertaining to the model artefacts parameters. |
|
""" |
|
|
|
logger.info("Loading diffusion model") |
|
logger.info("I'm totally new and updated") |
|
|
|
self.manifest = ctx.manifest |
|
properties = ctx.system_properties |
|
model_dir = properties.get("model_dir") |
|
|
|
device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" |
|
|
|
self.device = torch.device(device_str) |
|
self.pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"./", |
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
) |
|
|
|
logger.info("moving model to device: %s", device_str) |
|
self.pipe.to(self.device) |
|
|
|
logger.info(self.device) |
|
logger.info("Diffusion model from path %s loaded successfully", model_dir) |
|
|
|
self.initialized = True |
|
|
|
def preprocess(self, raw_requests): |
|
"""Basic text preprocessing, of the user's prompt. |
|
Args: |
|
requests (str): The Input data in the form of text is passed on to the preprocess |
|
function. |
|
Returns: |
|
list : The preprocess function returns a list of prompts. |
|
""" |
|
logger.info("Received requests: '%s'", raw_requests) |
|
|
|
processed_request = { |
|
"prompt": raw_requests[0]["prompt"], |
|
"negative_prompt": raw_requests[0].get("negative_prompt"), |
|
"width": raw_requests[0].get("width"), |
|
"height": raw_requests[0].get("height"), |
|
"num_inference_steps": raw_requests[0].get("num_inference_steps", 30), |
|
"guidance_scale": raw_requests[0].get("guidance_scale", 7.5), |
|
} |
|
|
|
logger.info("Processed request: '%s'", processed_request) |
|
return processed_request |
|
|
|
|
|
def inference(self, request): |
|
"""Generates the image relevant to the received text. |
|
Args: |
|
inputs (list): List of Text from the pre-process function is passed here |
|
Returns: |
|
list : It returns a list of the generate images for the input text |
|
""" |
|
|
|
|
|
inferences = self.pipe( |
|
**request |
|
).images |
|
|
|
logger.info("Generated image: '%s'", inferences) |
|
return inferences |
|
|
|
def postprocess(self, inference_outputs): |
|
"""Post Process Function converts the generated image into Torchserve readable format. |
|
Args: |
|
inference_outputs (list): It contains the generated image of the input text. |
|
Returns: |
|
(list): Returns a list of the images. |
|
""" |
|
bucket_name = "outputs-storage-prod" |
|
client = storage.Client() |
|
bucket = client.get_bucket(bucket_name) |
|
outputs = [] |
|
for image in inference_outputs: |
|
image_name = str(uuid.uuid4()) |
|
|
|
blob = bucket.blob(image_name + '.png') |
|
|
|
with TemporaryFile() as tmp: |
|
image.save(tmp, format="png") |
|
tmp.seek(0) |
|
blob.upload_from_file(tmp, content_type='image/png') |
|
|
|
|
|
|
|
|
|
|
|
outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png') |
|
return outputs |
|
|