zww commited on
Commit
4929721
·
1 Parent(s): 0728495

gcs upload in handler

Browse files
Files changed (1) hide show
  1. stable_diffusion_handler.py +50 -17
stable_diffusion_handler.py CHANGED
@@ -8,6 +8,13 @@ from diffusers import StableDiffusionXLPipeline
8
  from ts.torch_handler.base_handler import BaseHandler
9
  import numpy as np
10
 
 
 
 
 
 
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
  logger.info("Diffusers version %s", diffusers.__version__)
@@ -37,17 +44,21 @@ class DiffusersHandler(BaseHandler, ABC):
37
  device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
38
 
39
  self.device = torch.device(device_str)
40
- self.pipe = StableDiffusionXLPipeline.from_pretrained("./")
 
 
 
 
41
 
42
  logger.info("moving model to device: %s", device_str)
43
  self.pipe.to(self.device)
44
-
45
  logger.info(self.device)
46
  logger.info("Diffusion model from path %s loaded successfully", model_dir)
47
 
48
  self.initialized = True
49
 
50
- def preprocess(self, requests):
51
  """Basic text preprocessing, of the user's prompt.
52
  Args:
53
  requests (str): The Input data in the form of text is passed on to the preprocess
@@ -55,16 +66,22 @@ class DiffusersHandler(BaseHandler, ABC):
55
  Returns:
56
  list : The preprocess function returns a list of prompts.
57
  """
58
- logger.info("Received requests: '%s'", requests)
59
-
60
- text = requests[0]["prompt"]
61
-
62
- logger.info("pre-processed text: '%s'", text)
63
-
64
- return [text]
65
-
 
 
 
 
 
 
66
 
67
- def inference(self, inputs):
68
  """Generates the image relevant to the received text.
69
  Args:
70
  inputs (list): List of Text from the pre-process function is passed here
@@ -74,7 +91,7 @@ class DiffusersHandler(BaseHandler, ABC):
74
 
75
  # Handling inference for sequence_classification.
76
  inferences = self.pipe(
77
- inputs, guidance_scale=7.5, num_inference_steps=50
78
  ).images
79
 
80
  logger.info("Generated image: '%s'", inferences)
@@ -87,7 +104,23 @@ class DiffusersHandler(BaseHandler, ABC):
87
  Returns:
88
  (list): Returns a list of the images.
89
  """
90
- images = []
91
- for image in inference_output:
92
- images.append(np.array(image).tolist())
93
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from ts.torch_handler.base_handler import BaseHandler
9
  import numpy as np
10
 
11
+ import base64
12
+ from io import BytesIO
13
+ from PIL import Image
14
+ import numpy as np
15
+ import uuid
16
+ from tempfile import TemporaryFile
17
+ from google.cloud import storage
18
 
19
  logger = logging.getLogger(__name__)
20
  logger.info("Diffusers version %s", diffusers.__version__)
 
44
  device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
45
 
46
  self.device = torch.device(device_str)
47
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
48
+ "./",
49
+ torch_dtype=torch.float16,
50
+ use_safetensors=True,
51
+ )
52
 
53
  logger.info("moving model to device: %s", device_str)
54
  self.pipe.to(self.device)
55
+
56
  logger.info(self.device)
57
  logger.info("Diffusion model from path %s loaded successfully", model_dir)
58
 
59
  self.initialized = True
60
 
61
+ def preprocess(self, raw_requests):
62
  """Basic text preprocessing, of the user's prompt.
63
  Args:
64
  requests (str): The Input data in the form of text is passed on to the preprocess
 
66
  Returns:
67
  list : The preprocess function returns a list of prompts.
68
  """
69
+ logger.info("Received requests: '%s'", raw_requests)
70
+
71
+ processed_request = {
72
+ "prompt": raw_requests[0]["prompt"],
73
+ "negative_prompt": raw_requests[0].get("negative_prompt"),
74
+ "width": raw_requests[0].get("width"),
75
+ "height": raw_requests[0].get("height"),
76
+ "num_inference_steps": raw_requests[0].get("num_inference_steps", 30),
77
+ "guidance_scale": raw_requests[0].get("guidance_scale", 7.5),
78
+ }
79
+
80
+ logger.info("Processed request: '%s'", processed_request)
81
+ return processed_request
82
+
83
 
84
+ def inference(self, request):
85
  """Generates the image relevant to the received text.
86
  Args:
87
  inputs (list): List of Text from the pre-process function is passed here
 
91
 
92
  # Handling inference for sequence_classification.
93
  inferences = self.pipe(
94
+ **request
95
  ).images
96
 
97
  logger.info("Generated image: '%s'", inferences)
 
104
  Returns:
105
  (list): Returns a list of the images.
106
  """
107
+ bucket_name = "outputs-storage-prod"
108
+ client = storage.Client()
109
+ bucket = client.get_bucket(bucket_name)
110
+ outputs = []
111
+ for image in inference_output.images:
112
+ image_name = str(uuid.uuid4())
113
+
114
+ blob = bucket.blob(image_name + '.png')
115
+
116
+ with TemporaryFile() as tmp:
117
+ image.save(tmp, format="png")
118
+ tmp.seek(0)
119
+ blob.upload_from_file(tmp, content_type='image/png')
120
+
121
+ # generate txt file with the image name and the prompt inside
122
+ blob = bucket.blob(image_name + '.txt')
123
+ blob.upload_from_string(self.prompt)
124
+
125
+ outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png')
126
+ return outputs