zww
commited on
Commit
·
4929721
1
Parent(s):
0728495
gcs upload in handler
Browse files- stable_diffusion_handler.py +50 -17
stable_diffusion_handler.py
CHANGED
@@ -8,6 +8,13 @@ from diffusers import StableDiffusionXLPipeline
|
|
8 |
from ts.torch_handler.base_handler import BaseHandler
|
9 |
import numpy as np
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
logger = logging.getLogger(__name__)
|
13 |
logger.info("Diffusers version %s", diffusers.__version__)
|
@@ -37,17 +44,21 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
37 |
device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
|
38 |
|
39 |
self.device = torch.device(device_str)
|
40 |
-
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
|
|
|
|
|
|
|
41 |
|
42 |
logger.info("moving model to device: %s", device_str)
|
43 |
self.pipe.to(self.device)
|
44 |
-
|
45 |
logger.info(self.device)
|
46 |
logger.info("Diffusion model from path %s loaded successfully", model_dir)
|
47 |
|
48 |
self.initialized = True
|
49 |
|
50 |
-
def preprocess(self,
|
51 |
"""Basic text preprocessing, of the user's prompt.
|
52 |
Args:
|
53 |
requests (str): The Input data in the form of text is passed on to the preprocess
|
@@ -55,16 +66,22 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
55 |
Returns:
|
56 |
list : The preprocess function returns a list of prompts.
|
57 |
"""
|
58 |
-
logger.info("Received requests: '%s'",
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
def inference(self,
|
68 |
"""Generates the image relevant to the received text.
|
69 |
Args:
|
70 |
inputs (list): List of Text from the pre-process function is passed here
|
@@ -74,7 +91,7 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
74 |
|
75 |
# Handling inference for sequence_classification.
|
76 |
inferences = self.pipe(
|
77 |
-
|
78 |
).images
|
79 |
|
80 |
logger.info("Generated image: '%s'", inferences)
|
@@ -87,7 +104,23 @@ class DiffusersHandler(BaseHandler, ABC):
|
|
87 |
Returns:
|
88 |
(list): Returns a list of the images.
|
89 |
"""
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from ts.torch_handler.base_handler import BaseHandler
|
9 |
import numpy as np
|
10 |
|
11 |
+
import base64
|
12 |
+
from io import BytesIO
|
13 |
+
from PIL import Image
|
14 |
+
import numpy as np
|
15 |
+
import uuid
|
16 |
+
from tempfile import TemporaryFile
|
17 |
+
from google.cloud import storage
|
18 |
|
19 |
logger = logging.getLogger(__name__)
|
20 |
logger.info("Diffusers version %s", diffusers.__version__)
|
|
|
44 |
device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
|
45 |
|
46 |
self.device = torch.device(device_str)
|
47 |
+
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
48 |
+
"./",
|
49 |
+
torch_dtype=torch.float16,
|
50 |
+
use_safetensors=True,
|
51 |
+
)
|
52 |
|
53 |
logger.info("moving model to device: %s", device_str)
|
54 |
self.pipe.to(self.device)
|
55 |
+
|
56 |
logger.info(self.device)
|
57 |
logger.info("Diffusion model from path %s loaded successfully", model_dir)
|
58 |
|
59 |
self.initialized = True
|
60 |
|
61 |
+
def preprocess(self, raw_requests):
|
62 |
"""Basic text preprocessing, of the user's prompt.
|
63 |
Args:
|
64 |
requests (str): The Input data in the form of text is passed on to the preprocess
|
|
|
66 |
Returns:
|
67 |
list : The preprocess function returns a list of prompts.
|
68 |
"""
|
69 |
+
logger.info("Received requests: '%s'", raw_requests)
|
70 |
+
|
71 |
+
processed_request = {
|
72 |
+
"prompt": raw_requests[0]["prompt"],
|
73 |
+
"negative_prompt": raw_requests[0].get("negative_prompt"),
|
74 |
+
"width": raw_requests[0].get("width"),
|
75 |
+
"height": raw_requests[0].get("height"),
|
76 |
+
"num_inference_steps": raw_requests[0].get("num_inference_steps", 30),
|
77 |
+
"guidance_scale": raw_requests[0].get("guidance_scale", 7.5),
|
78 |
+
}
|
79 |
+
|
80 |
+
logger.info("Processed request: '%s'", processed_request)
|
81 |
+
return processed_request
|
82 |
+
|
83 |
|
84 |
+
def inference(self, request):
|
85 |
"""Generates the image relevant to the received text.
|
86 |
Args:
|
87 |
inputs (list): List of Text from the pre-process function is passed here
|
|
|
91 |
|
92 |
# Handling inference for sequence_classification.
|
93 |
inferences = self.pipe(
|
94 |
+
**request
|
95 |
).images
|
96 |
|
97 |
logger.info("Generated image: '%s'", inferences)
|
|
|
104 |
Returns:
|
105 |
(list): Returns a list of the images.
|
106 |
"""
|
107 |
+
bucket_name = "outputs-storage-prod"
|
108 |
+
client = storage.Client()
|
109 |
+
bucket = client.get_bucket(bucket_name)
|
110 |
+
outputs = []
|
111 |
+
for image in inference_output.images:
|
112 |
+
image_name = str(uuid.uuid4())
|
113 |
+
|
114 |
+
blob = bucket.blob(image_name + '.png')
|
115 |
+
|
116 |
+
with TemporaryFile() as tmp:
|
117 |
+
image.save(tmp, format="png")
|
118 |
+
tmp.seek(0)
|
119 |
+
blob.upload_from_file(tmp, content_type='image/png')
|
120 |
+
|
121 |
+
# generate txt file with the image name and the prompt inside
|
122 |
+
blob = bucket.blob(image_name + '.txt')
|
123 |
+
blob.upload_from_string(self.prompt)
|
124 |
+
|
125 |
+
outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png')
|
126 |
+
return outputs
|