MarkZakelj commited on
Commit
dec981f
·
1 Parent(s): 341de7e

style lora fusion with character lora

Browse files
Files changed (2) hide show
  1. sequential_timer.py +25 -0
  2. serve_loras.py +123 -26
sequential_timer.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter
2
+
3
+ class SequentialTimer:
4
+ def __init__(self, make_print=False):
5
+ self.timings = []
6
+ self.make_print = make_print
7
+
8
+ def time(self, message: str):
9
+ if self.make_print:
10
+ print(message)
11
+ self.timings.append((perf_counter(), message))
12
+
13
+ def to_str(self) -> str:
14
+ s = ""
15
+ if len(self.timings) <= 1:
16
+ s = "No timings"
17
+ return s
18
+ t0 = self.timings[0][0]
19
+ for ((t1, m1), (t2, _)) in zip(self.timings, self.timings[1:]):
20
+ s += f"TIME: step: {t2 - t1:06.3f} | cum {t2 - t0:06.3f} - {m1}\n"
21
+ s += f"ALL TIME: {self.timings[-1][0] - self.timings[0][0]:07.3f}\n"
22
+ return s
23
+
24
+ def printall(self):
25
+ print(self.to_str())
serve_loras.py CHANGED
@@ -5,7 +5,7 @@ import uuid
5
 
6
  import diffusers
7
  import torch
8
- from diffusers import StableDiffusionXLPipeline
9
 
10
  import numpy as np
11
  import threading
@@ -14,13 +14,15 @@ import base64
14
  from io import BytesIO
15
  from PIL import Image
16
  import numpy as np
17
- import uuid
18
  from tempfile import TemporaryFile
19
  from google.cloud import storage
20
  import sys
21
  import sentry_sdk
22
  from flask import Flask, request, jsonify
23
  import os
 
 
 
24
 
25
  logger = logging.getLogger(__name__)
26
  logger.info("Diffusers version %s", diffusers.__version__)
@@ -34,6 +36,24 @@ sentry_sdk.init(
34
 
35
  LORAS_DIR = './safetensors'
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  class DiffusersHandler(ABC):
38
  """
39
  Diffusers handler class for text to image generation.
@@ -65,8 +85,31 @@ class DiffusersHandler(ABC):
65
  torch_dtype=torch.float16,
66
  use_safetensors=True,
67
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- logger.info("moving model to device: %s", device_str)
70
  self.pipe.to(self.device)
71
 
72
  logger.info(self.device)
@@ -86,20 +129,30 @@ class DiffusersHandler(ABC):
86
  logger.info("Received requests: '%s'", raw_requests)
87
  self.working = True
88
 
89
- processed_request = {
90
  "prompt": raw_requests[0]["prompt"],
91
  "negative_prompt": raw_requests[0].get("negative_prompt"),
92
  "width": raw_requests[0].get("width"),
93
  "height": raw_requests[0].get("height"),
94
- "num_inference_steps": raw_requests[0].get("num_inference_steps", 30),
95
- "guidance_scale": raw_requests[0].get("guidance_scale", 7.5),
96
- "lora_weights": raw_requests[0].get("lora_name", None),
97
- "cross_attention_kwargs": {"scale": raw_requests[0].get("lora_scale", 0.6)}
98
  }
 
 
 
 
 
 
 
 
 
 
99
 
100
- logger.info("Processed request: '%s'", processed_request)
101
- axiom_logger.info("Processed request:" + str(processed_request), request_id=self.req_id, device=self.device_str)
102
- return processed_request
103
 
104
 
105
  def inference(self, request):
@@ -111,29 +164,70 @@ class DiffusersHandler(ABC):
111
  """
112
 
113
  # Handling inference for sequence_classification.
114
- compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
 
 
 
 
 
 
115
 
116
- self.prompt = request.pop("prompt")
117
- conditioning, pooled = compel(self.prompt)
 
 
118
 
119
- lora_weights = request.pop("lora_weights")
120
- if lora_weights is not None:
121
- lora_path = os.path.join(LORAS_DIR, lora_weights + '.safetensors')
122
- logger.info('LOADING LORA FROM: ' + lora_path)
123
- self.pipe.load_lora_weights(lora_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Handling inference for sequence_classification.
 
126
  inferences = self.pipe(
127
  prompt_embeds=conditioning,
128
  pooled_prompt_embeds=pooled,
129
- **request
 
 
130
  ).images
131
 
132
- if lora_weights is not None:
 
 
 
 
 
 
 
133
  self.pipe.unload_lora_weights()
134
 
135
- logger.info("Generated image: '%s'", inferences)
136
- axiom_logger.info("Generated images", request_id=self.req_id, device=self.device_str)
 
 
137
  return inferences
138
 
139
  def postprocess(self, inference_outputs):
@@ -178,16 +272,19 @@ handlers = [DiffusersHandler() for i in range(gpu_count)]
178
  for i in range(gpu_count):
179
  handlers[i].initialize({"gpu_id": i})
180
 
181
- handler_lock = threading.Lock()
182
- handler_index = 0
 
183
 
184
  @app.route('/generate', methods=['POST'])
185
  def generate_image():
186
  req_id = str(uuid.uuid4())
187
  global handler_index
 
188
  try:
189
  # Extract raw requests from HTTP POST body
190
  raw_requests = request.json
 
191
 
192
  with handler_lock:
193
  selected_handler = handlers[handler_index]
@@ -202,7 +299,7 @@ def generate_image():
202
  return jsonify({"image_urls": outputs})
203
  except Exception as e:
204
  logger.error("Error during image generation: %s", str(e))
205
- axiom_logger.critical("Error during image generation: " + str(e), request_id=req_id)
206
  return jsonify({"error": "Failed to generate image", "details": str(e)}), 500
207
 
208
  if __name__ == '__main__':
 
5
 
6
  import diffusers
7
  import torch
8
+ from diffusers import StableDiffusionXLPipeline, DiffusionPipeline
9
 
10
  import numpy as np
11
  import threading
 
14
  from io import BytesIO
15
  from PIL import Image
16
  import numpy as np
 
17
  from tempfile import TemporaryFile
18
  from google.cloud import storage
19
  import sys
20
  import sentry_sdk
21
  from flask import Flask, request, jsonify
22
  import os
23
+ from sequential_timer import SequentialTimer
24
+ from safetensors.torch import load_file
25
+ import copy
26
 
27
  logger = logging.getLogger(__name__)
28
  logger.info("Diffusers version %s", diffusers.__version__)
 
36
 
37
  LORAS_DIR = './safetensors'
38
 
39
+ handler_lock = threading.Lock()
40
+ handler_index = 0
41
+
42
+ class LoraCache():
43
+ def __init__(self, loras_dir: str = LORAS_DIR):
44
+ self.loras_dir = loras_dir
45
+ self.cache = {}
46
+
47
+ def load_lora(self, lora_name: str):
48
+ if lora_name.endswith('.safetensors'):
49
+ lora_name = lora_name.rstrip('.safetensors')
50
+ if lora_name not in self.cache:
51
+ lora = load_file(os.path.join(self.loras_dir, lora_name+'.safetensors'))
52
+ self.cache[lora_name] = lora
53
+ return copy.deepcopy(self.cache[lora_name])
54
+
55
+ lora_cache = LoraCache()
56
+
57
  class DiffusersHandler(ABC):
58
  """
59
  Diffusers handler class for text to image generation.
 
85
  torch_dtype=torch.float16,
86
  use_safetensors=True,
87
  )
88
+ # self.refiner = DiffusionPipeline.from_pretrained(
89
+ # "stabilityai/stable-diffusion-xl-refiner-1.0",
90
+ # text_encoder_2=self.pipe.text_encoder_2,
91
+ # vae=self.pipe.vae,
92
+ # torch_dtype=torch.float16,
93
+ # use_safetensors=True,
94
+ # variant="fp16",
95
+ # )
96
+ # self.refiner.enable_model_cpu_offload(properties.get("gpu_id"))
97
+ # logger.info("Refiner initialized and o")
98
+
99
+ self.compel_base = Compel(
100
+ tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
101
+ text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
102
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
103
+ requires_pooled=[False, True])
104
+ logger.info("Compel initialized")
105
+
106
+ # self.compel_refiner = Compel(
107
+ # tokenizer=[self.refiner.tokenizer_2],
108
+ # text_encoder=[self.refiner.text_encoder_2],
109
+ # returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
110
+ # requires_pooled=[True])
111
 
112
+ logger.info("moving base model to device: %s", device_str)
113
  self.pipe.to(self.device)
114
 
115
  logger.info(self.device)
 
129
  logger.info("Received requests: '%s'", raw_requests)
130
  self.working = True
131
 
132
+ model_args = {
133
  "prompt": raw_requests[0]["prompt"],
134
  "negative_prompt": raw_requests[0].get("negative_prompt"),
135
  "width": raw_requests[0].get("width"),
136
  "height": raw_requests[0].get("height"),
137
+ "num_inference_steps": raw_requests[0].get("num_inference_steps", 25),
138
+ "guidance_scale": raw_requests[0].get("guidance_scale", 8.5)
139
+ # "lora_weights": raw_requests[0].get("lora_name", None)
140
+ # "cross_attention_kwargs": {"scale": raw_requests[0].get("lora_scale", 0.0)}
141
  }
142
+
143
+ extra_args = {
144
+ "seed": raw_requests[0].get("seed", None),
145
+ "style_lora": raw_requests[0].get("style_lora", None),
146
+ "style_scale": raw_requests[0].get("style_scale", 1.0),
147
+ "char_lora": raw_requests[0].get("char_lora", None),
148
+ "char_scale": raw_requests[0].get("char_scale", 1.0)
149
+ }
150
+
151
+
152
 
153
+ logger.info("Processed request: '%s'", model_args)
154
+ axiom_logger.info("Processed request:" + str(model_args), request_id=self.req_id, device=self.device_str)
155
+ return model_args, extra_args
156
 
157
 
158
  def inference(self, request):
 
164
  """
165
 
166
  # Handling inference for sequence_classification.
167
+ # compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True])
168
+ st = SequentialTimer()
169
+ model_args, extra_args = request
170
+
171
+ use_char_lora = extra_args['char_lora'] is not None
172
+ use_style_lora = extra_args['style_lora'] is not None
173
+
174
 
175
+ style_lora = extra_args['style_lora']
176
+ char_lora = extra_args['char_lora']
177
+
178
+ cross_attention_kwargs = {"scale": extra_args['char_scale'] if use_char_lora else extra_args['style_scale']}
179
 
180
+ generator = torch.Generator(device="cuda").manual_seed(extra_args['seed']) if extra_args['seed'] else None
181
+
182
+
183
+ self.prompt = model_args.pop("prompt")
184
+
185
+ st.time("Base compel embedding")
186
+ conditioning, pooled = self.compel_base(self.prompt)
187
+
188
+ if use_style_lora:
189
+ style_lora = os.path.join(LORAS_DIR, style_lora + '.safetensors')
190
+ st.time("Load style lora")
191
+ self.pipe.load_lora_weights(style_lora)
192
+ if use_char_lora:
193
+ st.time("Fuse style lora into model")
194
+ self.pipe.fuse_lora(lora_scale=extra_args['style_scale'], fuse_text_encoder=False)
195
+
196
+ if use_char_lora:
197
+ char_lora = os.path.join(LORAS_DIR, char_lora + '.safetensors')
198
+ st.time('load character lora')
199
+ self.pipe.load_lora_weights(char_lora)
200
+
201
+ # lora_weights = model_args.pop("lora_weights")
202
+ # if lora_weights is not None:
203
+ # lora_path = os.path.join(LORAS_DIR, lora_weights + '.safetensors')
204
+ # logger.info('LOADING LORA FROM: ' + lora_path)
205
+ # self.pipe.load_lora_weights(lora_path)
206
 
207
  # Handling inference for sequence_classification.
208
+ st.time("base model inference")
209
  inferences = self.pipe(
210
  prompt_embeds=conditioning,
211
  pooled_prompt_embeds=pooled,
212
+ generator=generator,
213
+ cross_attention_kwargs=cross_attention_kwargs,
214
+ **model_args
215
  ).images
216
 
217
+ # if lora_weights is not None:
218
+ # self.pipe.unload_lora_weights()
219
+ if use_style_lora and use_char_lora:
220
+ st.time("unfuse lora weights")
221
+ self.pipe.unfuse_lora(unfuse_text_encoder=False)
222
+
223
+ if use_style_lora or use_char_lora:
224
+ st.time("unload lora weights")
225
  self.pipe.unload_lora_weights()
226
 
227
+ st.time('end')
228
+
229
+ # logger.info("Generated image: '%s'", inferences)
230
+ axiom_logger.info("Generated images", request_id=self.req_id, device=self.device_str, timings=st.to_str())
231
  return inferences
232
 
233
  def postprocess(self, inference_outputs):
 
272
  for i in range(gpu_count):
273
  handlers[i].initialize({"gpu_id": i})
274
 
275
+
276
+
277
+
278
 
279
  @app.route('/generate', methods=['POST'])
280
  def generate_image():
281
  req_id = str(uuid.uuid4())
282
  global handler_index
283
+ selected_handler = None
284
  try:
285
  # Extract raw requests from HTTP POST body
286
  raw_requests = request.json
287
+ axiom_logger.info(message="Received request", request_id=req_id, **raw_requests)
288
 
289
  with handler_lock:
290
  selected_handler = handlers[handler_index]
 
299
  return jsonify({"image_urls": outputs})
300
  except Exception as e:
301
  logger.error("Error during image generation: %s", str(e))
302
+ axiom_logger.critical("Error during image generation: " + str(e), request_id=req_id, device=selected_handler.device_str)
303
  return jsonify({"error": "Failed to generate image", "details": str(e)}), 500
304
 
305
  if __name__ == '__main__':