feishen29 commited on
Commit
d46bdb7
1 Parent(s): a085dac

Upload 31 files

Browse files
adapter/__pycache__/attention_processor.cpython-39.pyc ADDED
Binary file (14.4 kB). View file
 
adapter/__pycache__/resampler.cpython-39.pyc ADDED
Binary file (7.37 kB). View file
 
adapter/attention_processor.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from diffusers.utils import USE_PEFT_BACKEND
7
+ from diffusers.models.lora import LoRALinearLayer
8
+
9
+
10
+
11
+
12
+
13
+ class CacheAttnProcessor2_0:
14
+ r"""
15
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
16
+ """
17
+
18
+ def __init__(self):
19
+ if not hasattr(F, "scaled_dot_product_attention"):
20
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
21
+
22
+ self.cache = {} # cache hidden states
23
+
24
+ def __call__(
25
+ self,
26
+ attn,
27
+ hidden_states: torch.FloatTensor,
28
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
29
+ attention_mask: Optional[torch.FloatTensor] = None,
30
+ temb: Optional[torch.FloatTensor] = None,
31
+ scale: float = 1.0,
32
+ ) -> torch.FloatTensor:
33
+
34
+ self.cache["hidden_states"] = hidden_states # cache hidden states
35
+
36
+ residual = hidden_states
37
+ if attn.spatial_norm is not None:
38
+ hidden_states = attn.spatial_norm(hidden_states, temb)
39
+
40
+ input_ndim = hidden_states.ndim
41
+
42
+ if input_ndim == 4:
43
+ batch_size, channel, height, width = hidden_states.shape
44
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
45
+
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
48
+ )
49
+
50
+ if attention_mask is not None:
51
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
52
+ # scaled_dot_product_attention expects attention_mask shape to be
53
+ # (batch, heads, source_length, target_length)
54
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
55
+
56
+ if attn.group_norm is not None:
57
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
58
+
59
+ args = () if USE_PEFT_BACKEND else (scale,)
60
+ query = attn.to_q(hidden_states, *args)
61
+
62
+ if encoder_hidden_states is None:
63
+ encoder_hidden_states = hidden_states
64
+ elif attn.norm_cross:
65
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
66
+
67
+ key = attn.to_k(encoder_hidden_states, *args)
68
+ value = attn.to_v(encoder_hidden_states, *args)
69
+
70
+ inner_dim = key.shape[-1]
71
+ head_dim = inner_dim // attn.heads
72
+
73
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
74
+
75
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
76
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
77
+
78
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
79
+ # TODO: add support for attn.scale when we move to Torch 2.1
80
+ hidden_states = F.scaled_dot_product_attention(
81
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
82
+ )
83
+
84
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
85
+ hidden_states = hidden_states.to(query.dtype)
86
+
87
+ # linear proj
88
+ hidden_states = attn.to_out[0](hidden_states, *args)
89
+ # dropout
90
+ hidden_states = attn.to_out[1](hidden_states)
91
+
92
+ if input_ndim == 4:
93
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
94
+
95
+ if attn.residual_connection:
96
+ hidden_states = hidden_states + residual
97
+
98
+ hidden_states = hidden_states / attn.rescale_output_factor
99
+
100
+ return hidden_states
101
+
102
+
103
+ class SAttnProcessor2_0(torch.nn.Module):
104
+ r"""
105
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
106
+ """
107
+
108
+ def __init__(self, name, hidden_size, cross_attention_dim=None):
109
+ if not hasattr(F, "scaled_dot_product_attention"):
110
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
111
+
112
+ super().__init__()
113
+
114
+ self.name = name
115
+ self.hidden_size = hidden_size
116
+ self.cross_attention_dim = cross_attention_dim
117
+
118
+ def __call__(
119
+ self,
120
+ attn,
121
+ hidden_states: torch.FloatTensor,
122
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
123
+ attention_mask: Optional[torch.FloatTensor] = None,
124
+ temb: Optional[torch.FloatTensor] = None,
125
+ scale: float = 1.0,
126
+ cond_hidden_states=None,
127
+ sa_hidden_states=None,
128
+ ) -> torch.FloatTensor:
129
+ residual = hidden_states
130
+ if attn.spatial_norm is not None:
131
+ hidden_states = attn.spatial_norm(hidden_states, temb)
132
+
133
+ input_ndim = hidden_states.ndim
134
+
135
+ if input_ndim == 4:
136
+ batch_size, channel, height, width = hidden_states.shape
137
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
138
+
139
+ batch_size, sequence_length, _ = (
140
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
141
+ )
142
+
143
+ if attention_mask is not None:
144
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
145
+ # scaled_dot_product_attention expects attention_mask shape to be
146
+ # (batch, heads, source_length, target_length)
147
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
148
+
149
+ if attn.group_norm is not None:
150
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
151
+
152
+ args = () if USE_PEFT_BACKEND else (scale,)
153
+ query = attn.to_q(hidden_states, *args)
154
+
155
+ if encoder_hidden_states is None:
156
+ # for reference adapter
157
+ if sa_hidden_states is not None:
158
+ ref_hidden_states = sa_hidden_states[self.name]
159
+ encoder_hidden_states = torch.cat([hidden_states, ref_hidden_states], dim=1)
160
+ else:
161
+ encoder_hidden_states = hidden_states
162
+
163
+ elif attn.norm_cross:
164
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
165
+
166
+ key = attn.to_k(encoder_hidden_states, *args)
167
+ value = attn.to_v(encoder_hidden_states, *args)
168
+
169
+ inner_dim = key.shape[-1]
170
+ head_dim = inner_dim // attn.heads
171
+
172
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
173
+
174
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
175
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
176
+
177
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
178
+ # TODO: add support for attn.scale when we move to Torch 2.1
179
+ hidden_states = F.scaled_dot_product_attention(
180
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
181
+ )
182
+
183
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
184
+ hidden_states = hidden_states.to(query.dtype)
185
+
186
+ # linear proj
187
+ hidden_states = attn.to_out[0](hidden_states, *args)
188
+ # dropout
189
+ hidden_states = attn.to_out[1](hidden_states)
190
+
191
+ if input_ndim == 4:
192
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
193
+
194
+ if attn.residual_connection:
195
+ hidden_states = hidden_states + residual
196
+
197
+ hidden_states = hidden_states / attn.rescale_output_factor
198
+
199
+ return hidden_states
200
+
201
+
202
+ class CAttnProcessor2_0(torch.nn.Module):
203
+ r"""
204
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
205
+ """
206
+
207
+ def __init__(self, name, hidden_size, cross_attention_dim=None):
208
+ if not hasattr(F, "scaled_dot_product_attention"):
209
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
210
+
211
+ super().__init__()
212
+
213
+ self.name = name
214
+ self.hidden_size = hidden_size
215
+ self.cross_attention_dim = cross_attention_dim
216
+
217
+ # self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
218
+ # self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
219
+
220
+ def __call__(
221
+ self,
222
+ attn,
223
+ hidden_states: torch.FloatTensor,
224
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
225
+ attention_mask: Optional[torch.FloatTensor] = None,
226
+ temb: Optional[torch.FloatTensor] = None,
227
+ scale: float = 1.0,
228
+ cond_hidden_states=None,
229
+ sa_hidden_states=None,
230
+ ) -> torch.FloatTensor:
231
+ residual = hidden_states
232
+ if attn.spatial_norm is not None:
233
+ hidden_states = attn.spatial_norm(hidden_states, temb)
234
+
235
+ input_ndim = hidden_states.ndim
236
+
237
+ if input_ndim == 4:
238
+ batch_size, channel, height, width = hidden_states.shape
239
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
240
+
241
+ batch_size, sequence_length, _ = (
242
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
243
+ )
244
+
245
+ if attention_mask is not None:
246
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
247
+ # scaled_dot_product_attention expects attention_mask shape to be
248
+ # (batch, heads, source_length, target_length)
249
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
250
+
251
+ if attn.group_norm is not None:
252
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
253
+
254
+ args = () if USE_PEFT_BACKEND else (scale,)
255
+ query = attn.to_q(hidden_states, *args)
256
+
257
+ if encoder_hidden_states is None:
258
+ encoder_hidden_states = hidden_states
259
+ elif attn.norm_cross:
260
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
261
+
262
+ key = attn.to_k(encoder_hidden_states, *args)
263
+ value = attn.to_v(encoder_hidden_states, *args)
264
+
265
+ inner_dim = key.shape[-1]
266
+ head_dim = inner_dim // attn.heads
267
+
268
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
269
+
270
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
271
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
272
+
273
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
274
+ # TODO: add support for attn.scale when we move to Torch 2.1
275
+ hidden_states = F.scaled_dot_product_attention(
276
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
277
+ )
278
+
279
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
280
+ hidden_states = hidden_states.to(query.dtype)
281
+
282
+ # for ip
283
+ # if cond_hidden_states:
284
+ # ip_hidden_states = cond_hidden_states
285
+ # ip_key = self.to_k_ip(ip_hidden_states)
286
+ # ip_value = self.to_v_ip(ip_hidden_states)
287
+ # ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
288
+ # ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
289
+ #
290
+ # # the output of sdp = (batch, num_heads, seq_len, head_dim)
291
+ # # TODO: add support for attn.scale when we move to Torch 2.1
292
+ # ip_hidden_states = F.scaled_dot_product_attention(
293
+ # query, ip_key, ip_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
294
+ # )
295
+ # ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
296
+ # ip_hidden_states = ip_hidden_states.to(query.dtype)
297
+ # hidden_states = hidden_states + ip_hidden_states
298
+
299
+ # linear proj
300
+ hidden_states = attn.to_out[0](hidden_states, *args)
301
+ # dropout
302
+ hidden_states = attn.to_out[1](hidden_states)
303
+
304
+ if input_ndim == 4:
305
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
306
+
307
+ if attn.residual_connection:
308
+ hidden_states = hidden_states + residual
309
+
310
+ hidden_states = hidden_states / attn.rescale_output_factor
311
+
312
+ return hidden_states
313
+
314
+
315
+
316
+
317
+
318
+ class RefLoraSAttnProcessor2_0(torch.nn.Module):
319
+ r"""
320
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
321
+ """
322
+
323
+ def __init__(self, name, hidden_size, cross_attention_dim=None, scale=1.0, rank=128, network_alpha=None, lora_scale=1.0,):
324
+ if not hasattr(F, "scaled_dot_product_attention"):
325
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
326
+
327
+ super().__init__()
328
+
329
+ self.name = name
330
+ self.hidden_size = hidden_size
331
+ self.cross_attention_dim = cross_attention_dim
332
+ self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
333
+ self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
334
+ self.scale = scale
335
+
336
+ self.rank = rank
337
+ self.lora_scale = lora_scale
338
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
339
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
340
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
341
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
342
+
343
+ def __call__(
344
+ self,
345
+ attn,
346
+ hidden_states: torch.FloatTensor,
347
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
348
+ attention_mask: Optional[torch.FloatTensor] = None,
349
+ temb: Optional[torch.FloatTensor] = None,
350
+ scale: float = 1.0,
351
+ num_images_per_prompt=1,
352
+ cond_hidden_states=None,
353
+ sa_hidden_states=None,
354
+
355
+ ) -> torch.FloatTensor:
356
+ residual = hidden_states
357
+ if attn.spatial_norm is not None:
358
+ hidden_states = attn.spatial_norm(hidden_states, temb)
359
+
360
+ input_ndim = hidden_states.ndim
361
+
362
+ if input_ndim == 4:
363
+ batch_size, channel, height, width = hidden_states.shape
364
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
365
+
366
+ batch_size, sequence_length, _ = (
367
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
368
+ )
369
+
370
+ if attention_mask is not None:
371
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
372
+ # scaled_dot_product_attention expects attention_mask shape to be
373
+ # (batch, heads, source_length, target_length)
374
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
375
+
376
+ if attn.group_norm is not None:
377
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
378
+
379
+ args = () if USE_PEFT_BACKEND else (scale,)
380
+ query = attn.to_q(hidden_states, *args) + self.lora_scale * self.to_q_lora(hidden_states)
381
+
382
+ if encoder_hidden_states is None:
383
+ encoder_hidden_states = hidden_states
384
+
385
+ elif attn.norm_cross:
386
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
387
+
388
+ key = attn.to_k(encoder_hidden_states, *args) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
389
+ value = attn.to_v(encoder_hidden_states, *args) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
390
+
391
+ inner_dim = key.shape[-1]
392
+ head_dim = inner_dim // attn.heads
393
+
394
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
395
+
396
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
397
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
398
+
399
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
400
+ # TODO: add support for attn.scale when we move to Torch 2.1
401
+ hidden_states = F.scaled_dot_product_attention(
402
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
403
+ )
404
+
405
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
406
+ hidden_states = hidden_states.to(query.dtype)
407
+
408
+ # for ref adapter
409
+ if sa_hidden_states is not None:
410
+ ref_hidden_states = sa_hidden_states[self.name]
411
+ # for ref
412
+ ref_key = self.to_k_ref(ref_hidden_states)
413
+ ref_value = self.to_v_ref(ref_hidden_states)
414
+ ref_key = ref_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
415
+ ref_value = ref_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
416
+
417
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
418
+ # TODO: add support for attn.scale when we move to Torch 2.1
419
+ ref_hidden_states = F.scaled_dot_product_attention(
420
+ query, ref_key, ref_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
421
+ )
422
+ ref_hidden_states = ref_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
423
+ ref_hidden_states = ref_hidden_states.to(query.dtype)
424
+ hidden_states = hidden_states + ref_hidden_states * self.scale
425
+
426
+ # linear proj
427
+ hidden_states = attn.to_out[0](hidden_states, *args) + self.lora_scale * self.to_out_lora(hidden_states)
428
+ # dropout
429
+ hidden_states = attn.to_out[1](hidden_states)
430
+
431
+ if input_ndim == 4:
432
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
433
+
434
+ if attn.residual_connection:
435
+ hidden_states = hidden_states + residual
436
+
437
+ hidden_states = hidden_states / attn.rescale_output_factor
438
+
439
+ return hidden_states
440
+
441
+ class RefSAttnProcessor2_0(torch.nn.Module):
442
+ r"""
443
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
444
+ """
445
+
446
+ def __init__(self, name, hidden_size, cross_attention_dim=None, scale=1.0):
447
+ if not hasattr(F, "scaled_dot_product_attention"):
448
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
449
+
450
+ super().__init__()
451
+
452
+ self.name = name
453
+ self.hidden_size = hidden_size
454
+ self.cross_attention_dim = cross_attention_dim
455
+ self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
456
+ self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
457
+ self.scale = scale
458
+
459
+ def __call__(
460
+ self,
461
+ attn,
462
+ hidden_states: torch.FloatTensor,
463
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
464
+ attention_mask: Optional[torch.FloatTensor] = None,
465
+ temb: Optional[torch.FloatTensor] = None,
466
+ scale: float = 1.0,
467
+ num_images_per_prompt=1,
468
+ cond_hidden_states=None,
469
+ sa_hidden_states=None,
470
+
471
+ ) -> torch.FloatTensor:
472
+ residual = hidden_states
473
+ if attn.spatial_norm is not None:
474
+ hidden_states = attn.spatial_norm(hidden_states, temb)
475
+
476
+ input_ndim = hidden_states.ndim
477
+
478
+ if input_ndim == 4:
479
+ batch_size, channel, height, width = hidden_states.shape
480
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
481
+
482
+ batch_size, sequence_length, _ = (
483
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
484
+ )
485
+
486
+ if attention_mask is not None:
487
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
488
+ # scaled_dot_product_attention expects attention_mask shape to be
489
+ # (batch, heads, source_length, target_length)
490
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
491
+
492
+ if attn.group_norm is not None:
493
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
494
+
495
+ args = () if USE_PEFT_BACKEND else (scale,)
496
+ query = attn.to_q(hidden_states, *args)
497
+
498
+ if encoder_hidden_states is None:
499
+ encoder_hidden_states = hidden_states
500
+
501
+ elif attn.norm_cross:
502
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
503
+
504
+ key = attn.to_k(encoder_hidden_states, *args)
505
+ value = attn.to_v(encoder_hidden_states, *args)
506
+
507
+ inner_dim = key.shape[-1]
508
+ head_dim = inner_dim // attn.heads
509
+
510
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
511
+
512
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
513
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
514
+
515
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
516
+ # TODO: add support for attn.scale when we move to Torch 2.1
517
+ hidden_states = F.scaled_dot_product_attention(
518
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
519
+ )
520
+
521
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
522
+ hidden_states = hidden_states.to(query.dtype)
523
+
524
+ # for ref adapter
525
+ if sa_hidden_states is not None:
526
+ ref_hidden_states = sa_hidden_states[self.name]
527
+ # for ref
528
+ ref_key = self.to_k_ref(ref_hidden_states)
529
+ ref_value = self.to_v_ref(ref_hidden_states)
530
+ ref_key = ref_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
531
+ ref_value = ref_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
532
+
533
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
534
+ # TODO: add support for attn.scale when we move to Torch 2.1
535
+ ref_hidden_states = F.scaled_dot_product_attention(
536
+ query, ref_key, ref_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
537
+ )
538
+ ref_hidden_states = ref_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
539
+ ref_hidden_states = ref_hidden_states.to(query.dtype)
540
+ hidden_states = hidden_states + ref_hidden_states * self.scale
541
+
542
+ # linear proj
543
+ hidden_states = attn.to_out[0](hidden_states, *args)
544
+ # dropout
545
+ hidden_states = attn.to_out[1](hidden_states)
546
+
547
+ if input_ndim == 4:
548
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
549
+
550
+ if attn.residual_connection:
551
+ hidden_states = hidden_states + residual
552
+
553
+ hidden_states = hidden_states / attn.rescale_output_factor
554
+
555
+ return hidden_states
556
+
557
+
558
+
559
+ class IPAttnProcessor2_0(torch.nn.Module):
560
+ r"""
561
+ Attention processor for IP-Adapater for PyTorch 2.0.
562
+ Args:
563
+ hidden_size (`int`):
564
+ The hidden size of the attention layer.
565
+ cross_attention_dim (`int`):
566
+ The number of channels in the `encoder_hidden_states`.
567
+ scale (`float`, defaults to 1.0):
568
+ the weight scale of image prompt.
569
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
570
+ The context length of the image features.
571
+ """
572
+
573
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
574
+ super().__init__()
575
+
576
+ if not hasattr(F, "scaled_dot_product_attention"):
577
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
578
+
579
+ self.hidden_size = hidden_size
580
+ self.cross_attention_dim = cross_attention_dim
581
+ self.scale = scale
582
+ self.num_tokens = num_tokens
583
+
584
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
585
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
586
+
587
+ def __call__(
588
+ self,
589
+ attn,
590
+ hidden_states,
591
+ encoder_hidden_states=None,
592
+ attention_mask=None,
593
+ temb=None,
594
+ sa_hidden_states=None,
595
+ scale: float = 1.0,
596
+ ):
597
+ # attn原始的attn模块
598
+ residual = hidden_states
599
+
600
+ if attn.spatial_norm is not None:
601
+ hidden_states = attn.spatial_norm(hidden_states, temb)
602
+
603
+ input_ndim = hidden_states.ndim
604
+
605
+ if input_ndim == 4:
606
+ batch_size, channel, height, width = hidden_states.shape
607
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
608
+
609
+ batch_size, sequence_length, _ = (
610
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
611
+ )
612
+
613
+ if attention_mask is not None:
614
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
615
+ # scaled_dot_product_attention expects attention_mask shape to be
616
+ # (batch, heads, source_length, target_length)
617
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
618
+
619
+ if attn.group_norm is not None:
620
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
621
+
622
+ query = attn.to_q(hidden_states)
623
+
624
+ if encoder_hidden_states is None:
625
+ if sa_hidden_states is not None:
626
+ ref_hidden_states = sa_hidden_states[self.name]
627
+ # print(ref_hidden_states.shape, hidden_states.shape)
628
+ encoder_hidden_states = torch.cat([hidden_states, ref_hidden_states], dim=1)
629
+ else:
630
+ encoder_hidden_states = hidden_states
631
+ else:
632
+ # get encoder_hidden_states, ip_hidden_states
633
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
634
+ if end_pos != 89:
635
+ encoder_hidden_states = encoder_hidden_states
636
+ ip_hidden_states = None
637
+ else:
638
+ encoder_hidden_states, ip_hidden_states = (
639
+ encoder_hidden_states[:, :end_pos, :],
640
+ encoder_hidden_states[:, end_pos:, :],
641
+ )
642
+ if attn.norm_cross:
643
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
644
+
645
+ key = attn.to_k(encoder_hidden_states)
646
+ value = attn.to_v(encoder_hidden_states)
647
+
648
+ inner_dim = key.shape[-1]
649
+ head_dim = inner_dim // attn.heads
650
+
651
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
652
+
653
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
654
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
655
+
656
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
657
+ # TODO: add support for attn.scale when we move to Torch 2.1
658
+ hidden_states = F.scaled_dot_product_attention(
659
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
660
+ )
661
+
662
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
663
+ hidden_states = hidden_states.to(query.dtype)
664
+
665
+ # make sure the ipa is in the inference stage
666
+ if ip_hidden_states is not None:
667
+ # for ip-adapter
668
+ ip_key = self.to_k_ip(ip_hidden_states)
669
+ ip_value = self.to_v_ip(ip_hidden_states)
670
+
671
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
672
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
673
+
674
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
675
+ # TODO: add support for attn.scale when we move to Torch 2.1
676
+ ip_hidden_states = F.scaled_dot_product_attention(
677
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
678
+ )
679
+ with torch.no_grad():
680
+ self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
681
+ # print(self.attn_map.shape)
682
+
683
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
684
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
685
+
686
+ hidden_states = hidden_states + self.scale * ip_hidden_states
687
+
688
+ # linear proj
689
+ hidden_states = attn.to_out[0](hidden_states)
690
+ # dropout
691
+ hidden_states = attn.to_out[1](hidden_states)
692
+
693
+ if input_ndim == 4:
694
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
695
+
696
+ if attn.residual_connection:
697
+ hidden_states = hidden_states + residual
698
+
699
+ hidden_states = hidden_states / attn.rescale_output_factor
700
+
701
+ return hidden_states
702
+
703
+ class LoRAIPAttnProcessor2_0(nn.Module):
704
+ r"""
705
+ Processor for implementing the LoRA attention mechanism.
706
+
707
+ Args:
708
+ hidden_size (`int`, *optional*):
709
+ The hidden size of the attention layer.
710
+ cross_attention_dim (`int`, *optional*):
711
+ The number of channels in the `encoder_hidden_states`.
712
+ rank (`int`, defaults to 4):
713
+ The dimension of the LoRA update matrices.
714
+ network_alpha (`int`, *optional*):
715
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
716
+ """
717
+
718
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=128, network_alpha=None, lora_scale=1.0, scale=1.0,
719
+ num_tokens=4):
720
+ super().__init__()
721
+
722
+ self.rank = rank
723
+ self.lora_scale = lora_scale
724
+ self.num_tokens = num_tokens
725
+
726
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
727
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
728
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
729
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
730
+
731
+ self.hidden_size = hidden_size
732
+ self.cross_attention_dim = cross_attention_dim
733
+ self.scale = scale
734
+
735
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
736
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
737
+
738
+ def __call__(
739
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None, *args,
740
+ **kwargs,
741
+ ):
742
+ residual = hidden_states
743
+
744
+ if attn.spatial_norm is not None:
745
+ hidden_states = attn.spatial_norm(hidden_states, temb)
746
+
747
+ input_ndim = hidden_states.ndim
748
+
749
+ if input_ndim == 4:
750
+ batch_size, channel, height, width = hidden_states.shape
751
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
752
+
753
+ batch_size, sequence_length, _ = (
754
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
755
+ )
756
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
757
+
758
+ if attn.group_norm is not None:
759
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
760
+
761
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
762
+ # query = attn.head_to_batch_dim(query)
763
+
764
+ if encoder_hidden_states is None:
765
+ encoder_hidden_states = hidden_states
766
+ else:
767
+ # get encoder_hidden_states, ip_hidden_states
768
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
769
+ encoder_hidden_states, ip_hidden_states = (
770
+ encoder_hidden_states[:, :end_pos, :],
771
+ encoder_hidden_states[:, end_pos:, :],
772
+ )
773
+ if attn.norm_cross:
774
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
775
+
776
+ # for text
777
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
778
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
779
+
780
+ inner_dim = key.shape[-1]
781
+ head_dim = inner_dim // attn.heads
782
+
783
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
784
+
785
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
786
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
787
+
788
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
789
+ # TODO: add support for attn.scale when we move to Torch 2.1
790
+ hidden_states = F.scaled_dot_product_attention(
791
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
792
+ )
793
+
794
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
795
+ hidden_states = hidden_states.to(query.dtype)
796
+
797
+ # for ip
798
+ ip_key = self.to_k_ip(ip_hidden_states)
799
+ ip_value = self.to_v_ip(ip_hidden_states)
800
+
801
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
802
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
803
+
804
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
805
+ # TODO: add support for attn.scale when we move to Torch 2.1
806
+ ip_hidden_states = F.scaled_dot_product_attention(
807
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
808
+ )
809
+
810
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
811
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
812
+
813
+ hidden_states = hidden_states + self.scale * ip_hidden_states
814
+
815
+ # linear proj
816
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
817
+ # dropout
818
+ hidden_states = attn.to_out[1](hidden_states)
819
+
820
+ if input_ndim == 4:
821
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
822
+
823
+ if attn.residual_connection:
824
+ hidden_states = hidden_states + residual
825
+
826
+ hidden_states = hidden_states / attn.rescale_output_factor
827
+
828
+ return hidden_states
adapter/resampler.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/5/13
3
+ # @Author : White Jiang
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head ** -0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads) # [b, h, n, c]
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v # [b, h, n, n] @ [b, h, n, c] = [b, h, n, c]
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class PerceiverResampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ *,
85
+ dim=1024,
86
+ depth=8,
87
+ dim_head=64,
88
+ heads=16,
89
+ num_latents=8,
90
+ embedding_dim=768,
91
+ output_dim=1024,
92
+ ff_mult=4,
93
+ ):
94
+ super().__init__()
95
+
96
+ self.latents = nn.Parameter(torch.randn(1, num_latents, dim) / dim ** 0.5)
97
+
98
+ self.proj_in = nn.Linear(embedding_dim, dim)
99
+
100
+ self.proj_out = nn.Linear(dim, output_dim)
101
+ self.norm_out = nn.LayerNorm(output_dim)
102
+
103
+ self.layers = nn.ModuleList([])
104
+ for _ in range(depth):
105
+ self.layers.append(
106
+ nn.ModuleList(
107
+ [
108
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
109
+ FeedForward(dim=dim, mult=ff_mult),
110
+ ]
111
+ )
112
+ )
113
+
114
+ def forward(self, x):
115
+
116
+ latents = self.latents.repeat(x.size(0), 1, 1)
117
+
118
+ x = self.proj_in(x)
119
+
120
+ for attn, ff in self.layers:
121
+ latents = attn(x, latents) + latents
122
+ latents = ff(latents) + latents
123
+
124
+ latents = self.proj_out(latents)
125
+ return self.norm_out(latents)
126
+
127
+
128
+ class FacePerceiverResampler(nn.Module):
129
+ def __init__(
130
+ self,
131
+ *,
132
+ dim=768,
133
+ depth=4,
134
+ dim_head=64,
135
+ heads=16,
136
+ embedding_dim=1280,
137
+ output_dim=768,
138
+ ff_mult=4,
139
+ ):
140
+ super().__init__()
141
+
142
+ self.proj_in = nn.Linear(embedding_dim, dim)
143
+
144
+ self.proj_out = nn.Linear(dim, output_dim)
145
+ self.norm_out = nn.LayerNorm(output_dim)
146
+
147
+ self.layers = nn.ModuleList([])
148
+ for _ in range(depth):
149
+ self.layers.append(
150
+ nn.ModuleList(
151
+ [
152
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
153
+ FeedForward(dim=dim, mult=ff_mult),
154
+ ]
155
+ )
156
+ )
157
+
158
+ def forward(self, latents, x):
159
+
160
+ x = self.proj_in(x)
161
+
162
+ for attn, ff in self.layers:
163
+ latents = attn(x, latents) + latents
164
+ latents = ff(latents) + latents
165
+
166
+ latents = self.proj_out(latents)
167
+ return self.norm_out(latents)
168
+
169
+
170
+ class Resampler(nn.Module):
171
+ def __init__(
172
+ self,
173
+ dim=1024,
174
+ depth=8,
175
+ dim_head=64,
176
+ heads=16,
177
+ num_queries=8,
178
+ embedding_dim=768,
179
+ output_dim=1024,
180
+ ff_mult=4,
181
+ max_seq_len: int = 257, # CLIP tokens + CLS token
182
+ apply_pos_emb: bool = False,
183
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
184
+ ):
185
+ super().__init__()
186
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
187
+
188
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
189
+
190
+ self.proj_in = nn.Linear(embedding_dim, dim)
191
+
192
+ self.proj_out = nn.Linear(dim, output_dim)
193
+ self.norm_out = nn.LayerNorm(output_dim)
194
+
195
+ self.to_latents_from_mean_pooled_seq = (
196
+ nn.Sequential(
197
+ nn.LayerNorm(dim),
198
+ nn.Linear(dim, dim * num_latents_mean_pooled),
199
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
200
+ )
201
+ if num_latents_mean_pooled > 0
202
+ else None
203
+ )
204
+
205
+ self.layers = nn.ModuleList([])
206
+ for _ in range(depth):
207
+ self.layers.append(
208
+ nn.ModuleList(
209
+ [
210
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
211
+ FeedForward(dim=dim, mult=ff_mult),
212
+ ]
213
+ )
214
+ )
215
+
216
+ def forward(self, x):
217
+ if self.pos_emb is not None:
218
+ n, device = x.shape[1], x.device
219
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
220
+ x = x + pos_emb
221
+
222
+ latents = self.latents.repeat(x.size(0), 1, 1)
223
+
224
+ x = self.proj_in(x)
225
+
226
+ if self.to_latents_from_mean_pooled_seq:
227
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
228
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
229
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
230
+
231
+ for attn, ff in self.layers:
232
+ latents = attn(x, latents) + latents
233
+ latents = ff(latents) + latents
234
+
235
+ latents = self.proj_out(latents)
236
+ return self.norm_out(latents)
237
+
238
+
239
+ def masked_mean(t, *, dim, mask=None):
240
+ if mask is None:
241
+ return t.mean(dim=dim)
242
+
243
+ denom = mask.sum(dim=dim, keepdim=True)
244
+ mask = rearrange(mask, "b n -> b n 1")
245
+ masked_t = t.masked_fill(~mask, 0.0)
246
+
247
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
248
+
249
+
250
+ class ProjPlusModel(torch.nn.Module):
251
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
252
+ super().__init__()
253
+
254
+ self.cross_attention_dim = cross_attention_dim
255
+ self.num_tokens = num_tokens
256
+
257
+ self.proj = torch.nn.Sequential(
258
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
259
+ torch.nn.GELU(),
260
+ torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
261
+ )
262
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
263
+
264
+ self.perceiver_resampler = FacePerceiverResampler(
265
+ dim=cross_attention_dim,
266
+ depth=4,
267
+ dim_head=64,
268
+ heads=cross_attention_dim // 64,
269
+ embedding_dim=clip_embeddings_dim,
270
+ output_dim=cross_attention_dim,
271
+ ff_mult=4,
272
+ )
273
+
274
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
275
+ x = self.proj(id_embeds)
276
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
277
+ x = self.norm(x)
278
+ out = self.perceiver_resampler(x, clip_embeds)
279
+ if shortcut:
280
+ out = x + scale * out
281
+ return out
282
+
283
+
284
+ if __name__ == "__main__":
285
+ model = PerceiverResampler(
286
+ dim=1024,
287
+ depth=8,
288
+ dim_head=64,
289
+ heads=16,
290
+ num_latents=8,
291
+ embedding_dim=4096,
292
+ output_dim=1024,
293
+ ff_mult=4,
294
+ )
295
+
296
+ x = torch.rand(2, 77, 4096)
297
+
298
+ with torch.no_grad():
299
+ out = model(x)
300
+ print(out.shape)
301
+
302
+ print(sum([p.numel() for p in model.parameters()]) / 1e6)
dressing_sd/pipelines/__pycache__/pipeline_sd.cpython-39.pyc ADDED
Binary file (15.4 kB). View file
 
dressing_sd/pipelines/pipeline_sd.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/5/31
3
+ # @Author : White Jiang
4
+ from diffusers.schedulers import (
5
+ DDIMScheduler,
6
+ DPMSolverMultistepScheduler,
7
+ EulerAncestralDiscreteScheduler,
8
+ EulerDiscreteScheduler,
9
+ LMSDiscreteScheduler,
10
+ PNDMScheduler,
11
+ )
12
+ from diffusers.utils import is_accelerate_available
13
+ from diffusers.pipelines.controlnet.pipeline_controlnet import *
14
+
15
+ import os
16
+ import sys
17
+ from safetensors import safe_open
18
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
+ sys.path.append(BASE_DIR)
20
+
21
+ from adapter.resampler import ProjPlusModel
22
+ from adapter.attention_processor import RefSAttnProcessor2_0, RefLoraSAttnProcessor2_0, IPAttnProcessor2_0, LoRAIPAttnProcessor2_0
23
+
24
+
25
+ class PipIpaControlNet(StableDiffusionControlNetPipeline):
26
+ _optional_components = []
27
+
28
+ def __init__(
29
+ self,
30
+ vae,
31
+ reference_unet,
32
+ unet,
33
+ tokenizer,
34
+ text_encoder,
35
+ controlnet,
36
+ image_encoder,
37
+ ImgProj,
38
+ ip_ckpt,
39
+ scheduler: Union[
40
+ DDIMScheduler,
41
+ PNDMScheduler,
42
+ LMSDiscreteScheduler,
43
+ EulerDiscreteScheduler,
44
+ EulerAncestralDiscreteScheduler,
45
+ DPMSolverMultistepScheduler,
46
+ ],
47
+ safety_checker: StableDiffusionSafetyChecker,
48
+ feature_extractor: CLIPImageProcessor,
49
+ ):
50
+ super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler, safety_checker, feature_extractor)
51
+
52
+ self.register_modules(
53
+ vae=vae,
54
+ reference_unet=reference_unet,
55
+ unet=unet,
56
+ controlnet=controlnet,
57
+ scheduler=scheduler,
58
+ tokenizer=tokenizer,
59
+ text_encoder=text_encoder,
60
+ image_encoder=image_encoder,
61
+ ImgProj=ImgProj,
62
+ safety_checker=safety_checker,
63
+ feature_extractor=feature_extractor
64
+ )
65
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
66
+ self.clip_image_processor = CLIPImageProcessor()
67
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
68
+ self.ref_image_processor = VaeImageProcessor(
69
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False,
70
+ )
71
+ self.cond_image_processor = VaeImageProcessor(
72
+ vae_scale_factor=self.vae_scale_factor,
73
+ do_convert_rgb=True,
74
+ do_normalize=False,
75
+ )
76
+ self.ip_ckpt = ip_ckpt
77
+ self.num_tokens = 4
78
+ # image proj model
79
+ self.image_proj_model = self.init_proj()
80
+ self.load_ip_adapter()
81
+
82
+ def init_proj(self):
83
+ image_proj_model = ProjPlusModel(
84
+ cross_attention_dim=self.unet.config.cross_attention_dim,
85
+ id_embeddings_dim=512,
86
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
87
+ num_tokens=self.num_tokens,
88
+ ).to(self.unet.device, dtype=torch.float16)
89
+ return image_proj_model
90
+
91
+ def load_ip_adapter(self):
92
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
93
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
94
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
95
+ for key in f.keys():
96
+ if key.startswith("image_proj."):
97
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
98
+ elif key.startswith("ip_adapter."):
99
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
100
+ else:
101
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
102
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
103
+ ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
104
+ ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
105
+
106
+ @property
107
+ def cross_attention_kwargs(self):
108
+ return self._cross_attention_kwargs
109
+
110
+ def enable_vae_slicing(self):
111
+ self.vae.enable_slicing()
112
+
113
+ def disable_vae_slicing(self):
114
+ self.vae.disable_slicing()
115
+
116
+ def enable_sequential_cpu_offload(self, gpu_id=0):
117
+ if is_accelerate_available():
118
+ from accelerate import cpu_offload
119
+ else:
120
+ raise ImportError("Please install accelerate via `pip install accelerate`")
121
+
122
+ device = torch.device(f"cuda:{gpu_id}")
123
+
124
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
125
+ if cpu_offloaded_model is not None:
126
+ cpu_offload(cpu_offloaded_model, device)
127
+
128
+ @property
129
+ def _execution_device(self):
130
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
131
+ return self.device
132
+ for module in self.unet.modules():
133
+ if (
134
+ hasattr(module, "_hf_hook")
135
+ and hasattr(module._hf_hook, "execution_device")
136
+ and module._hf_hook.execution_device is not None
137
+ ):
138
+ return torch.device(module._hf_hook.execution_device)
139
+ return self.device
140
+
141
+ def prepare_extra_step_kwargs(self, generator, eta):
142
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
143
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
144
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
145
+ # and should be between [0, 1]
146
+
147
+ accepts_eta = "eta" in set(
148
+ inspect.signature(self.scheduler.step).parameters.keys()
149
+ )
150
+ extra_step_kwargs = {}
151
+ if accepts_eta:
152
+ extra_step_kwargs["eta"] = eta
153
+
154
+ # check if the scheduler accepts generator
155
+ accepts_generator = "generator" in set(
156
+ inspect.signature(self.scheduler.step).parameters.keys()
157
+ )
158
+ if accepts_generator:
159
+ extra_step_kwargs["generator"] = generator
160
+ return extra_step_kwargs
161
+
162
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
163
+
164
+ def encode_prompt(
165
+ self,
166
+ prompt,
167
+ device,
168
+ num_images_per_prompt,
169
+ do_classifier_free_guidance,
170
+ negative_prompt=None,
171
+ prompt_embeds: Optional[torch.FloatTensor] = None,
172
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
173
+ lora_scale: Optional[float] = None,
174
+ clip_skip: Optional[int] = None,
175
+ ):
176
+
177
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
178
+ self._lora_scale = lora_scale
179
+
180
+ # dynamically adjust the LoRA scale
181
+ if not USE_PEFT_BACKEND:
182
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
183
+ else:
184
+ scale_lora_layers(self.text_encoder, lora_scale)
185
+
186
+ if prompt is not None and isinstance(prompt, str):
187
+ batch_size = 1
188
+ elif prompt is not None and isinstance(prompt, list):
189
+ batch_size = len(prompt)
190
+ else:
191
+ batch_size = prompt_embeds.shape[0]
192
+
193
+ if prompt_embeds is None:
194
+ # textual inversion: procecss multi-vector tokens if necessary
195
+ if isinstance(self, TextualInversionLoaderMixin):
196
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
197
+
198
+ text_inputs = self.tokenizer(
199
+ prompt,
200
+ padding="max_length",
201
+ max_length=self.tokenizer.model_max_length,
202
+ truncation=True,
203
+ return_tensors="pt",
204
+ )
205
+ text_input_ids = text_inputs.input_ids
206
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
207
+
208
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
209
+ text_input_ids, untruncated_ids
210
+ ):
211
+ removed_text = self.tokenizer.batch_decode(
212
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
213
+ )
214
+ logger.warning(
215
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
216
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
217
+ )
218
+
219
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
220
+ attention_mask = text_inputs.attention_mask.to(device)
221
+ else:
222
+ attention_mask = None
223
+
224
+ if clip_skip is None:
225
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
226
+ prompt_embeds = prompt_embeds[0]
227
+ else:
228
+ prompt_embeds = self.text_encoder(
229
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
230
+ )
231
+ # Access the `hidden_states` first, that contains a tuple of
232
+ # all the hidden states from the encoder layers. Then index into
233
+ # the tuple to access the hidden states from the desired layer.
234
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
235
+ # We also need to apply the final LayerNorm here to not mess with the
236
+ # representations. The `last_hidden_states` that we typically use for
237
+ # obtaining the final prompt representations passes through the LayerNorm
238
+ # layer.
239
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
240
+
241
+ if self.text_encoder is not None:
242
+ prompt_embeds_dtype = self.text_encoder.dtype
243
+ elif self.unet is not None:
244
+ prompt_embeds_dtype = self.unet.dtype
245
+ else:
246
+ prompt_embeds_dtype = prompt_embeds.dtype
247
+
248
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
249
+
250
+ bs_embed, seq_len, _ = prompt_embeds.shape
251
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
252
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
253
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
254
+
255
+ # get unconditional embeddings for classifier free guidance
256
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
257
+ uncond_tokens: List[str]
258
+ if negative_prompt is None:
259
+ uncond_tokens = [""] * batch_size
260
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
261
+ raise TypeError(
262
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
263
+ f" {type(prompt)}."
264
+ )
265
+ elif isinstance(negative_prompt, str):
266
+ uncond_tokens = [negative_prompt]
267
+ elif batch_size != len(negative_prompt):
268
+ raise ValueError(
269
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
270
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
271
+ " the batch size of `prompt`."
272
+ )
273
+ else:
274
+ uncond_tokens = negative_prompt
275
+
276
+ # textual inversion: procecss multi-vector tokens if necessary
277
+ if isinstance(self, TextualInversionLoaderMixin):
278
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
279
+
280
+ max_length = prompt_embeds.shape[1]
281
+ uncond_input = self.tokenizer(
282
+ uncond_tokens,
283
+ padding="max_length",
284
+ max_length=max_length,
285
+ truncation=True,
286
+ return_tensors="pt",
287
+ )
288
+
289
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
290
+ attention_mask = uncond_input.attention_mask.to(device)
291
+ else:
292
+ attention_mask = None
293
+
294
+ negative_prompt_embeds = self.text_encoder(
295
+ uncond_input.input_ids.to(device),
296
+ attention_mask=attention_mask,
297
+ )
298
+ negative_prompt_embeds = negative_prompt_embeds[0]
299
+
300
+ if do_classifier_free_guidance:
301
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
302
+ seq_len = negative_prompt_embeds.shape[1]
303
+
304
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
305
+
306
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
307
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
308
+
309
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
310
+ # Retrieve the original scale by scaling back the LoRA layers
311
+ unscale_lora_layers(self.text_encoder, lora_scale)
312
+
313
+ return prompt_embeds, negative_prompt_embeds
314
+
315
+ def prepare_latents(
316
+ self,
317
+ batch_size,
318
+ num_channels_latents,
319
+ width,
320
+ height,
321
+ dtype,
322
+ device,
323
+ generator,
324
+ latents=None,
325
+ ):
326
+ shape = (
327
+ batch_size,
328
+ num_channels_latents,
329
+ height // self.vae_scale_factor,
330
+ width // self.vae_scale_factor,
331
+ )
332
+ if isinstance(generator, list) and len(generator) != batch_size:
333
+ raise ValueError(
334
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
335
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
336
+ )
337
+
338
+ if latents is None:
339
+ latents = randn_tensor(
340
+ shape, generator=generator, device=device, dtype=dtype
341
+ )
342
+ else:
343
+ latents = latents.to(device)
344
+
345
+ # scale the initial noise by the standard deviation required by the scheduler
346
+ latents = latents * self.scheduler.init_noise_sigma
347
+ return latents
348
+
349
+ def prepare_condition(
350
+ self,
351
+ cond_image,
352
+ width,
353
+ height,
354
+ device,
355
+ dtype,
356
+ do_classififer_free_guidance=False,
357
+ ):
358
+ image = self.cond_image_processor.preprocess(
359
+ cond_image, height=height, width=width
360
+ ).to(dtype=torch.float32)
361
+
362
+ image = image.to(device=device, dtype=dtype)
363
+
364
+ if do_classififer_free_guidance:
365
+ image = torch.cat([image] * 2)
366
+
367
+ return image
368
+
369
+ def get_image_embeds(self, clip_image=None, faceid_embeds=None):
370
+ with torch.no_grad():
371
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16),
372
+ output_hidden_states=True).hidden_states[-2]
373
+ uncond_clip_image_embeds = self.image_encoder(
374
+ torch.zeros_like(clip_image).to(self.device, dtype=torch.float16), output_hidden_states=True
375
+ ).hidden_states[-2]
376
+
377
+ faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16)
378
+ image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds)
379
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds),uncond_clip_image_embeds)
380
+ return image_prompt_embeds, uncond_image_prompt_embeds
381
+
382
+ def set_scale(self, scale, lora_scale):
383
+ for attn_processor in self.unet.attn_processors.values():
384
+ if isinstance(attn_processor, RefLoraSAttnProcessor2_0):
385
+ attn_processor.scale = scale
386
+ attn_processor.lora_scale = lora_scale
387
+ # elif isinstance(attn_processor, RefCAttnProcessor2_0):
388
+ # attn_processor.scale = scale
389
+
390
+ def set_ipa_scale(self, ipa_scale, lora_scale):
391
+ for attn_processor in self.unet.attn_processors.values():
392
+ if isinstance(attn_processor, LoRAIPAttnProcessor2_0):
393
+ attn_processor.scale = ipa_scale
394
+ attn_processor.lora_scale = lora_scale
395
+ elif isinstance(attn_processor, IPAttnProcessor2_0):
396
+ attn_processor.scale = ipa_scale
397
+ attn_processor.lora_scale = lora_scale
398
+
399
+ @torch.no_grad()
400
+ def __call__(
401
+ self,
402
+ prompt,
403
+ null_prompt,
404
+ negative_prompt,
405
+ ref_image,
406
+ width,
407
+ height,
408
+ num_inference_steps,
409
+ guidance_scale,
410
+ pose_image=None,
411
+ ref_clip_image=None,
412
+ face_clip_image=None,
413
+ faceid_embeds=None,
414
+ num_images_per_prompt=1,
415
+ image_scale=1.0,
416
+ ipa_scale=0.0,
417
+ s_lora_scale=0.0,
418
+ c_lora_scale=0.0,
419
+ num_samples=1,
420
+ eta: float = 0.0,
421
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
422
+ output_type: Optional[str] = "pil",
423
+ return_dict: bool = True,
424
+ clip_skip: Optional[int] = None,
425
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
426
+ callback_steps: Optional[int] = 1,
427
+ prompt_embeds: Optional[torch.FloatTensor] = None,
428
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
429
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
430
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
431
+ guess_mode: bool = False,
432
+ control_guidance_start: Union[float, List[float]] = 0.0,
433
+ control_guidance_end: Union[float, List[float]] = 1.0,
434
+ **kwargs,
435
+ ):
436
+
437
+ if face_clip_image is None:
438
+ self.set_scale(image_scale, lora_scale=0.0)
439
+ self.set_ipa_scale(ipa_scale=0.0, lora_scale=0.0)
440
+ else:
441
+ self.set_scale(image_scale, lora_scale=s_lora_scale)
442
+ self.set_ipa_scale(ipa_scale, lora_scale=c_lora_scale)
443
+
444
+ # controlnet
445
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
446
+ # align format for control guidance
447
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
448
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
449
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
450
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
451
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
452
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
453
+ control_guidance_start, control_guidance_end = (
454
+ mult * [control_guidance_start],
455
+ mult * [control_guidance_end],
456
+ )
457
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
458
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
459
+
460
+ global_pool_conditions = (
461
+ controlnet.config.global_pool_conditions
462
+ if isinstance(controlnet, ControlNetModel)
463
+ else controlnet.nets[0].config.global_pool_conditions
464
+ )
465
+ guess_mode = guess_mode or global_pool_conditions
466
+ # Default height and width to unet
467
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
468
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
469
+
470
+ device = self._execution_device
471
+ self._cross_attention_kwargs = cross_attention_kwargs
472
+ self._clip_skip = clip_skip
473
+ do_classifier_free_guidance = guidance_scale > 1.0
474
+
475
+ # Prepare timesteps
476
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
477
+ timesteps = self.scheduler.timesteps
478
+
479
+ batch_size = 1
480
+ if pose_image is not None:
481
+ # Prepare control image
482
+ if isinstance(controlnet, ControlNetModel):
483
+ image = self.prepare_image(
484
+ image=pose_image,
485
+ width=width,
486
+ height=height,
487
+ batch_size=batch_size * num_images_per_prompt,
488
+ num_images_per_prompt=num_images_per_prompt,
489
+ device=device,
490
+ dtype=controlnet.dtype,
491
+ do_classifier_free_guidance=do_classifier_free_guidance,
492
+ guess_mode=guess_mode,
493
+ )
494
+ if do_classifier_free_guidance and not guess_mode:
495
+ image = image.chunk(2)[0]
496
+ height, width = image.shape[-2:]
497
+ else:
498
+ assert False
499
+ # print(image.shape)
500
+
501
+ # 3. Encode input prompt
502
+ text_encoder_lora_scale = (
503
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
504
+ )
505
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
506
+ prompt,
507
+ device,
508
+ num_images_per_prompt,
509
+ do_classifier_free_guidance,
510
+ negative_prompt,
511
+ prompt_embeds=prompt_embeds,
512
+ negative_prompt_embeds=negative_prompt_embeds,
513
+ lora_scale=text_encoder_lora_scale,
514
+ clip_skip=self.clip_skip,
515
+ )
516
+
517
+ if face_clip_image is not None:
518
+ # for face image condition
519
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(face_clip_image, faceid_embeds)
520
+
521
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
522
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
523
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
524
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
525
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
526
+
527
+ if ref_clip_image is not None:
528
+ with torch.no_grad():
529
+ image_embeds = self.image_encoder(ref_clip_image.to(device, dtype=prompt_embeds.dtype),
530
+ output_hidden_states=True).hidden_states[-2]
531
+ image_null_embeds = \
532
+ self.image_encoder(torch.zeros_like(ref_clip_image).to(device, dtype=prompt_embeds.dtype),
533
+ output_hidden_states=True).hidden_states[-2]
534
+ cloth_proj_embed = self.ImgProj(image_embeds)
535
+ cloth_null_embeds = self.ImgProj(image_null_embeds)
536
+ # cloth_null_embeds = self.ImgProj(torch.zeros_like(image_embeds))
537
+ else:
538
+ null_prompt_embeds, _ = self.encode_prompt(
539
+ null_prompt,
540
+ device,
541
+ num_images_per_prompt,
542
+ do_classifier_free_guidance,
543
+ negative_prompt,
544
+ prompt_embeds=prompt_embeds,
545
+ negative_prompt_embeds=negative_prompt_embeds,
546
+ lora_scale=text_encoder_lora_scale,
547
+ clip_skip=self.clip_skip,
548
+ )
549
+
550
+ # For classifier free guidance, we need to do two forward passes.
551
+ # Here we concatenate the unconditional and text embeddings into a single batch
552
+ # to avoid doing two forward passes
553
+ if do_classifier_free_guidance:
554
+ prompt_embeds_control = torch.cat([negative_prompt_embeds, prompt_embeds])
555
+ if ref_clip_image is not None:
556
+ null_prompt_embeds = torch.cat([cloth_null_embeds, cloth_proj_embed])
557
+ else:
558
+ null_prompt_embeds = torch.cat([negative_prompt_embeds, null_prompt_embeds])
559
+ if face_clip_image is not None:
560
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
561
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
562
+ else:
563
+ prompt_embeds = prompt_embeds
564
+ negative_prompt_embeds = negative_prompt_embeds
565
+
566
+ num_channels_latents = self.unet.in_channels
567
+ latents = self.prepare_latents(
568
+ batch_size * num_images_per_prompt,
569
+ num_channels_latents,
570
+ width,
571
+ height,
572
+ prompt_embeds.dtype,
573
+ device,
574
+ generator,
575
+ )
576
+
577
+ # Prepare extra step kwargs.
578
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
579
+
580
+ # Prepare ref image latents
581
+ ref_image_tensor = ref_image.to(
582
+ dtype=self.vae.dtype, device=self.vae.device
583
+ )
584
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
585
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
586
+ if pose_image is not None:
587
+ # Create tensor stating which controlnets to keep
588
+ controlnet_keep = []
589
+ for i in range(len(timesteps)):
590
+ keeps = [
591
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
592
+ for s, e in zip(control_guidance_start, control_guidance_end)
593
+ ]
594
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
595
+
596
+ # denoising loop
597
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
598
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
599
+ for i, t in enumerate(timesteps):
600
+ # 1. Forward reference image
601
+ if i == 0:
602
+ _ = self.reference_unet(
603
+ ref_image_latents.repeat(
604
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
605
+ ),
606
+ torch.zeros_like(t),
607
+ encoder_hidden_states=null_prompt_embeds,
608
+ return_dict=False,
609
+ )
610
+
611
+ # get cache tensors
612
+ sa_hidden_states = {}
613
+ for name in self.reference_unet.attn_processors.keys():
614
+ sa_hidden_states[name] = self.reference_unet.attn_processors[name].cache["hidden_states"][
615
+ 1].unsqueeze(0)
616
+ # sa_hidden_states[name][0, :, :] = 0
617
+
618
+ # 3.1 expand the latents if we are doing classifier free guidance
619
+ latent_model_input = (
620
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
621
+ )
622
+ latent_model_input = self.scheduler.scale_model_input(
623
+ latent_model_input, t
624
+ )
625
+
626
+ # Optionally get Guidance Scale Embedding
627
+ timestep_cond = None
628
+ if self.unet.config.time_cond_proj_dim is not None:
629
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
630
+ batch_size * num_images_per_prompt)
631
+ timestep_cond = self.get_guidance_scale_embedding(
632
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
633
+ ).to(device=device, dtype=latents.dtype)
634
+
635
+ # for control
636
+ if pose_image is not None:
637
+ # controlnet(s) inference
638
+ if guess_mode and self.do_classifier_free_guidance:
639
+ # Infer ControlNet only for the conditional batch.
640
+ control_model_input = latents
641
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
642
+ controlnet_prompt_embeds = prompt_embeds_control.chunk(2)[1]
643
+ # controlnet_prompt_embeds = prompt_embeds
644
+ else:
645
+ control_model_input = latent_model_input
646
+ controlnet_prompt_embeds = prompt_embeds_control
647
+ if isinstance(controlnet_keep[i], list):
648
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
649
+ else:
650
+ controlnet_cond_scale = controlnet_conditioning_scale
651
+ if isinstance(controlnet_cond_scale, list):
652
+ controlnet_cond_scale = controlnet_cond_scale[0]
653
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
654
+
655
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
656
+ control_model_input,
657
+ t,
658
+ encoder_hidden_states=controlnet_prompt_embeds,
659
+ controlnet_cond=image,
660
+ conditioning_scale=cond_scale,
661
+ guess_mode=guess_mode,
662
+ return_dict=False,
663
+ )
664
+
665
+ # if do_classifier_free_guidance:
666
+ down_block_res_samples_con = []
667
+ down_block_res_samples_uncon = []
668
+ for down_block in down_block_res_samples:
669
+ down_block_res_samples_con.append(down_block[1])
670
+ down_block_res_samples_uncon.append(down_block[0])
671
+ # for prompt_embeds ref + text
672
+ noise_pred = self.unet(
673
+ latent_model_input[0].unsqueeze(0),
674
+ t,
675
+ encoder_hidden_states=prompt_embeds,
676
+ cross_attention_kwargs={
677
+ "sa_hidden_states": sa_hidden_states,
678
+ },
679
+ timestep_cond=timestep_cond,
680
+ down_block_additional_residuals=down_block_res_samples_con,
681
+ mid_block_additional_residual=mid_block_res_sample[1],
682
+ added_cond_kwargs=None,
683
+ return_dict=False,
684
+ )[0]
685
+ # for negative_prompt_embeds non text
686
+ unc_noise_pred = self.unet(
687
+ latent_model_input[1].unsqueeze(0),
688
+ t,
689
+ encoder_hidden_states=negative_prompt_embeds,
690
+ timestep_cond=timestep_cond,
691
+ down_block_additional_residuals=down_block_res_samples_uncon,
692
+ mid_block_additional_residual=mid_block_res_sample[0],
693
+ added_cond_kwargs=None,
694
+ return_dict=False,
695
+ )[0]
696
+ # for no control
697
+ else:
698
+ noise_pred = self.unet(
699
+ latent_model_input[1].unsqueeze(0),
700
+ t,
701
+ encoder_hidden_states=prompt_embeds,
702
+ cross_attention_kwargs={
703
+ "sa_hidden_states": sa_hidden_states,
704
+ },
705
+ timestep_cond=timestep_cond,
706
+ added_cond_kwargs=None,
707
+ return_dict=False,
708
+ )[0]
709
+ # for negative_prompt_embeds non text
710
+ unc_noise_pred = self.unet(
711
+ latent_model_input[0].unsqueeze(0),
712
+ t,
713
+ encoder_hidden_states=negative_prompt_embeds,
714
+ timestep_cond=timestep_cond,
715
+ added_cond_kwargs=None,
716
+ return_dict=False,
717
+ )[0]
718
+
719
+ # perform guidance
720
+ if do_classifier_free_guidance:
721
+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
722
+ noise_pred_uncond, noise_pred_text = unc_noise_pred, noise_pred
723
+
724
+ noise_pred = noise_pred_uncond + guidance_scale * (
725
+ noise_pred_text - noise_pred_uncond
726
+ )
727
+
728
+ # compute the previous noisy sample x_t -> x_t-1
729
+ latents = self.scheduler.step(
730
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
731
+ )[0]
732
+
733
+ # call the callback, if provided
734
+ if i == len(timesteps) - 1 or (
735
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
736
+ ):
737
+ progress_bar.update()
738
+ if callback is not None and i % callback_steps == 0:
739
+ step_idx = i // getattr(self.scheduler, "order", 1)
740
+ callback(step_idx, t, latents)
741
+
742
+ # Post-processing
743
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
744
+ do_denormalize = [True] * image.shape[0]
745
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
746
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
747
+
748
+
requirements.txt CHANGED
@@ -4,4 +4,5 @@ invisible_watermark
4
  torch
5
  transformers
6
  xformers
7
- modelscope
 
 
4
  torch
5
  transformers
6
  xformers
7
+ addict
8
+ insightface