MarkZakelj commited on
Commit
e85fe1a
·
1 Parent(s): c339505

loras serve

Browse files
Files changed (3) hide show
  1. gunicorn_config.py +12 -0
  2. serve_loras.py +25 -17
  3. serve_loras_prod.py +318 -0
gunicorn_config.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gunicorn_config.py
2
+ import os
3
+
4
+ worker_id_counter = 0
5
+
6
+ def pre_fork(server, worker):
7
+ global worker_id_counter
8
+ worker_id_counter += 1
9
+
10
+ def post_fork(server, worker):
11
+ worker_id = worker_id_counter - 1
12
+ os.environ['WORKER_ID'] = str(worker_id % 4)
serve_loras.py CHANGED
@@ -9,6 +9,7 @@ from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
9
 
10
  import numpy as np
11
  import threading
 
12
 
13
  import base64
14
  from io import BytesIO
@@ -23,6 +24,7 @@ import os
23
  from sequential_timer import SequentialTimer
24
  from safetensors.torch import load_file
25
  import copy
 
26
 
27
  logger = logging.getLogger(__name__)
28
  logger.info("Diffusers version %s", diffusers.__version__)
@@ -36,23 +38,25 @@ sentry_sdk.init(
36
 
37
  LORAS_DIR = './safetensors'
38
 
 
 
39
  handler_lock = threading.Lock()
40
  handler_index = 0
41
 
42
- class LoraCache():
43
- def __init__(self, loras_dir: str = LORAS_DIR):
44
- self.loras_dir = loras_dir
45
- self.cache = {}
46
-
47
- def load_lora(self, lora_name: str):
48
- if lora_name.endswith('.safetensors'):
49
- lora_name = lora_name.rstrip('.safetensors')
50
- if lora_name not in self.cache:
51
- lora = load_file(os.path.join(self.loras_dir, lora_name+'.safetensors'))
52
- self.cache[lora_name] = lora
53
- return copy.deepcopy(self.cache[lora_name])
54
 
55
- lora_cache = LoraCache()
56
 
57
  class DiffusersHandler(ABC):
58
  """
@@ -134,7 +138,7 @@ class DiffusersHandler(ABC):
134
  "negative_prompt": raw_requests[0].get("negative_prompt"),
135
  "width": raw_requests[0].get("width"),
136
  "height": raw_requests[0].get("height"),
137
- "num_inference_steps": raw_requests[0].get("num_inference_steps", 25),
138
  "guidance_scale": raw_requests[0].get("guidance_scale", 8.5)
139
  # "lora_weights": raw_requests[0].get("lora_name", None)
140
  # "cross_attention_kwargs": {"scale": raw_requests[0].get("lora_scale", 0.0)}
@@ -167,6 +171,7 @@ class DiffusersHandler(ABC):
167
  # compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
168
  st = SequentialTimer()
169
  model_args, extra_args = request
 
170
 
171
  use_char_lora = extra_args['char_lora'] is not None
172
  use_style_lora = extra_args['style_lora'] is not None
@@ -188,7 +193,8 @@ class DiffusersHandler(ABC):
188
  if use_style_lora:
189
  style_lora = os.path.join(LORAS_DIR, style_lora + '.safetensors')
190
  st.time("Load style lora")
191
- self.pipe.load_lora_weights(style_lora)
 
192
  if use_char_lora:
193
  st.time("Fuse style lora into model")
194
  self.pipe.fuse_lora(lora_scale=extra_args['style_scale'], fuse_text_encoder=False)
@@ -196,7 +202,8 @@ class DiffusersHandler(ABC):
196
  if use_char_lora:
197
  char_lora = os.path.join(LORAS_DIR, char_lora + '.safetensors')
198
  st.time('load character lora')
199
- self.pipe.load_lora_weights(char_lora)
 
200
 
201
  # lora_weights = model_args.pop("lora_weights")
202
  # if lora_weights is not None:
@@ -287,6 +294,8 @@ def generate_image():
287
  axiom_logger.info(message="Received request", request_id=req_id, **raw_requests)
288
 
289
  with handler_lock:
 
 
290
  selected_handler = handlers[handler_index]
291
  handler_index = (handler_index + 1) % gpu_count # Rotate to the next handler
292
  selected_handler.req_id = req_id
@@ -295,7 +304,6 @@ def generate_image():
295
  inferences = selected_handler.inference(processed_request)
296
  outputs = selected_handler.postprocess(inferences)
297
  selected_handler.req_id = None
298
-
299
  return jsonify({"image_urls": outputs})
300
  except Exception as e:
301
  logger.error("Error during image generation: %s", str(e))
 
9
 
10
  import numpy as np
11
  import threading
12
+ import mmap
13
 
14
  import base64
15
  from io import BytesIO
 
24
  from sequential_timer import SequentialTimer
25
  from safetensors.torch import load_file
26
  import copy
27
+ import gc
28
 
29
  logger = logging.getLogger(__name__)
30
  logger.info("Diffusers version %s", diffusers.__version__)
 
38
 
39
  LORAS_DIR = './safetensors'
40
 
41
+ lora_lock = threading.Lock()
42
+
43
  handler_lock = threading.Lock()
44
  handler_index = 0
45
 
46
+ # class LoraCache():
47
+ # def __init__(self, loras_dir: str = LORAS_DIR):
48
+ # self.loras_dir = loras_dir
49
+ # self.cache = {}
50
+
51
+ # def load_lora(self, lora_name: str):
52
+ # if lora_name.endswith('.safetensors'):
53
+ # lora_name = lora_name.rstrip('.safetensors')
54
+ # if lora_name not in self.cache:
55
+ # lora = load_file(os.path.join(self.loras_dir, lora_name+'.safetensors'))
56
+ # self.cache[lora_name] = lora
57
+ # return copy.deepcopy(self.cache[lora_name])
58
 
59
+ # lora_cache = LoraCache()
60
 
61
  class DiffusersHandler(ABC):
62
  """
 
138
  "negative_prompt": raw_requests[0].get("negative_prompt"),
139
  "width": raw_requests[0].get("width"),
140
  "height": raw_requests[0].get("height"),
141
+ "num_inference_steps": raw_requests[0].get("num_inference_steps", 30),
142
  "guidance_scale": raw_requests[0].get("guidance_scale", 8.5)
143
  # "lora_weights": raw_requests[0].get("lora_name", None)
144
  # "cross_attention_kwargs": {"scale": raw_requests[0].get("lora_scale", 0.0)}
 
171
  # compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
172
  st = SequentialTimer()
173
  model_args, extra_args = request
174
+ global lora_cache
175
 
176
  use_char_lora = extra_args['char_lora'] is not None
177
  use_style_lora = extra_args['style_lora'] is not None
 
193
  if use_style_lora:
194
  style_lora = os.path.join(LORAS_DIR, style_lora + '.safetensors')
195
  st.time("Load style lora")
196
+ with lora_lock:
197
+ self.pipe.load_lora_weights(style_lora)
198
  if use_char_lora:
199
  st.time("Fuse style lora into model")
200
  self.pipe.fuse_lora(lora_scale=extra_args['style_scale'], fuse_text_encoder=False)
 
202
  if use_char_lora:
203
  char_lora = os.path.join(LORAS_DIR, char_lora + '.safetensors')
204
  st.time('load character lora')
205
+ with lora_lock:
206
+ self.pipe.load_lora_weights(char_lora)
207
 
208
  # lora_weights = model_args.pop("lora_weights")
209
  # if lora_weights is not None:
 
294
  axiom_logger.info(message="Received request", request_id=req_id, **raw_requests)
295
 
296
  with handler_lock:
297
+ if handler_index == 0:
298
+ gc.collect()
299
  selected_handler = handlers[handler_index]
300
  handler_index = (handler_index + 1) % gpu_count # Rotate to the next handler
301
  selected_handler.req_id = req_id
 
304
  inferences = selected_handler.inference(processed_request)
305
  outputs = selected_handler.postprocess(inferences)
306
  selected_handler.req_id = None
 
307
  return jsonify({"image_urls": outputs})
308
  except Exception as e:
309
  logger.error("Error during image generation: %s", str(e))
serve_loras_prod.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from compel import Compel, ReturnedEmbeddingsType
2
+ import logging
3
+ from abc import ABC
4
+ import uuid
5
+
6
+ import diffusers
7
+ import torch
8
+ from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
9
+
10
+ import numpy as np
11
+ import threading
12
+
13
+ import base64
14
+ from io import BytesIO
15
+ from PIL import Image
16
+ import numpy as np
17
+ from tempfile import TemporaryFile
18
+ from google.cloud import storage
19
+ import sys
20
+ import sentry_sdk
21
+ from flask import Flask, request, jsonify, current_app
22
+ import os
23
+ from sequential_timer import SequentialTimer
24
+ from safetensors.torch import load_file
25
+ from dotenv import load_dotenv
26
+ import copy
27
+ import gc
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+ logger.info("Diffusers version %s", diffusers.__version__)
32
+
33
+ from axiom_logger import AxiomLogger
34
+ axiom_logger = AxiomLogger()
35
+
36
+ sentry_sdk.init(
37
+ dsn="https://f750d1b039d66541f344ee6151d38166@o4505891057696768.ingest.sentry.io/4506071735205888",
38
+ )
39
+
40
+ LORAS_DIR = './safetensors'
41
+
42
+ load_dotenv()
43
+
44
+ lora_lock = threading.Lock()
45
+
46
+ # handler_lock = threading.Lock()
47
+ # handler_index = 0
48
+
49
+ # class LoraCache():
50
+ # def __init__(self, loras_dir: str = LORAS_DIR):
51
+ # self.loras_dir = loras_dir
52
+ # self.cache = {}
53
+
54
+ # def load_lora(self, lora_name: str):
55
+ # if lora_name.endswith('.safetensors'):
56
+ # lora_name = lora_name.rstrip('.safetensors')
57
+ # if lora_name not in self.cache:
58
+ # lora = load_file(os.path.join(self.loras_dir, lora_name+'.safetensors'))
59
+ # self.cache[lora_name] = lora
60
+ # return copy.deepcopy(self.cache[lora_name])
61
+
62
+ # lora_cache = LoraCache()
63
+
64
+ class DiffusersHandler(ABC):
65
+ """
66
+ Diffusers handler class for text to image generation.
67
+ """
68
+
69
+ def __init__(self):
70
+ self.initialized = False
71
+ self.req_id = None
72
+
73
+ def initialize(self, properties):
74
+ """In this initialize function, the Stable Diffusion model is loaded and
75
+ initialized here.
76
+ Args:
77
+ ctx (context): It is a JSON Object containing information
78
+ pertaining to the model artefacts parameters.
79
+ """
80
+
81
+ logger.info("Loading diffusion model")
82
+ logger.info("I'm totally new and updated")
83
+
84
+
85
+ device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu"
86
+ self.device_str = device_str
87
+
88
+ print("my device is " + device_str)
89
+ self.device = torch.device(device_str)
90
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(
91
+ "./",
92
+ torch_dtype=torch.float16,
93
+ use_safetensors=True,
94
+ )
95
+ # self.refiner = DiffusionPipeline.from_pretrained(
96
+ # "stabilityai/stable-diffusion-xl-refiner-1.0",
97
+ # text_encoder_2=self.pipe.text_encoder_2,
98
+ # vae=self.pipe.vae,
99
+ # torch_dtype=torch.float16,
100
+ # use_safetensors=True,
101
+ # variant="fp16",
102
+ # )
103
+ # self.refiner.enable_model_cpu_offload(properties.get("gpu_id"))
104
+ # logger.info("Refiner initialized and o")
105
+
106
+ self.compel_base = Compel(
107
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
108
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
109
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
110
+ requires_pooled=[False, True])
111
+ logger.info("Compel initialized")
112
+
113
+ # self.compel_refiner = Compel(
114
+ # tokenizer=[self.refiner.tokenizer_2],
115
+ # text_encoder=[self.refiner.text_encoder_2],
116
+ # returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
117
+ # requires_pooled=[True])
118
+
119
+ logger.info("moving base model to device: %s", device_str)
120
+ self.pipe.to(self.device)
121
+
122
+ logger.info(self.device)
123
+ logger.info("Diffusion model from path %s loaded successfully")
124
+ axiom_logger.info("Diffusion model initialized", device=self.device_str)
125
+
126
+ self.initialized = True
127
+
128
+ def preprocess(self, raw_requests):
129
+ """Basic text preprocessing, of the user's prompt.
130
+ Args:
131
+ requests (str): The Input data in the form of text is passed on to the preprocess
132
+ function.
133
+ Returns:
134
+ list : The preprocess function returns a list of prompts.
135
+ """
136
+ logger.info("Received requests: '%s'", raw_requests)
137
+ self.working = True
138
+
139
+ model_args = {
140
+ "prompt": raw_requests[0]["prompt"],
141
+ "negative_prompt": raw_requests[0].get("negative_prompt"),
142
+ "width": raw_requests[0].get("width"),
143
+ "height": raw_requests[0].get("height"),
144
+ "num_inference_steps": raw_requests[0].get("num_inference_steps", 30),
145
+ "guidance_scale": raw_requests[0].get("guidance_scale", 8.5)
146
+ # "lora_weights": raw_requests[0].get("lora_name", None)
147
+ # "cross_attention_kwargs": {"scale": raw_requests[0].get("lora_scale", 0.0)}
148
+ }
149
+
150
+ extra_args = {
151
+ "seed": raw_requests[0].get("seed", None),
152
+ "style_lora": raw_requests[0].get("style_lora", None),
153
+ "style_scale": raw_requests[0].get("style_scale", 1.0),
154
+ "char_lora": raw_requests[0].get("char_lora", None),
155
+ "char_scale": raw_requests[0].get("char_scale", 1.0),
156
+ "scene_prompt": raw_requests[0].get("scene_prompt", None)
157
+ }
158
+
159
+
160
+ logger.info("Processed request: '%s'", model_args)
161
+ axiom_logger.info("Processed request:" + str(model_args), request_id=self.req_id, device=self.device_str)
162
+ return model_args, extra_args
163
+
164
+
165
+ def inference(self, request):
166
+ """Generates the image relevant to the received text.
167
+ Args:
168
+ inputs (list): List of Text from the pre-process function is passed here
169
+ Returns:
170
+ list : It returns a list of the generate images for the input text
171
+ """
172
+
173
+ # Handling inference for sequence_classification.
174
+ # compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
175
+ st = SequentialTimer()
176
+ model_args, extra_args = request
177
+
178
+ use_char_lora = extra_args['char_lora'] is not None
179
+ use_style_lora = extra_args['style_lora'] is not None
180
+
181
+
182
+ style_lora = extra_args['style_lora']
183
+ char_lora = extra_args['char_lora']
184
+
185
+ cross_attention_kwargs = {"scale": extra_args['char_scale'] if use_char_lora else extra_args['style_scale']}
186
+
187
+ generator = torch.Generator(device="cuda").manual_seed(extra_args['seed']) if extra_args['seed'] else None
188
+
189
+
190
+ prompt = model_args.pop("prompt")
191
+ negative_prompt = model_args.pop('negative_prompt')
192
+ scene_prompt = extra_args['scene_prompt']
193
+ if scene_prompt:
194
+ prompt = f'("{prompt}", "{scene_prompt}").and()'
195
+ st.time("Base compel embedding")
196
+ conditioning, pooled = self.compel_base(prompt)
197
+ negative_conditioning, negative_pooled = self.compel_base(negative_prompt)
198
+
199
+ [conditioning, negative_conditioning] = self.compel_base.pad_conditioning_tensors_to_same_length([conditioning, negative_conditioning])
200
+
201
+ if use_style_lora:
202
+ style_lora = os.path.join(LORAS_DIR, style_lora + '.safetensors')
203
+ st.time("Load style lora")
204
+ self.pipe.load_lora_weights(style_lora)
205
+ if use_char_lora:
206
+ st.time("Fuse style lora into model")
207
+ self.pipe.fuse_lora(lora_scale=extra_args['style_scale'], fuse_text_encoder=False)
208
+
209
+ if use_char_lora:
210
+ char_lora = os.path.join(LORAS_DIR, char_lora + '.safetensors')
211
+ st.time('load character lora')
212
+ self.pipe.load_lora_weights(char_lora)
213
+
214
+ # lora_weights = model_args.pop("lora_weights")
215
+ # if lora_weights is not None:
216
+ # lora_path = os.path.join(LORAS_DIR, lora_weights + '.safetensors')
217
+ # logger.info('LOADING LORA FROM: ' + lora_path)
218
+ # self.pipe.load_lora_weights(lora_path)
219
+
220
+ # Handling inference for sequence_classification.
221
+ st.time("base model inference")
222
+ inferences = self.pipe(
223
+ prompt_embeds=conditioning,
224
+ pooled_prompt_embeds=pooled,
225
+ negative_prompt_embeds=negative_conditioning,
226
+ negative_pooled_prompt_embeds=negative_pooled,
227
+ generator=generator,
228
+ cross_attention_kwargs=cross_attention_kwargs,
229
+ **model_args
230
+ ).images
231
+
232
+ if use_style_lora and use_char_lora:
233
+ st.time("unfuse lora weights")
234
+ self.pipe.unfuse_lora(unfuse_text_encoder=False)
235
+
236
+ if use_style_lora or use_char_lora:
237
+ st.time("unload lora weights")
238
+ self.pipe.unload_lora_weights()
239
+
240
+ st.time('end')
241
+
242
+ # logger.info("Generated image: '%s'", inferences)
243
+ axiom_logger.info("Generated images", request_id=self.req_id, device=self.device_str, timings=st.to_str())
244
+ return inferences
245
+
246
+ def postprocess(self, inference_outputs):
247
+ """Post Process Function converts the generated image into Torchserve readable format.
248
+ Args:
249
+ inference_outputs (list): It contains the generated image of the input text.
250
+ Returns:
251
+ (list): Returns a list of the images.
252
+ """
253
+ bucket_name = "outputs-storage-prod"
254
+ client = storage.Client()
255
+ self.working = False
256
+ bucket = client.get_bucket(bucket_name)
257
+ outputs = []
258
+ for image in inference_outputs:
259
+ image_name = str(uuid.uuid4())
260
+
261
+ blob = bucket.blob(image_name + '.png')
262
+
263
+ with TemporaryFile() as tmp:
264
+ image.save(tmp, format="png")
265
+ tmp.seek(0)
266
+ blob.upload_from_file(tmp, content_type='image/png')
267
+
268
+ # generate txt file with the image name and the prompt inside
269
+ # blob = bucket.blob(image_name + '.txt')
270
+ # blob.upload_from_string(self.prompt)
271
+ url_name = 'https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png'
272
+ outputs.append(url_name)
273
+ axiom_logger.info("Pushed image to google cloud: "+ url_name, request_id=self.req_id, device=self.device_str)
274
+ return outputs
275
+
276
+
277
+ app = Flask(__name__)
278
+
279
+ # Initialize the handler on startup
280
+ gpu_count = torch.cuda.device_count()
281
+ if gpu_count == 0:
282
+ raise ValueError("No GPUs available!")
283
+
284
+ worker_id = os.environ.get('WORKER_ID', 'Unknown')
285
+ if worker_id == 'Unknown':
286
+ raise ValueError("No worker id")
287
+ logger.critical("cant get worker ID")
288
+ logger.info(f"WORKER ID: {worker_id}")
289
+ handler = DiffusersHandler()
290
+ handler.initialize({"gpu_id": worker_id})
291
+
292
+
293
+ @app.route('/generate', methods=['POST'])
294
+ def generate_image():
295
+ req_id = str(uuid.uuid4())
296
+ selected_handler = None
297
+ try:
298
+ # Extract raw requests from HTTP POST body
299
+ raw_requests = request.json
300
+ axiom_logger.info(message="Received request", request_id=req_id, **raw_requests)
301
+
302
+ gc.collect()
303
+ torch.cuda.empty_cache()
304
+ selected_handler = handler
305
+ selected_handler.req_id = req_id
306
+
307
+ processed_request = selected_handler.preprocess([raw_requests])
308
+ inferences = selected_handler.inference(processed_request)
309
+ outputs = selected_handler.postprocess(inferences)
310
+ selected_handler.req_id = None
311
+ return jsonify({"image_urls": outputs})
312
+ except Exception as e:
313
+ logger.error("Error during image generation: %s", str(e))
314
+ axiom_logger.critical("Error during image generation: " + str(e), request_id=req_id, device=selected_handler.device_str)
315
+ return jsonify({"error": "Failed to generate image", "details": str(e)}), 500
316
+
317
+ if __name__ == '__main__':
318
+ app.run(host='0.0.0.0', port=3000, threaded=False)