yrr commited on
Commit
7f48662
1 Parent(s): 2c2ec6c
OmniGen/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .model import OmniGen
2
+ from .processor import OmniGenProcessor
3
+ from .scheduler import OmniGenScheduler
4
+ from .pipeline import OmniGenPipeline
OmniGen/model.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The code is revised from DiT
2
+ import os
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import math
7
+ from typing import Dict
8
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
9
+
10
+ from OmniGen.transformer import Phi3Config, Phi3Transformer
11
+
12
+
13
+ def modulate(x, shift, scale):
14
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
15
+
16
+
17
+ class TimestepEmbedder(nn.Module):
18
+ """
19
+ Embeds scalar timesteps into vector representations.
20
+ """
21
+ def __init__(self, hidden_size, frequency_embedding_size=256):
22
+ super().__init__()
23
+ self.mlp = nn.Sequential(
24
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
25
+ nn.SiLU(),
26
+ nn.Linear(hidden_size, hidden_size, bias=True),
27
+ )
28
+ self.frequency_embedding_size = frequency_embedding_size
29
+
30
+ @staticmethod
31
+ def timestep_embedding(t, dim, max_period=10000):
32
+ """
33
+ Create sinusoidal timestep embeddings.
34
+ :param t: a 1-D Tensor of N indices, one per batch element.
35
+ These may be fractional.
36
+ :param dim: the dimension of the output.
37
+ :param max_period: controls the minimum frequency of the embeddings.
38
+ :return: an (N, D) Tensor of positional embeddings.
39
+ """
40
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
41
+ half = dim // 2
42
+ freqs = torch.exp(
43
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
44
+ ).to(device=t.device)
45
+ args = t[:, None].float() * freqs[None]
46
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
47
+ if dim % 2:
48
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
49
+ return embedding
50
+
51
+ def forward(self, t, dtype=torch.float32):
52
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
53
+ t_emb = self.mlp(t_freq)
54
+ return t_emb
55
+
56
+
57
+ class FinalLayer(nn.Module):
58
+ """
59
+ The final layer of DiT.
60
+ """
61
+ def __init__(self, hidden_size, patch_size, out_channels):
62
+ super().__init__()
63
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
64
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
65
+ self.adaLN_modulation = nn.Sequential(
66
+ nn.SiLU(),
67
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
68
+ )
69
+
70
+ def forward(self, x, c):
71
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
72
+ x = modulate(self.norm_final(x), shift, scale)
73
+ x = self.linear(x)
74
+ return x
75
+
76
+
77
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
78
+ """
79
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
80
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
81
+ """
82
+ if isinstance(grid_size, int):
83
+ grid_size = (grid_size, grid_size)
84
+
85
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
86
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
87
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
88
+ grid = np.stack(grid, axis=0)
89
+
90
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
91
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
92
+ if cls_token and extra_tokens > 0:
93
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
94
+ return pos_embed
95
+
96
+
97
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
98
+ assert embed_dim % 2 == 0
99
+
100
+ # use half of dimensions to encode grid_h
101
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
102
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
103
+
104
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
105
+ return emb
106
+
107
+
108
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
109
+ """
110
+ embed_dim: output dimension for each position
111
+ pos: a list of positions to be encoded: size (M,)
112
+ out: (M, D)
113
+ """
114
+ assert embed_dim % 2 == 0
115
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
116
+ omega /= embed_dim / 2.
117
+ omega = 1. / 10000**omega # (D/2,)
118
+
119
+ pos = pos.reshape(-1) # (M,)
120
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
121
+
122
+ emb_sin = np.sin(out) # (M, D/2)
123
+ emb_cos = np.cos(out) # (M, D/2)
124
+
125
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
126
+ return emb
127
+
128
+
129
+ class PatchEmbedMR(nn.Module):
130
+ """ 2D Image to Patch Embedding
131
+ """
132
+ def __init__(
133
+ self,
134
+ patch_size: int = 2,
135
+ in_chans: int = 4,
136
+ embed_dim: int = 768,
137
+ bias: bool = True,
138
+ ):
139
+ super().__init__()
140
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
141
+
142
+ def forward(self, x):
143
+ x = self.proj(x)
144
+ x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
145
+ return x
146
+
147
+
148
+ class OmniGen(nn.Module):
149
+ """
150
+ Diffusion model with a Transformer backbone.
151
+ """
152
+ def __init__(
153
+ self,
154
+ transformer_config: Phi3Config,
155
+ patch_size=2,
156
+ in_channels=4,
157
+ pe_interpolation: float = 1.0,
158
+ pos_embed_max_size: int = 192,
159
+ ):
160
+ super().__init__()
161
+ self.in_channels = in_channels
162
+ self.out_channels = in_channels
163
+ self.patch_size = patch_size
164
+ self.pos_embed_max_size = pos_embed_max_size
165
+
166
+ hidden_size = transformer_config.hidden_size
167
+
168
+ self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
169
+ self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
170
+
171
+ self.time_token = TimestepEmbedder(hidden_size)
172
+ self.t_embedder = TimestepEmbedder(hidden_size)
173
+
174
+ self.pe_interpolation = pe_interpolation
175
+ pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
176
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
177
+
178
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
179
+
180
+ self.initialize_weights()
181
+
182
+ self.llm = Phi3Transformer(config=transformer_config)
183
+ self.llm.config.use_cache = False
184
+
185
+ @classmethod
186
+ def from_pretrained(cls, model_name):
187
+ if not os.path.exists(os.path.join(model_name, 'model.pt')):
188
+ cache_folder = os.getenv('HF_HUB_CACHE')
189
+ model_name = snapshot_download(repo_id=model_name,
190
+ cache_dir=cache_folder,
191
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
192
+ config = Phi3Config.from_pretrained(model_name)
193
+ model = cls(config)
194
+ ckpt = torch.load(os.path.join(model_name, 'model.pt'))
195
+ model.load_state_dict(ckpt)
196
+ return model
197
+
198
+ def initialize_weights(self):
199
+ assert not hasattr(self, "llama")
200
+
201
+ # Initialize transformer layers:
202
+ def _basic_init(module):
203
+ if isinstance(module, nn.Linear):
204
+ torch.nn.init.xavier_uniform_(module.weight)
205
+ if module.bias is not None:
206
+ nn.init.constant_(module.bias, 0)
207
+ self.apply(_basic_init)
208
+
209
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
210
+ w = self.x_embedder.proj.weight.data
211
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
212
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
213
+
214
+ w = self.input_x_embedder.proj.weight.data
215
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
216
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
217
+
218
+
219
+ # Initialize timestep embedding MLP:
220
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
221
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
222
+ nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
223
+ nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
224
+
225
+ # Zero-out output layers:
226
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
227
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
228
+ nn.init.constant_(self.final_layer.linear.weight, 0)
229
+ nn.init.constant_(self.final_layer.linear.bias, 0)
230
+
231
+ def unpatchify(self, x, h, w):
232
+ """
233
+ x: (N, T, patch_size**2 * C)
234
+ imgs: (N, H, W, C)
235
+ """
236
+ c = self.out_channels
237
+
238
+ x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
239
+ x = torch.einsum('nhwpqc->nchpwq', x)
240
+ imgs = x.reshape(shape=(x.shape[0], c, h, w))
241
+ return imgs
242
+
243
+
244
+ def cropped_pos_embed(self, height, width):
245
+ """Crops positional embeddings for SD3 compatibility."""
246
+ if self.pos_embed_max_size is None:
247
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
248
+
249
+ height = height // self.patch_size
250
+ width = width // self.patch_size
251
+ if height > self.pos_embed_max_size:
252
+ raise ValueError(
253
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
254
+ )
255
+ if width > self.pos_embed_max_size:
256
+ raise ValueError(
257
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
258
+ )
259
+
260
+ top = (self.pos_embed_max_size - height) // 2
261
+ left = (self.pos_embed_max_size - width) // 2
262
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
263
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
264
+ # print(top, top + height, left, left + width, spatial_pos_embed.size())
265
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
266
+ return spatial_pos_embed
267
+
268
+
269
+ def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
270
+ if isinstance(latents, list):
271
+ return_list = False
272
+ if padding_latent is None:
273
+ padding_latent = [None] * len(latents)
274
+ return_list = True
275
+ patched_latents, num_tokens, shapes = [], [], []
276
+ for latent, padding in zip(latents, padding_latent):
277
+ height, width = latent.shape[-2:]
278
+ if is_input_images:
279
+ latent = self.input_x_embedder(latent)
280
+ else:
281
+ latent = self.x_embedder(latent)
282
+ pos_embed = self.cropped_pos_embed(height, width)
283
+ latent = latent + pos_embed
284
+ if padding is not None:
285
+ latent = torch.cat([latent, padding], dim=-2)
286
+ patched_latents.append(latent)
287
+
288
+ num_tokens.append(pos_embed.size(1))
289
+ shapes.append([height, width])
290
+ if not return_list:
291
+ latents = torch.cat(patched_latents, dim=0)
292
+ else:
293
+ latents = patched_latents
294
+ else:
295
+ height, width = latents.shape[-2:]
296
+ if is_input_images:
297
+ latents = self.input_x_embedder(latents)
298
+ else:
299
+ latents = self.x_embedder(latents)
300
+ pos_embed = self.cropped_pos_embed(height, width)
301
+ latents = latents + pos_embed
302
+ num_tokens = latents.size(1)
303
+ shapes = [height, width]
304
+ return latents, num_tokens, shapes
305
+
306
+
307
+ def forward(self, x, timestep, text_ids, pixel_values, image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None):
308
+ """
309
+
310
+ """
311
+ input_is_list = isinstance(x, list)
312
+ x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
313
+ time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
314
+
315
+ if pixel_values is not None:
316
+ input_latents, _, _ = self.patch_multiple_resolutions(pixel_values, is_input_images=True)
317
+ if text_ids is not None:
318
+ condition_embeds = self.llm.embed_tokens(text_ids)
319
+ input_img_inx = 0
320
+ for b_inx in image_sizes.keys():
321
+ for start_inx, end_inx in image_sizes[b_inx]:
322
+ condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
323
+ input_img_inx += 1
324
+ if pixel_values is not None:
325
+ assert input_img_inx == len(input_latents)
326
+
327
+ input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
328
+ else:
329
+ input_emb = torch.cat([time_token, x], dim=1)
330
+ output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
331
+ output, past_key_values = output.last_hidden_state, output.past_key_values
332
+ if input_is_list:
333
+ image_embedding = output[:, -max(num_tokens):]
334
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
335
+ x = self.final_layer(image_embedding, time_emb)
336
+ latents = []
337
+ for i in range(x.size(0)):
338
+ latent = x[i:i+1, :num_tokens[i]]
339
+ latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
340
+ latents.append(latent)
341
+ else:
342
+ image_embedding = output[:, -num_tokens:]
343
+ time_emb = self.t_embedder(timestep, dtype=x.dtype)
344
+ x = self.final_layer(image_embedding, time_emb)
345
+ latents = self.unpatchify(x, shapes[0], shapes[1])
346
+
347
+ return latents, past_key_values
348
+
349
+ @torch.no_grad()
350
+ def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
351
+ """
352
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
353
+ """
354
+ self.llm.config.use_cache = use_kv_cache
355
+ model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values)
356
+ if use_img_cfg:
357
+ cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
358
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
359
+ model_out = [cond, cond, cond]
360
+ else:
361
+ cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
362
+ cond = uncond + cfg_scale * (cond - uncond)
363
+ model_out = [cond, cond]
364
+
365
+ return torch.cat(model_out, dim=0), past_key_values
366
+
367
+
368
+ @torch.no_grad()
369
+ def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
370
+ """
371
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
372
+ """
373
+ self.llm.config.use_cache = use_kv_cache
374
+ if past_key_values is None:
375
+ past_key_values = [None] * len(attention_mask)
376
+
377
+ x = torch.split(x, len(x) // len(attention_mask), dim=0)
378
+ timestep = timestep.to(x[0].dtype)
379
+ timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
380
+
381
+ model_out, pask_key_values = [], []
382
+ for i in range(len(input_ids)):
383
+ temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values[i])
384
+ model_out.append(temp_out)
385
+ pask_key_values.append(temp_pask_key_values)
386
+
387
+ if len(model_out) == 3:
388
+ cond, uncond, img_cond = model_out
389
+ cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
390
+ model_out = [cond, cond, cond]
391
+ elif len(model_out) == 2:
392
+ cond, uncond = model_out
393
+ cond = uncond + cfg_scale * (cond - uncond)
394
+ model_out = [cond, cond]
395
+ else:
396
+ return model_out[0]
397
+
398
+ return torch.cat(model_out, dim=0), pask_key_values
399
+
400
+
401
+
402
+
OmniGen/pipeline.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ from PIL import Image
6
+ import numpy as np
7
+ import torch
8
+ from huggingface_hub import snapshot_download
9
+ from diffusers.models import AutoencoderKL
10
+ from diffusers.utils import (
11
+ USE_PEFT_BACKEND,
12
+ is_torch_xla_available,
13
+ logging,
14
+ replace_example_docstring,
15
+ scale_lora_layers,
16
+ unscale_lora_layers,
17
+ )
18
+
19
+ from OmniGen import OmniGen, OmniGenProcessor, OmniGenScheduler
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ EXAMPLE_DOC_STRING = """
25
+ Examples:
26
+ ```py
27
+ >>> from OmniGen import OmniGenPipeline
28
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
29
+ ... base_model
30
+ ... )
31
+ >>> prompt = "A woman holds a bouquet of flowers and faces the camera"
32
+ >>> image = pipe(
33
+ ... prompt,
34
+ ... guidance_scale=1.0,
35
+ ... num_inference_steps=50,
36
+ ... ).images[0]
37
+ >>> image.save("t2i.png")
38
+ ```
39
+ """
40
+
41
+
42
+
43
+ class OmniGenPipeline:
44
+ def __init__(
45
+ self,
46
+ vae: AutoencoderKL,
47
+ model: OmniGen,
48
+ processor: OmniGenProcessor,
49
+ ):
50
+ self.vae = vae
51
+ self.model = model
52
+ self.processor = processor
53
+
54
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
55
+ self.model.to(self.device)
56
+ self.vae.to(self.device)
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, model_name):
60
+ if not os.path.exists(model_name):
61
+ cache_folder = os.getenv('HF_HUB_CACHE')
62
+ print(cache_folder)
63
+ model_name = snapshot_download(repo_id=model_name,
64
+ cache_dir=cache_folder,
65
+ ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
66
+ logger.info(f"Downloaded model to {model_name}")
67
+ model = OmniGen.from_pretrained(model_name)
68
+ processor = OmniGenProcessor.from_pretrained(model_name)
69
+ vae = AutoencoderKL.from_pretrained(os.path.join(model_name, "vae"))
70
+
71
+ return cls(vae, model, processor)
72
+
73
+ def vae_encode(self, x, dtype):
74
+ if self.vae.config.shift_factor is not None:
75
+ x = self.vae.encode(x).latent_dist.sample()
76
+ x = (x - self.vae.config.shift_factor) * self.vae.config.scaling_factor
77
+ else:
78
+ x = self.vae.encode(x).latent_dist.sample().mul_(self.vae.config.scaling_factor)
79
+ x = x.to(dtype)
80
+ return x
81
+
82
+ def move_to_device(self, data):
83
+ if isinstance(data, list):
84
+ return [x.to(self.device) for x in data]
85
+ return data.to(self.device)
86
+
87
+
88
+ @torch.no_grad()
89
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
90
+ def __call__(
91
+ self,
92
+ prompt: Union[str, List[str]],
93
+ input_images: Union[List[str], List[List[str]]] = None,
94
+ height: int = 1024,
95
+ width: int = 1024,
96
+ num_inference_steps: int = 50,
97
+ guidance_scale: float = 3,
98
+ use_img_guidance: bool = True,
99
+ img_guidance_scale: float = 1.6,
100
+ separate_cfg_infer: bool = False,
101
+ use_kv_cache: bool = True,
102
+ dtype: torch.dtype = torch.bfloat16,
103
+ ):
104
+ r"""
105
+ Function invoked when calling the pipeline for generation.
106
+
107
+ Args:
108
+ prompt (`str` or `List[str]`):
109
+ The prompt or prompts to guide the image generation.
110
+ input_images (`List[str]` or `List[List[str]]`, *optional*):
111
+ The list of input images. We will replace the "<|image_i|>" in prompt with the 1-th image in list.
112
+ height (`int`, *optional*, defaults to 1024):
113
+ The height in pixels of the generated image. The number must be a multiple of 16.
114
+ width (`int`, *optional*, defaults to 1024):
115
+ The width in pixels of the generated image. The number must be a multiple of 16.
116
+ num_inference_steps (`int`, *optional*, defaults to 50):
117
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
118
+ guidance_scale (`float`, *optional*, defaults to 4.0):
119
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
120
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
121
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
122
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
123
+ usually at the expense of lower image quality.
124
+ use_img_guidance (`bool`, *optional*, defaults to True):
125
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
126
+ img_guidance_scale (`float`, *optional*, defaults to 1.6):
127
+ Defined as equation 3 in [Instrucpix2pix](https://arxiv.org/pdf/2211.09800).
128
+ separate_cfg_infer (`bool`, *optional*, defaults to False):
129
+ Perform inference on images with different guidance separately; this can save memory when generating images of large size at the expense of slower inference.
130
+ use_kv_cache (`bool`, *optional*, defaults to True): enable kv cache to speed up the inference
131
+
132
+ Examples:
133
+
134
+ Returns:
135
+ A list with the generated images.
136
+ """
137
+ assert height%16 == 0 and width%16 == 0
138
+ if use_kv_cache and separate_cfg_infer:
139
+ raise "Currently, don't support both use_kv_cache and separate_cfg_infer"
140
+ if input_images is None:
141
+ use_img_guidance = False
142
+ if isinstance(prompt, str):
143
+ prompt = [prompt]
144
+ input_images = [input_images] if input_images is not None else None
145
+
146
+ input_data = self.processor(prompt, input_images, height=height, width=width, use_img_cfg=use_img_guidance, separate_cfg_input=separate_cfg_infer)
147
+
148
+ num_prompt = len(prompt)
149
+ num_cfg = 2 if use_img_guidance else 1
150
+ latent_size_h, latent_size_w = height//8, width//8
151
+
152
+ latents = torch.randn(num_prompt, 4, latent_size_h, latent_size_w, device=self.device)
153
+ latents = torch.cat([latents]*(1+num_cfg), 0).to(dtype)
154
+
155
+ input_img_latents = []
156
+ if separate_cfg_infer:
157
+ for temp_pixel_values in input_data['input_pixel_values']:
158
+ temp_input_latents = []
159
+ for img in temp_pixel_values:
160
+ img = self.vae_encode(img.to(self.device), dtype)
161
+ temp_input_latents.append(img)
162
+ input_img_latents.append(temp_input_latents)
163
+ else:
164
+ for img in input_data['input_pixel_values']:
165
+ img = self.vae_encode(img.to(self.device), dtype)
166
+ input_img_latents.append(img)
167
+
168
+ model_kwargs = dict(input_ids=self.move_to_device(input_data['input_ids']),
169
+ input_img_latents=input_img_latents,
170
+ input_image_sizes=input_data['input_image_sizes'],
171
+ attention_mask=self.move_to_device(input_data["attention_mask"]),
172
+ position_ids=self.move_to_device(input_data["position_ids"]),
173
+ cfg_scale=guidance_scale,
174
+ img_cfg_scale=img_guidance_scale,
175
+ use_img_cfg=use_img_guidance,
176
+ use_kv_cache=use_kv_cache)
177
+
178
+ if separate_cfg_infer:
179
+ func = self.model.forward_with_separate_cfg
180
+ else:
181
+ func = self.model.forward_with_cfg
182
+ self.model.to(dtype)
183
+
184
+ scheduler = OmniGenScheduler(num_steps=num_inference_steps)
185
+ samples = scheduler(latents, func, model_kwargs, use_kv_cache=use_kv_cache)
186
+ samples = samples.chunk((1+num_cfg), dim=0)[0]
187
+
188
+ samples = samples.to(torch.float32)
189
+ if self.vae.config.shift_factor is not None:
190
+ samples = samples / self.vae.config.scaling_factor + self.vae.config.shift_factor
191
+ else:
192
+ samples = samples / self.vae.config.scaling_factor
193
+ samples = self.vae.decode(samples).sample
194
+
195
+ output_samples = (samples * 0.5 + 0.5).clamp(0, 1)*255
196
+ output_samples = output_samples.permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
197
+ output_images = []
198
+ for i, sample in enumerate(output_samples):
199
+ output_images.append(Image.fromarray(sample))
200
+
201
+ return output_images
OmniGen/processor.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import Dict, List
4
+ import json
5
+
6
+ import torch
7
+ import numpy as np
8
+ import random
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ from transformers import AutoTokenizer
12
+ from huggingface_hub import snapshot_download
13
+
14
+
15
+ def crop_arr(pil_image, max_image_size):
16
+ while min(*pil_image.size) >= 2 * max_image_size:
17
+ pil_image = pil_image.resize(
18
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
19
+ )
20
+
21
+ if max(*pil_image.size) > max_image_size:
22
+ scale = max_image_size / max(*pil_image.size)
23
+ pil_image = pil_image.resize(
24
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
25
+ )
26
+
27
+ arr = np.array(pil_image)
28
+ crop_y1 = (arr.shape[0] % 16) // 2
29
+ crop_y2 = arr.shape[0] % 16 - crop_y1
30
+
31
+ crop_x1 = (arr.shape[1] % 16) // 2
32
+ crop_x2 = arr.shape[1] % 16 - crop_x1
33
+
34
+ arr = arr[crop_y1:arr.shape[0]-crop_y2, crop_x1:arr.shape[1]-crop_x2]
35
+ return Image.fromarray(arr)
36
+
37
+
38
+ class OmniGenProcessor:
39
+ def __init__(self,
40
+ text_tokenizer,
41
+ max_image_size: int=1024):
42
+ self.text_tokenizer = text_tokenizer
43
+ self.max_image_size = max_image_size
44
+
45
+ self.image_transform = transforms.Compose([
46
+ transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
49
+ ])
50
+
51
+ self.collator = OmniGenCollator()
52
+ self.separate_collator = OmniGenSeparateCollator()
53
+
54
+ @classmethod
55
+ def from_pretrained(cls, model_name):
56
+ if not os.path.exists(model_name):
57
+ cache_folder = os.getenv('HF_HUB_CACHE')
58
+ model_name = snapshot_download(repo_id=model_name,
59
+ cache_dir=cache_folder,
60
+ allow_patterns="*.json")
61
+ text_tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+
63
+ return cls(text_tokenizer)
64
+
65
+
66
+ def process_image(self, image):
67
+ image = Image.open(image).convert('RGB')
68
+ return self.image_transform(image)
69
+
70
+ def process_multi_modal_prompt(self, text, input_images):
71
+ if input_images is None or len(input_images) == 0:
72
+ model_inputs = self.text_tokenizer(text)
73
+ return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
74
+
75
+ pattern = r"<\|image_\d+\|>"
76
+ prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
77
+
78
+ for i in range(1, len(prompt_chunks)):
79
+ if prompt_chunks[i][0] == 1:
80
+ prompt_chunks[i] = prompt_chunks[i][1:]
81
+
82
+ image_tags = re.findall(pattern, text)
83
+ image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
84
+
85
+ unique_image_ids = sorted(list(set(image_ids)))
86
+ assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
87
+ # total images must be the same as the number of image tags
88
+ assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
89
+
90
+ input_images = [input_images[x-1] for x in image_ids]
91
+
92
+ all_input_ids = []
93
+ img_inx = []
94
+ idx = 0
95
+ for i in range(len(prompt_chunks)):
96
+ all_input_ids.extend(prompt_chunks[i])
97
+ if i != len(prompt_chunks) -1:
98
+ start_inx = len(all_input_ids)
99
+ size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
100
+ img_inx.append([start_inx, start_inx+size])
101
+ all_input_ids.extend([0]*size)
102
+
103
+ return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
104
+
105
+
106
+ def add_prefix_instruction(self, prompt):
107
+ user_prompt = '<|user|>\n'
108
+ generation_prompt = 'Generate an image according to the following instructions\n'
109
+ assistant_prompt = '<|assistant|>\n<|diffusion|>'
110
+ prompt_suffix = "<|end|>\n"
111
+ prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
112
+ return prompt
113
+
114
+
115
+ def __call__(self,
116
+ instructions: List[str],
117
+ input_images: List[List[str]] = None,
118
+ height: int = 1024,
119
+ width: int = 1024,
120
+ negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
121
+ use_img_cfg: bool = True,
122
+ separate_cfg_input: bool = False,
123
+ ) -> Dict:
124
+
125
+ if input_images is None:
126
+ use_img_cfg = False
127
+ if isinstance(instructions, str):
128
+ instructions = [instructions]
129
+ input_images = [input_images]
130
+
131
+ input_data = []
132
+ for i in range(len(instructions)):
133
+ cur_instruction = instructions[i]
134
+ cur_input_images = None if input_images is None else input_images[i]
135
+ cur_instruction = self.add_prefix_instruction(cur_instruction)
136
+ if cur_input_images is not None and len(cur_input_images) > 0:
137
+ cur_input_images = [self.process_image(x) for x in cur_input_images]
138
+ else:
139
+ cur_input_images = None
140
+ assert "<img><|image_1|></img>" not in cur_instruction
141
+
142
+ mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
143
+
144
+
145
+ neg_mllm_input, img_cfg_mllm_input = None, None
146
+ neg_instruction = self.add_prefix_instruction(negative_prompt)
147
+ neg_mllm_input = self.process_multi_modal_prompt(neg_instruction, None)
148
+ if use_img_cfg:
149
+ if cur_input_images is not None and len(cur_input_images) >= 1:
150
+ img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
151
+ img_cfg_mllm_input = self.process_multi_modal_prompt(self.add_prefix_instruction(" ".join(img_cfg_prompt)), cur_input_images)
152
+ else:
153
+ img_cfg_mllm_input = neg_instruction
154
+
155
+ input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
156
+
157
+ if separate_cfg_input:
158
+ return self.separate_collator(input_data)
159
+ return self.collator(input_data)
160
+
161
+
162
+
163
+
164
+ class OmniGenCollator:
165
+ def __init__(self, pad_token_id=2, hidden_size=3072):
166
+ self.pad_token_id = pad_token_id
167
+ self.hidden_size = hidden_size
168
+
169
+ def create_position(self, attention_mask, num_tokens_for_output_images):
170
+ position_ids = []
171
+ text_length = attention_mask.size(-1)
172
+ img_length = max(num_tokens_for_output_images)
173
+ for mask in attention_mask:
174
+ temp_l = torch.sum(mask)
175
+ temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
176
+ position_ids.append(temp_position)
177
+ return torch.LongTensor(position_ids)
178
+
179
+ def create_mask(self, attention_mask, num_tokens_for_output_images):
180
+ extended_mask = []
181
+ padding_images = []
182
+ text_length = attention_mask.size(-1)
183
+ img_length = max(num_tokens_for_output_images)
184
+ seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
185
+ inx = 0
186
+ for mask in attention_mask:
187
+ temp_l = torch.sum(mask)
188
+ pad_l = text_length - temp_l
189
+
190
+ temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))
191
+
192
+ image_mask = torch.zeros(size=(temp_l+1, img_length))
193
+ temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
194
+
195
+ image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
196
+ temp_mask = torch.cat([temp_mask, image_mask], dim=0)
197
+
198
+ if pad_l > 0:
199
+ pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
200
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
201
+
202
+ pad_mask = torch.ones(size=(pad_l, seq_len))
203
+ temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
204
+
205
+ true_img_length = num_tokens_for_output_images[inx]
206
+ pad_img_length = img_length - true_img_length
207
+ if pad_img_length > 0:
208
+ temp_mask[:, -pad_img_length:] = 0
209
+ temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
210
+ else:
211
+ temp_padding_imgs = None
212
+
213
+ extended_mask.append(temp_mask.unsqueeze(0))
214
+ padding_images.append(temp_padding_imgs)
215
+ inx += 1
216
+ return torch.cat(extended_mask, dim=0), padding_images
217
+
218
+ def adjust_attention_for_input_images(self, attention_mask, image_sizes):
219
+ for b_inx in image_sizes.keys():
220
+ for start_inx, end_inx in image_sizes[b_inx]:
221
+ attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
222
+
223
+ return attention_mask
224
+
225
+ def pad_input_ids(self, input_ids, image_sizes):
226
+ max_l = max([len(x) for x in input_ids])
227
+ padded_ids = []
228
+ attention_mask = []
229
+ new_image_sizes = []
230
+
231
+ for i in range(len(input_ids)):
232
+ temp_ids = input_ids[i]
233
+ temp_l = len(temp_ids)
234
+ pad_l = max_l - temp_l
235
+ if pad_l == 0:
236
+ attention_mask.append([1]*max_l)
237
+ padded_ids.append(temp_ids)
238
+ else:
239
+ attention_mask.append([0]*pad_l+[1]*temp_l)
240
+ padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
241
+
242
+ if i in image_sizes:
243
+ new_inx = []
244
+ for old_inx in image_sizes[i]:
245
+ new_inx.append([x+pad_l for x in old_inx])
246
+ image_sizes[i] = new_inx
247
+
248
+ return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
249
+
250
+
251
+ def process_mllm_input(self, mllm_inputs, target_img_size):
252
+ num_tokens_for_output_images = []
253
+ for img_size in target_img_size:
254
+ num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)
255
+
256
+ pixel_values, image_sizes = [], {}
257
+ b_inx = 0
258
+ for x in mllm_inputs:
259
+ if x['pixel_values'] is not None:
260
+ pixel_values.extend(x['pixel_values'])
261
+ for size in x['image_sizes']:
262
+ if b_inx not in image_sizes:
263
+ image_sizes[b_inx] = [size]
264
+ else:
265
+ image_sizes[b_inx].append(size)
266
+ b_inx += 1
267
+ pixel_values = [x.unsqueeze(0) for x in pixel_values]
268
+
269
+
270
+ input_ids = [x['input_ids'] for x in mllm_inputs]
271
+ padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
272
+ position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
273
+ attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
274
+ attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
275
+
276
+ return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
277
+
278
+
279
+ def __call__(self, features):
280
+ mllm_inputs = [f[0] for f in features]
281
+ cfg_mllm_inputs = [f[1] for f in features]
282
+ img_cfg_mllm_input = [f[2] for f in features]
283
+ target_img_size = [f[3] for f in features]
284
+
285
+
286
+ if img_cfg_mllm_input[0] is not None:
287
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
288
+ target_img_size = target_img_size + target_img_size + target_img_size
289
+ else:
290
+ mllm_inputs = mllm_inputs + cfg_mllm_inputs
291
+ target_img_size = target_img_size + target_img_size
292
+
293
+
294
+ all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
295
+
296
+ data = {"input_ids": all_padded_input_ids,
297
+ "attention_mask": all_attention_mask,
298
+ "position_ids": all_position_ids,
299
+ "input_pixel_values": all_pixel_values,
300
+ "input_image_sizes": all_image_sizes,
301
+ "padding_images": all_padding_images,
302
+ }
303
+ return data
304
+
305
+
306
+ class OmniGenSeparateCollator(OmniGenCollator):
307
+ def __call__(self, features):
308
+ mllm_inputs = [f[0] for f in features]
309
+ cfg_mllm_inputs = [f[1] for f in features]
310
+ img_cfg_mllm_input = [f[2] for f in features]
311
+ target_img_size = [f[3] for f in features]
312
+
313
+
314
+ all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []
315
+
316
+
317
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
318
+ all_padded_input_ids.append(padded_input_ids)
319
+ all_attention_mask.append(attention_mask)
320
+ all_position_ids.append(position_ids)
321
+ all_pixel_values.append(pixel_values)
322
+ all_image_sizes.append(image_sizes)
323
+ all_padding_images.append(padding_images)
324
+
325
+ if cfg_mllm_inputs[0] is not None:
326
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
327
+ all_padded_input_ids.append(padded_input_ids)
328
+ all_attention_mask.append(attention_mask)
329
+ all_position_ids.append(position_ids)
330
+ all_pixel_values.append(pixel_values)
331
+ all_image_sizes.append(image_sizes)
332
+ all_padding_images.append(padding_images)
333
+ if img_cfg_mllm_input[0] is not None:
334
+ padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
335
+ all_padded_input_ids.append(padded_input_ids)
336
+ all_attention_mask.append(attention_mask)
337
+ all_position_ids.append(position_ids)
338
+ all_pixel_values.append(pixel_values)
339
+ all_image_sizes.append(image_sizes)
340
+ all_padding_images.append(padding_images)
341
+
342
+ data = {"input_ids": all_padded_input_ids,
343
+ "attention_mask": all_attention_mask,
344
+ "position_ids": all_position_ids,
345
+ "input_pixel_values": all_pixel_values,
346
+ "input_image_sizes": all_image_sizes,
347
+ "padding_images": all_padding_images,
348
+ }
349
+ return data
OmniGen/scheduler.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from transformers.cache_utils import Cache, DynamicCache, OffloadedCache
4
+
5
+ class OmniGenScheduler:
6
+ def __init__(self, num_steps: int=50, time_shifting_factor: int=1):
7
+ self.num_steps = num_steps
8
+ self.time_shift = time_shifting_factor
9
+
10
+ t = torch.linspace(0, 1, num_steps+1)
11
+ t = t / (t + time_shifting_factor - time_shifting_factor * t)
12
+ self.sigma = t
13
+
14
+ def crop_kv_cache(self, past_key_values, num_tokens_for_img):
15
+ crop_past_key_values = ()
16
+ for layer_idx in range(len(past_key_values)):
17
+ key_states, value_states = past_key_values[layer_idx][:2]
18
+ crop_past_key_values += ((key_states[..., :-(num_tokens_for_img+1), :], value_states[..., :-(num_tokens_for_img+1), :], ),)
19
+ return crop_past_key_values
20
+ # return DynamicCache.from_legacy_cache(crop_past_key_values)
21
+
22
+ def crop_position_ids_for_cache(self, position_ids, num_tokens_for_img):
23
+ if isinstance(position_ids, list):
24
+ for i in range(len(position_ids)):
25
+ position_ids[i] = position_ids[i][:, -(num_tokens_for_img+1):]
26
+ else:
27
+ position_ids = position_ids[:, -(num_tokens_for_img+1):]
28
+ return position_ids
29
+
30
+ def crop_attention_mask_for_cache(self, attention_mask, num_tokens_for_img):
31
+ if isinstance(attention_mask, list):
32
+ return [x[..., -(num_tokens_for_img+1):, :] for x in attention_mask]
33
+ return attention_mask[..., -(num_tokens_for_img+1):, :]
34
+
35
+ def __call__(self, z, func, model_kwargs, use_kv_cache: bool=True):
36
+ past_key_values = None
37
+ for i in tqdm(range(self.num_steps)):
38
+ timesteps = torch.zeros(size=(len(z), )).to(z.device) + self.sigma[i]
39
+ pred, temp_past_key_values = func(z, timesteps, past_key_values=past_key_values, **model_kwargs)
40
+ sigma_next = self.sigma[i+1]
41
+ sigma = self.sigma[i]
42
+ z = z + (sigma_next - sigma) * pred
43
+ if i == 0 and use_kv_cache:
44
+ num_tokens_for_img = z.size(-1)*z.size(-2) // 4
45
+ if isinstance(temp_past_key_values, list):
46
+ past_key_values = [self.crop_kv_cache(x, num_tokens_for_img) for x in temp_past_key_values]
47
+ model_kwargs['input_ids'] = [None] * len(temp_past_key_values)
48
+ else:
49
+ past_key_values = self.crop_kv_cache(temp_past_key_values, num_tokens_for_img)
50
+ model_kwargs['input_ids'] = None
51
+
52
+ model_kwargs['position_ids'] = self.crop_position_ids_for_cache(model_kwargs['position_ids'], num_tokens_for_img)
53
+ model_kwargs['attention_mask'] = self.crop_attention_mask_for_cache(model_kwargs['attention_mask'], num_tokens_for_img)
54
+ return z
55
+
OmniGen/train.py ADDED
File without changes
OmniGen/transformer.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9
+ from huggingface_hub import snapshot_download
10
+
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutputWithPast,
13
+ CausalLMOutputWithPast,
14
+ SequenceClassifierOutputWithPast,
15
+ TokenClassifierOutput,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers import Phi3Config, Phi3Model
19
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache, OffloadedCache
20
+ from transformers.utils import logging
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Phi3Transformer(Phi3Model):
26
+ """
27
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
28
+ We only modified the attention mask
29
+ Args:
30
+ config: Phi3Config
31
+ """
32
+
33
+ def forward(
34
+ self,
35
+ input_ids: torch.LongTensor = None,
36
+ attention_mask: Optional[torch.Tensor] = None,
37
+ position_ids: Optional[torch.LongTensor] = None,
38
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
39
+ inputs_embeds: Optional[torch.FloatTensor] = None,
40
+ use_cache: Optional[bool] = None,
41
+ output_attentions: Optional[bool] = None,
42
+ output_hidden_states: Optional[bool] = None,
43
+ return_dict: Optional[bool] = None,
44
+ cache_position: Optional[torch.LongTensor] = None,
45
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
46
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
47
+ output_hidden_states = (
48
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
49
+ )
50
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
51
+
52
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
53
+
54
+ if (input_ids is None) ^ (inputs_embeds is not None):
55
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
56
+
57
+ if self.gradient_checkpointing and self.training:
58
+ if use_cache:
59
+ logger.warning_once(
60
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
61
+ )
62
+ use_cache = False
63
+
64
+ # kept for BC (non `Cache` `past_key_values` inputs)
65
+ return_legacy_cache = False
66
+ if use_cache and not isinstance(past_key_values, Cache):
67
+ return_legacy_cache = True
68
+ if past_key_values is None:
69
+ past_key_values = DynamicCache()
70
+ else:
71
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
72
+ logger.warning_once(
73
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
74
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
75
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
76
+ )
77
+
78
+ if inputs_embeds is None:
79
+ inputs_embeds = self.embed_tokens(input_ids)
80
+
81
+ if cache_position is None:
82
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
83
+ cache_position = torch.arange(
84
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
85
+ )
86
+ if position_ids is None:
87
+ position_ids = cache_position.unsqueeze(0)
88
+
89
+ if attention_mask is not None and attention_mask.dim() == 3:
90
+ dtype = inputs_embeds.dtype
91
+ min_dtype = torch.finfo(dtype).min
92
+ attention_mask = (1 - attention_mask) * min_dtype
93
+ attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
94
+ else:
95
+ raise
96
+ # causal_mask = self._update_causal_mask(
97
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
98
+ # )
99
+
100
+ hidden_states = inputs_embeds
101
+
102
+ # decoder layers
103
+ all_hidden_states = () if output_hidden_states else None
104
+ all_self_attns = () if output_attentions else None
105
+ next_decoder_cache = None
106
+
107
+ for decoder_layer in self.layers:
108
+ if output_hidden_states:
109
+ all_hidden_states += (hidden_states,)
110
+
111
+ if self.gradient_checkpointing and self.training:
112
+ layer_outputs = self._gradient_checkpointing_func(
113
+ decoder_layer.__call__,
114
+ hidden_states,
115
+ attention_mask,
116
+ position_ids,
117
+ past_key_values,
118
+ output_attentions,
119
+ use_cache,
120
+ cache_position,
121
+ )
122
+ else:
123
+ layer_outputs = decoder_layer(
124
+ hidden_states,
125
+ attention_mask=attention_mask,
126
+ position_ids=position_ids,
127
+ past_key_value=past_key_values,
128
+ output_attentions=output_attentions,
129
+ use_cache=use_cache,
130
+ cache_position=cache_position,
131
+ )
132
+
133
+ hidden_states = layer_outputs[0]
134
+
135
+ if use_cache:
136
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
137
+
138
+ if output_attentions:
139
+ all_self_attns += (layer_outputs[1],)
140
+
141
+ hidden_states = self.norm(hidden_states)
142
+
143
+ # add hidden states from the last decoder layer
144
+ if output_hidden_states:
145
+ all_hidden_states += (hidden_states,)
146
+
147
+ next_cache = next_decoder_cache if use_cache else None
148
+ if return_legacy_cache:
149
+ next_cache = next_cache.to_legacy_cache()
150
+
151
+ if not return_dict:
152
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
153
+ return BaseModelOutputWithPast(
154
+ last_hidden_state=hidden_states,
155
+ past_key_values=next_cache,
156
+ hidden_states=all_hidden_states,
157
+ attentions=all_self_attns,
158
+ )
159
+
app.py CHANGED
@@ -1,154 +1,68 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
  height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import os
4
 
5
+ os.environ['CUDA_VISIBLE_DEVICES'] = '7'
 
 
6
 
7
+ from OmniGen import OmniGenPipeline
 
8
 
9
+ pipe = OmniGenPipeline.from_pretrained("shitao/tmp-preview")
 
 
 
10
 
11
+ # 示例处理函数:生成图像
12
+ def generate_image(text, img1, img2, img3, height, width, guidance_scale):
13
+ input_images = [img1, img2, img3]
14
+ # 去除 None
15
+ input_images = [img for img in input_images if img is not None]
16
+ if len(input_images) == 0:
17
+ input_images = None
18
 
19
+ output = pipe(
20
+ prompt=text,
21
+ input_images=input_images,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  height=height,
23
+ width=width,
24
+ guidance_scale=guidance_scale,
25
+ img_guidance_scale=1.6,
26
+ separate_cfg_infer=True,
27
+ use_kv_cache=False
28
+ )
29
+ img = output[0]
30
+ return img
31
+
32
+ # Gradio 接口
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("## Text + Multiple Images to Image Generator")
35
+
36
+ with gr.Row():
37
+ with gr.Column():
38
+ # 文本输入框
39
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your prompt here...")
40
+
41
+ # 图片上传框
42
+ image_input_1 = gr.Image(label="<img><|image_1|></img>", type="filepath")
43
+ image_input_2 = gr.Image(label="<img><|image_2|></img>", type="filepath")
44
+ image_input_3 = gr.Image(label="<img><|image_3|></img>", type="filepath")
45
+
46
+ # 高度和宽度滑块
47
+ height_input = gr.Slider(label="Height", minimum=256, maximum=2048, value=1024, step=16)
48
+ width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1024, step=16)
49
+
50
+ # 引导尺度输入
51
+ guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
52
+
53
+ # 生成按钮
54
+ generate_button = gr.Button("Generate Image")
55
+
56
+ with gr.Column():
57
+ # 输出图像框
58
+ output_image = gr.Image(label="Output Image")
59
+
60
+ # 按钮点击事件
61
+ generate_button.click(
62
+ generate_image,
63
+ inputs=[prompt_input, image_input_1, image_input_2, image_input_3, height_input, width_input, guidance_scale_input],
64
+ outputs=output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
67
+ # 启动应用
68
+ demo.launch()
edit.png ADDED
imgs/.DS_Store ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d65165279105ca6773180500688df4bdc69a2c7b771752f0a46ef120b7fd8ec3
3
+ size 6148
imgs/test_cases/liuyifei.png ADDED
imgs/test_cases/taylor.png ADDED
imgs/test_cases/trump.png ADDED
imgs/test_cases/turing.png ADDED
inference.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
setup.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ with open("README.md", mode="r", encoding="utf-8") as readme_file:
4
+ readme = readme_file.read()
5
+
6
+ setup(
7
+ name='OmniGen',
8
+ version='1.0.0',
9
+ description='OmniGen',
10
+ long_description=readme,
11
+ long_description_content_type="text/markdown",
12
+ author_email='[email protected]',
13
+ url='https://github.com/VectorSpaceLab/OmniGen',
14
+ packages=find_packages(),
15
+ include_package_data=True,
16
+ install_requires=[
17
+ 'torch>=1.6.0',
18
+ 'transformers>=4.41.0',
19
+ 'datasets',
20
+ 'accelerate>=0.20.1',
21
+ 'diffusers>=0.30.3'
22
+ ],
23
+ )