smile123456789 commited on
Commit
b65930c
·
1 Parent(s): c7b92cf

reorganize code

Browse files
.gitattributes CHANGED
@@ -1,3 +1,5 @@
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ __pycache__/
2
+
3
  *.7z filter=lfs diff=lfs merge=lfs -text
4
  *.arrow filter=lfs diff=lfs merge=lfs -text
5
  *.bin filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,37 +1,26 @@
1
  import os
2
  import math
3
  import time
4
- import numpy
5
  import spaces
6
  import random
7
  import threading
8
  import gradio as gr
9
- from PIL import Image, ImageOps
10
  from moviepy import VideoFileClip
11
  from datetime import datetime, timedelta
12
  from huggingface_hub import hf_hub_download, snapshot_download
13
 
14
- import insightface
15
- from insightface.app import FaceAnalysis
16
- from facexlib.parsing import init_parsing_model
17
- from facexlib.utils.face_restoration_helper import FaceRestoreHelper
18
-
19
  import torch
20
- from diffusers import CogVideoXDPMScheduler
21
- from diffusers.utils import load_image
22
  from diffusers.image_processor import VaeImageProcessor
23
  from diffusers.training_utils import free_memory
24
 
25
  from util.utils import *
26
  from util.rife_model import load_rife_model, rife_inference_with_latents
27
- from models.utils import process_face_embeddings
28
  from models.transformer_consisid import ConsisIDTransformer3DModel
29
  from models.pipeline_consisid import ConsisIDPipeline
30
- from models.eva_clip import create_model_and_transforms
31
- from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
32
- from models.eva_clip.utils_qformer import resize_numpy_image_long
33
 
34
 
 
35
  model_path = "ckpts"
36
 
37
  lora_path = None
@@ -51,72 +40,30 @@ if os.path.exists(os.path.join(model_path, "transformer_ema")):
51
  subfolder = "transformer_ema"
52
  else:
53
  subfolder = "transformer"
54
-
55
- transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
56
- scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
57
 
58
- try:
59
- is_kps = transformer.config.is_kps
60
- except:
61
- is_kps = False
62
-
63
- # 1. load face helper models
64
- face_helper = FaceRestoreHelper(
65
- upscale_factor=1,
66
- face_size=512,
67
- crop_ratio=(1, 1),
68
- det_model='retinaface_resnet50',
69
- save_ext='png',
70
- device=device,
71
- model_rootpath=os.path.join(model_path, "face_encoder")
72
- )
73
- face_helper.face_parse = None
74
- face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
75
- face_helper.face_det.eval()
76
- face_helper.face_parse.eval()
77
-
78
- model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
79
- face_clip_model = model.visual
80
- face_clip_model.eval()
81
-
82
- eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
83
- eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
84
- if not isinstance(eva_transform_mean, (list, tuple)):
85
- eva_transform_mean = (eva_transform_mean,) * 3
86
- if not isinstance(eva_transform_std, (list, tuple)):
87
- eva_transform_std = (eva_transform_std,) * 3
88
- eva_transform_mean = eva_transform_mean
89
- eva_transform_std = eva_transform_std
90
-
91
- face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
92
- handler_ante = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
93
- face_main_model.prepare(ctx_id=0, det_size=(640, 640))
94
- handler_ante.prepare(ctx_id=0)
95
-
96
- face_clip_model.to(device, dtype=dtype)
97
- face_helper.face_det.to(device)
98
- face_helper.face_parse.to(device)
99
  transformer.to(device, dtype=dtype)
100
- free_memory()
101
 
102
- pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, scheduler=scheduler, torch_dtype=dtype)
103
  # If you're using with lora, add this code
104
  if lora_path:
105
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
106
  pipe.fuse_lora(lora_scale=1 / lora_rank)
107
 
108
- scheduler_args = {}
109
- if "variance_type" in pipe.scheduler.config:
110
- variance_type = pipe.scheduler.config.variance_type
111
- if variance_type in ["learned", "learned_range"]:
112
- variance_type = "fixed_small"
113
- scheduler_args["variance_type"] = variance_type
114
 
115
- pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, **scheduler_args)
 
 
 
 
116
  pipe.to(device)
117
-
118
- # Enable CPU offload for the model.
119
- # turn on if you don't have multiple GPUs or enough GPU memory(such as H100) and it will cost more time in inference, it may also reduce the quality
120
  pipe.enable_model_cpu_offload()
121
  pipe.enable_sequential_cpu_offload()
122
  # pipe.vae.enable_slicing()
@@ -125,6 +72,7 @@ pipe.enable_sequential_cpu_offload()
125
  os.makedirs("./output", exist_ok=True)
126
  os.makedirs("./gradio_tmp", exist_ok=True)
127
 
 
128
  upscale_model = load_sd_upscale(f"{model_path}/model_real_esran/RealESRGAN_x4.pth", device)
129
  frame_interpolation_model = load_rife_model(f"{model_path}/model_rife")
130
 
@@ -142,34 +90,21 @@ def generate(
142
  if seed == -1:
143
  seed = random.randint(0, 2**8 - 1)
144
 
145
- id_image = np.array(ImageOps.exif_transpose(Image.fromarray(image_input)).convert("RGB"))
146
- id_image = resize_numpy_image_long(id_image, 1024)
147
- id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
148
  eva_transform_mean, eva_transform_std,
149
- face_main_model, device, dtype, id_image,
150
- original_id_image=id_image, is_align_face=True,
151
- cal_uncond=False)
152
-
153
- if is_kps:
154
- kps_cond = face_kps
155
- else:
156
- kps_cond = None
157
-
158
- tensor = align_crop_face_image.cpu().detach()
159
- tensor = tensor.squeeze()
160
- tensor = tensor.permute(1, 2, 0)
161
- tensor = tensor.numpy() * 255
162
- tensor = tensor.astype(np.uint8)
163
- image = ImageOps.exif_transpose(Image.fromarray(tensor))
164
 
165
  prompt = prompt.strip('"')
166
- if len(negative_prompt) == 0:
167
- negative_prompt = None
168
  if negative_prompt:
169
  negative_prompt = negative_prompt.strip('"')
170
 
171
- generator = torch.Generator(device).manual_seed(seed) if seed else None
172
-
173
  video_pt = pipe(
174
  prompt=prompt,
175
  negative_prompt=negative_prompt,
@@ -388,8 +323,6 @@ with gr.Blocks() as demo:
388
  seed_update = gr.update(visible=True, value=seed)
389
 
390
  return video_path, video_update, gif_update, seed_update
391
-
392
- run.zerogpu = True
393
 
394
  generate_button.click(
395
  fn=run,
@@ -400,4 +333,4 @@ with gr.Blocks() as demo:
400
 
401
  if __name__ == "__main__":
402
  demo.queue(max_size=15)
403
- demo.launch()
 
1
  import os
2
  import math
3
  import time
 
4
  import spaces
5
  import random
6
  import threading
7
  import gradio as gr
 
8
  from moviepy import VideoFileClip
9
  from datetime import datetime, timedelta
10
  from huggingface_hub import hf_hub_download, snapshot_download
11
 
 
 
 
 
 
12
  import torch
 
 
13
  from diffusers.image_processor import VaeImageProcessor
14
  from diffusers.training_utils import free_memory
15
 
16
  from util.utils import *
17
  from util.rife_model import load_rife_model, rife_inference_with_latents
18
+ from models.utils import process_face_embeddings_infer, prepare_face_models
19
  from models.transformer_consisid import ConsisIDTransformer3DModel
20
  from models.pipeline_consisid import ConsisIDPipeline
 
 
 
21
 
22
 
23
+ # 0. Pre config
24
  model_path = "ckpts"
25
 
26
  lora_path = None
 
40
  subfolder = "transformer_ema"
41
  else:
42
  subfolder = "transformer"
 
 
 
43
 
44
+
45
+ # 1. Prepare all the face models
46
+ face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models(model_path, device, dtype)
47
+
48
+
49
+ # 2. Load Pipeline.
50
+ transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  transformer.to(device, dtype=dtype)
52
+ pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=dtype)
53
 
 
54
  # If you're using with lora, add this code
55
  if lora_path:
56
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
57
  pipe.fuse_lora(lora_scale=1 / lora_rank)
58
 
 
 
 
 
 
 
59
 
60
+ # 3. Move to device.
61
+ face_helper_1.face_det.to(device)
62
+ face_helper_1.face_parse.to(device)
63
+ face_clip_model.to(device, dtype=dtype)
64
+ transformer.to(device, dtype=dtype)
65
  pipe.to(device)
66
+ # Save Memory. Turn on if you don't have multiple GPUs or enough GPU memory(such as H100) and it will cost more time in inference, it may also reduce the quality
 
 
67
  pipe.enable_model_cpu_offload()
68
  pipe.enable_sequential_cpu_offload()
69
  # pipe.vae.enable_slicing()
 
72
  os.makedirs("./output", exist_ok=True)
73
  os.makedirs("./gradio_tmp", exist_ok=True)
74
 
75
+ # load upscale and interpolation model
76
  upscale_model = load_sd_upscale(f"{model_path}/model_real_esran/RealESRGAN_x4.pth", device)
77
  frame_interpolation_model = load_rife_model(f"{model_path}/model_rife")
78
 
 
90
  if seed == -1:
91
  seed = random.randint(0, 2**8 - 1)
92
 
93
+ # 4. Prepare model input
94
+ id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2,
 
95
  eva_transform_mean, eva_transform_std,
96
+ face_main_model, device, dtype,
97
+ image_input, is_align_face=True)
98
+
99
+ is_kps = getattr(transformer.config, 'is_kps', False)
100
+ kps_cond = face_kps if is_kps else None
 
 
 
 
 
 
 
 
 
 
101
 
102
  prompt = prompt.strip('"')
 
 
103
  if negative_prompt:
104
  negative_prompt = negative_prompt.strip('"')
105
 
106
+ # 5. Generate Identity-Preserving Video
107
+ generator = torch.Generator(device).manual_seed(seed) if seed else None
108
  video_pt = pipe(
109
  prompt=prompt,
110
  negative_prompt=negative_prompt,
 
323
  seed_update = gr.update(visible=True, value=seed)
324
 
325
  return video_path, video_update, gif_update, seed_update
 
 
326
 
327
  generate_button.click(
328
  fn=run,
 
333
 
334
  if __name__ == "__main__":
335
  demo.queue(max_size=15)
336
+ demo.launch()
asserts/example_images/4.png ADDED
models/local_facial_extractor.py CHANGED
@@ -4,7 +4,18 @@ import torch.nn as nn
4
 
5
 
6
  # FFN
7
- def FeedForward(dim, mult=4):
 
 
 
 
 
 
 
 
 
 
 
8
  inner_dim = int(dim * mult)
9
  return nn.Sequential(
10
  nn.LayerNorm(dim),
@@ -15,20 +26,41 @@ def FeedForward(dim, mult=4):
15
 
16
 
17
  def reshape_tensor(x, heads):
 
 
 
 
 
 
 
 
 
 
18
  bs, length, width = x.shape
19
- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
20
  x = x.view(bs, length, heads, -1)
21
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22
  x = x.transpose(1, 2)
23
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24
  x = x.reshape(bs, heads, length, -1)
25
  return x
26
 
27
 
28
  class PerceiverAttention(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
30
  super().__init__()
31
- self.scale = dim_head ** -0.5
32
  self.dim_head = dim_head
33
  self.heads = heads
34
  inner_dim = dim_head * heads
@@ -42,21 +74,27 @@ class PerceiverAttention(nn.Module):
42
 
43
  def forward(self, x, latents):
44
  """
 
 
45
  Args:
46
- x (torch.Tensor): image features
47
- shape (b, n1, D)
48
- latent (torch.Tensor): latent features
49
- shape (b, n2, D)
 
50
  """
 
51
  x = self.norm1(x)
52
  latents = self.norm2(latents)
53
 
54
- b, seq_len, _ = latents.shape
55
 
 
56
  q = self.to_q(latents)
57
  kv_input = torch.cat((x, latents), dim=-2)
58
  k, v = self.to_kv(kv_input).chunk(2, dim=-1)
59
 
 
60
  q = reshape_tensor(q, self.heads)
61
  k = reshape_tensor(k, self.heads)
62
  v = reshape_tensor(v, self.heads)
@@ -67,6 +105,7 @@ class PerceiverAttention(nn.Module):
67
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
68
  out = weight @ v
69
 
 
70
  out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
71
 
72
  return self.to_out(out)
@@ -74,22 +113,22 @@ class PerceiverAttention(nn.Module):
74
 
75
  class LocalFacialExtractor(nn.Module):
76
  def __init__(
77
- self,
78
- dim=1024,
79
- depth=10,
80
- dim_head=64,
81
- heads=16,
82
- num_id_token=5,
83
- num_queries=32,
84
- output_dim=2048,
85
- ff_mult=4,
86
  ):
87
  """
88
  Initializes the LocalFacialExtractor class.
89
 
90
  Parameters:
91
  - dim (int): The dimensionality of latent features.
92
- - depth (int): Total number of PerceiverAttention and FeedForward layers.
93
  - dim_head (int): Dimensionality of each attention head.
94
  - heads (int): Number of attention heads.
95
  - num_id_token (int): Number of tokens used for identity features.
@@ -105,21 +144,21 @@ class LocalFacialExtractor(nn.Module):
105
  self.num_queries = num_queries
106
  assert depth % 5 == 0
107
  self.depth = depth // 5
108
- scale = dim ** -0.5
109
 
110
  # Learnable latent query embeddings
111
  self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
112
  # Projection layer to map the latent output to the desired dimension
113
  self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
114
 
115
- # Attention and FeedForward layer stack
116
  self.layers = nn.ModuleList([])
117
  for _ in range(depth):
118
  self.layers.append(
119
  nn.ModuleList(
120
  [
121
  PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
122
- FeedForward(dim=dim, mult=ff_mult), # FeedForward layer
123
  ]
124
  )
125
  )
@@ -128,7 +167,7 @@ class LocalFacialExtractor(nn.Module):
128
  for i in range(5):
129
  setattr(
130
  self,
131
- f'mapping_{i}',
132
  nn.Sequential(
133
  nn.Linear(1024, 1024),
134
  nn.LayerNorm(1024),
@@ -175,30 +214,30 @@ class LocalFacialExtractor(nn.Module):
175
 
176
  # Process each of the 5 visual feature inputs
177
  for i in range(5):
178
- vit_feature = getattr(self, f'mapping_{i}')(y[i])
179
  ctx_feature = torch.cat((x, vit_feature), dim=1)
180
 
181
- # Pass through the PerceiverAttention and FeedForward layers
182
- for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:
183
  latents = attn(ctx_feature, latents) + latents
184
  latents = ff(latents) + latents
185
 
186
  # Retain only the query latents
187
- latents = latents[:, :self.num_queries]
188
  # Project the latents to the output dimension
189
  latents = latents @ self.proj_out
190
  return latents
191
-
192
 
193
  class PerceiverCrossAttention(nn.Module):
194
  """
195
-
196
  Args:
197
  dim (int): Dimension of the input latent and output. Default is 3072.
198
  dim_head (int): Dimension of each attention head. Default is 128.
199
  heads (int): Number of attention heads. Default is 16.
200
  kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
201
-
202
  Attributes:
203
  scale (float): Scaling factor used in dot-product attention for numerical stability.
204
  norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
@@ -208,9 +247,10 @@ class PerceiverCrossAttention(nn.Module):
208
  to_out (nn.Linear): Linear layer for outputting the final result after attention.
209
 
210
  """
 
211
  def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
212
  super().__init__()
213
- self.scale = dim_head ** -0.5
214
  self.dim_head = dim_head
215
  self.heads = heads
216
  inner_dim = dim_head * heads
@@ -232,13 +272,13 @@ class PerceiverCrossAttention(nn.Module):
232
  - batch_size (b): Number of samples in the batch.
233
  - n1: Sequence length (e.g., number of patches or tokens).
234
  - D: Feature dimension.
235
-
236
  latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
237
  - n2: Number of latent elements.
238
-
239
  Returns:
240
  torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
241
-
242
  """
243
  # Apply layer normalization to the input image and latent features
244
  x = self.norm1(x)
 
4
 
5
 
6
  # FFN
7
+ def ConsisIDFeedForward(dim, mult=4):
8
+ """
9
+ Creates a consistent ID feedforward block consisting of layer normalization,
10
+ two linear layers, and a GELU activation.
11
+
12
+ Args:
13
+ dim (int): The input dimension of the tensor.
14
+ mult (int, optional): Multiplier for the inner dimension. Default is 4.
15
+
16
+ Returns:
17
+ nn.Sequential: A sequence of layers comprising LayerNorm, Linear layers, and GELU.
18
+ """
19
  inner_dim = int(dim * mult)
20
  return nn.Sequential(
21
  nn.LayerNorm(dim),
 
26
 
27
 
28
  def reshape_tensor(x, heads):
29
+ """
30
+ Reshapes the input tensor for multi-head attention.
31
+
32
+ Args:
33
+ x (torch.Tensor): The input tensor with shape (batch_size, length, width).
34
+ heads (int): The number of attention heads.
35
+
36
+ Returns:
37
+ torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
38
+ """
39
  bs, length, width = x.shape
 
40
  x = x.view(bs, length, heads, -1)
 
41
  x = x.transpose(1, 2)
 
42
  x = x.reshape(bs, heads, length, -1)
43
  return x
44
 
45
 
46
  class PerceiverAttention(nn.Module):
47
+ """
48
+ Implements the Perceiver attention mechanism with multi-head attention.
49
+
50
+ This layer takes two inputs: 'x' (image features) and 'latents' (latent features),
51
+ applying multi-head attention to both and producing an output tensor with the same
52
+ dimension as the input tensor 'x'.
53
+
54
+ Args:
55
+ dim (int): The input dimension.
56
+ dim_head (int, optional): The dimension of each attention head. Default is 64.
57
+ heads (int, optional): The number of attention heads. Default is 8.
58
+ kv_dim (int, optional): The key-value dimension. If None, `dim` is used for both keys and values.
59
+ """
60
+
61
  def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
62
  super().__init__()
63
+ self.scale = dim_head**-0.5
64
  self.dim_head = dim_head
65
  self.heads = heads
66
  inner_dim = dim_head * heads
 
74
 
75
  def forward(self, x, latents):
76
  """
77
+ Forward pass for Perceiver attention.
78
+
79
  Args:
80
+ x (torch.Tensor): Image features tensor with shape (batch_size, num_pixels, D).
81
+ latents (torch.Tensor): Latent features tensor with shape (batch_size, num_latents, D).
82
+
83
+ Returns:
84
+ torch.Tensor: Output tensor after applying attention and transformation.
85
  """
86
+ # Apply normalization
87
  x = self.norm1(x)
88
  latents = self.norm2(latents)
89
 
90
+ b, seq_len, _ = latents.shape # Get batch size and sequence length
91
 
92
+ # Compute query, key, and value matrices
93
  q = self.to_q(latents)
94
  kv_input = torch.cat((x, latents), dim=-2)
95
  k, v = self.to_kv(kv_input).chunk(2, dim=-1)
96
 
97
+ # Reshape the tensors for multi-head attention
98
  q = reshape_tensor(q, self.heads)
99
  k = reshape_tensor(k, self.heads)
100
  v = reshape_tensor(v, self.heads)
 
105
  weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
106
  out = weight @ v
107
 
108
+ # Reshape and return the final output
109
  out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
110
 
111
  return self.to_out(out)
 
113
 
114
  class LocalFacialExtractor(nn.Module):
115
  def __init__(
116
+ self,
117
+ dim=1024,
118
+ depth=10,
119
+ dim_head=64,
120
+ heads=16,
121
+ num_id_token=5,
122
+ num_queries=32,
123
+ output_dim=2048,
124
+ ff_mult=4,
125
  ):
126
  """
127
  Initializes the LocalFacialExtractor class.
128
 
129
  Parameters:
130
  - dim (int): The dimensionality of latent features.
131
+ - depth (int): Total number of PerceiverAttention and ConsisIDFeedForward layers.
132
  - dim_head (int): Dimensionality of each attention head.
133
  - heads (int): Number of attention heads.
134
  - num_id_token (int): Number of tokens used for identity features.
 
144
  self.num_queries = num_queries
145
  assert depth % 5 == 0
146
  self.depth = depth // 5
147
+ scale = dim**-0.5
148
 
149
  # Learnable latent query embeddings
150
  self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
151
  # Projection layer to map the latent output to the desired dimension
152
  self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
153
 
154
+ # Attention and ConsisIDFeedForward layer stack
155
  self.layers = nn.ModuleList([])
156
  for _ in range(depth):
157
  self.layers.append(
158
  nn.ModuleList(
159
  [
160
  PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
161
+ ConsisIDFeedForward(dim=dim, mult=ff_mult), # ConsisIDFeedForward layer
162
  ]
163
  )
164
  )
 
167
  for i in range(5):
168
  setattr(
169
  self,
170
+ f"mapping_{i}",
171
  nn.Sequential(
172
  nn.Linear(1024, 1024),
173
  nn.LayerNorm(1024),
 
214
 
215
  # Process each of the 5 visual feature inputs
216
  for i in range(5):
217
+ vit_feature = getattr(self, f"mapping_{i}")(y[i])
218
  ctx_feature = torch.cat((x, vit_feature), dim=1)
219
 
220
+ # Pass through the PerceiverAttention and ConsisIDFeedForward layers
221
+ for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
222
  latents = attn(ctx_feature, latents) + latents
223
  latents = ff(latents) + latents
224
 
225
  # Retain only the query latents
226
+ latents = latents[:, : self.num_queries]
227
  # Project the latents to the output dimension
228
  latents = latents @ self.proj_out
229
  return latents
230
+
231
 
232
  class PerceiverCrossAttention(nn.Module):
233
  """
234
+
235
  Args:
236
  dim (int): Dimension of the input latent and output. Default is 3072.
237
  dim_head (int): Dimension of each attention head. Default is 128.
238
  heads (int): Number of attention heads. Default is 16.
239
  kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
240
+
241
  Attributes:
242
  scale (float): Scaling factor used in dot-product attention for numerical stability.
243
  norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
 
247
  to_out (nn.Linear): Linear layer for outputting the final result after attention.
248
 
249
  """
250
+
251
  def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
252
  super().__init__()
253
+ self.scale = dim_head**-0.5
254
  self.dim_head = dim_head
255
  self.heads = heads
256
  inner_dim = dim_head * heads
 
272
  - batch_size (b): Number of samples in the batch.
273
  - n1: Sequence length (e.g., number of patches or tokens).
274
  - D: Feature dimension.
275
+
276
  latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
277
  - n2: Number of latent elements.
278
+
279
  Returns:
280
  torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
281
+
282
  """
283
  # Apply layer normalization to the input image and latent features
284
  x = self.norm1(x)
models/pipeline_consisid.py CHANGED
@@ -1,8 +1,16 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
 
 
6
 
7
  import inspect
8
  import math
@@ -13,20 +21,19 @@ import sys
13
  import PIL
14
  import numpy as np
15
  import cv2
16
- from PIL import Image
17
  import torch
 
18
  from transformers import T5EncoderModel, T5Tokenizer
19
 
20
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
21
  from diffusers.image_processor import PipelineImageInput
22
- from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
23
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
24
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
  from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
26
- from diffusers.utils import logging, replace_example_docstring
27
  from diffusers.utils.torch_utils import randn_tensor
28
  from diffusers.video_processor import VideoProcessor
29
- from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
30
 
31
  from models.transformer_consisid import ConsisIDTransformer3DModel
32
 
@@ -37,26 +44,28 @@ for project_root in project_roots:
37
 
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
 
 
40
  EXAMPLE_DOC_STRING = """
41
  Examples:
42
  ```py
43
  >>> import torch
44
- >>> from diffusers import CogVideoXImageToVideoPipeline
45
  >>> from diffusers.utils import export_to_video, load_image
46
 
47
- >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
48
  >>> pipe.to("cuda")
49
 
50
- >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
51
  >>> image = load_image(
52
- ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
53
  ... )
54
  >>> video = pipe(image, prompt, use_dynamic_cfg=True)
55
  >>> export_to_video(video.frames[0], "output.mp4", fps=8)
56
  ```
57
  """
58
 
59
- def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
 
60
  stickwidth = 4
61
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
62
  kps = np.array(kps)
@@ -72,7 +81,9 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
72
  y = kps[index][:, 1]
73
  length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
74
  angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
75
- polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
 
 
76
  out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
77
  out_img = (out_img * 0.6).astype(np.uint8)
78
 
@@ -81,9 +92,10 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
81
  x, y = kp
82
  out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
83
 
84
- out_img_pil = Image.fromarray(out_img.astype(np.uint8))
85
  return out_img_pil
86
 
 
87
  def process_image(image, vae):
88
  image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
89
  image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
@@ -92,6 +104,7 @@ def process_image(image, vae):
92
  image_latent_dist = vae.encode(input_image).latent_dist
93
  return image_latent_dist
94
 
 
95
  # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
96
  def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
97
  tw = tgt_width
@@ -185,9 +198,24 @@ def retrieve_latents(
185
  raise AttributeError("Could not access latents of provided encoder_output")
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  class ConsisIDPipeline(DiffusionPipeline):
189
  r"""
190
- Pipeline for image-to-video generation using CogVideoX.
191
 
192
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
193
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -196,7 +224,7 @@ class ConsisIDPipeline(DiffusionPipeline):
196
  vae ([`AutoencoderKL`]):
197
  Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
198
  text_encoder ([`T5EncoderModel`]):
199
- Frozen text-encoder. CogVideoX uses
200
  [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
201
  [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
202
  tokenizer (`T5Tokenizer`):
@@ -222,7 +250,7 @@ class ConsisIDPipeline(DiffusionPipeline):
222
  tokenizer: T5Tokenizer,
223
  text_encoder: T5EncoderModel,
224
  vae: AutoencoderKLCogVideoX,
225
- transformer: Union[ConsisIDTransformer3DModel, CogVideoXTransformer3DModel],
226
  scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
227
  ):
228
  super().__init__()
@@ -246,7 +274,7 @@ class ConsisIDPipeline(DiffusionPipeline):
246
 
247
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
248
 
249
- # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
250
  def _get_t5_prompt_embeds(
251
  self,
252
  prompt: Union[str, List[str]] = None,
@@ -289,7 +317,7 @@ class ConsisIDPipeline(DiffusionPipeline):
289
 
290
  return prompt_embeds
291
 
292
- # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
293
  def encode_prompt(
294
  self,
295
  prompt: Union[str, List[str]],
@@ -409,7 +437,8 @@ class ConsisIDPipeline(DiffusionPipeline):
409
  if kps_cond is not None:
410
  kps_cond = kps_cond.unsqueeze(2)
411
  kps_cond_latents = [
412
- retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
 
413
  ]
414
  else:
415
  image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
@@ -455,7 +484,7 @@ class ConsisIDPipeline(DiffusionPipeline):
455
  latents = latents * self.scheduler.init_noise_sigma
456
  return latents, image_latents
457
 
458
- # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
459
  def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
460
  latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
461
  latents = 1 / self.vae_scaling_factor_image * latents
@@ -554,13 +583,13 @@ class ConsisIDPipeline(DiffusionPipeline):
554
  f" {negative_prompt_embeds.shape}."
555
  )
556
 
557
- # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
558
  def fuse_qkv_projections(self) -> None:
559
  r"""Enables fused QKV projections."""
560
  self.fusing_transformer = True
561
  self.transformer.fuse_qkv_projections()
562
 
563
- # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
564
  def unfuse_qkv_projections(self) -> None:
565
  r"""Disable QKV projection fusion if enabled."""
566
  if not self.fusing_transformer:
@@ -569,7 +598,7 @@ class ConsisIDPipeline(DiffusionPipeline):
569
  self.transformer.unfuse_qkv_projections()
570
  self.fusing_transformer = False
571
 
572
- # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
573
  def _prepare_rotary_positional_embeddings(
574
  self,
575
  height: int,
@@ -638,7 +667,7 @@ class ConsisIDPipeline(DiffusionPipeline):
638
  id_vit_hidden: Optional[torch.Tensor] = None,
639
  id_cond: Optional[torch.Tensor] = None,
640
  kps_cond: Optional[torch.Tensor] = None,
641
- ) -> Union[CogVideoXPipelineOutput, Tuple]:
642
  """
643
  Function invoked when calling the pipeline for generation.
644
 
@@ -658,7 +687,7 @@ class ConsisIDPipeline(DiffusionPipeline):
658
  The width in pixels of the generated image. This is set to 720 by default for the best results.
659
  num_frames (`int`, defaults to `48`):
660
  Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
661
- contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
662
  num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
663
  needs to be satisfied is that of divisibility mentioned above.
664
  num_inference_steps (`int`, *optional*, defaults to 50):
@@ -712,8 +741,8 @@ class ConsisIDPipeline(DiffusionPipeline):
712
  Examples:
713
 
714
  Returns:
715
- [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
716
- [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
717
  `tuple`. When returning a tuple, the first element is a list with the generated images.
718
  """
719
  if num_frames > 49:
@@ -784,7 +813,7 @@ class ConsisIDPipeline(DiffusionPipeline):
784
  image = self.video_processor.preprocess(image, height=height, width=width).to(
785
  device, dtype=prompt_embeds.dtype
786
  )
787
-
788
  latent_channels = self.transformer.config.in_channels // 2
789
  latents, image_latents = self.prepare_latents(
790
  image,
@@ -797,9 +826,9 @@ class ConsisIDPipeline(DiffusionPipeline):
797
  device,
798
  generator,
799
  latents,
800
- kps_cond
801
  )
802
-
803
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
804
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
805
 
@@ -836,8 +865,8 @@ class ConsisIDPipeline(DiffusionPipeline):
836
  timestep=timestep,
837
  image_rotary_emb=image_rotary_emb,
838
  return_dict=False,
839
- id_vit_hidden = id_vit_hidden,
840
- id_cond = id_cond,
841
  )[0]
842
  noise_pred = noise_pred.float()
843
 
@@ -891,4 +920,4 @@ class ConsisIDPipeline(DiffusionPipeline):
891
  if not return_dict:
892
  return (video,)
893
 
894
- return CogVideoXPipelineOutput(frames=video)
 
1
+ # Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
 
15
  import inspect
16
  import math
 
21
  import PIL
22
  import numpy as np
23
  import cv2
 
24
  import torch
25
+ from dataclasses import dataclass
26
  from transformers import T5EncoderModel, T5Tokenizer
27
 
28
  from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
29
  from diffusers.image_processor import PipelineImageInput
30
+ from diffusers.models import AutoencoderKLCogVideoX
31
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
32
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
33
  from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
34
+ from diffusers.utils import logging, replace_example_docstring, BaseOutput
35
  from diffusers.utils.torch_utils import randn_tensor
36
  from diffusers.video_processor import VideoProcessor
 
37
 
38
  from models.transformer_consisid import ConsisIDTransformer3DModel
39
 
 
44
 
45
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
 
47
+
48
  EXAMPLE_DOC_STRING = """
49
  Examples:
50
  ```py
51
  >>> import torch
52
+ >>> from diffusers import ConsisIDPipeline
53
  >>> from diffusers.utils import export_to_video, load_image
54
 
55
+ >>> pipe = ConsisIDPipeline.from_pretrained("https://huggingface.co/BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
56
  >>> pipe.to("cuda")
57
 
58
+ >>> prompt = "A woman adorned with a delicate flower crown, is standing amidst a field of gently swaying wildflowers. Her eyes sparkle with a serene gaze, and a faint smile graces her lips, suggesting a moment of peaceful contentment. The shot is framed from the waist up, highlighting the gentle breeze lightly tousling her hair. The background reveals an expansive meadow under a bright blue sky, capturing the tranquility of a sunny afternoon."
59
  >>> image = load_image(
60
+ ... "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/1.png?raw=true"
61
  ... )
62
  >>> video = pipe(image, prompt, use_dynamic_cfg=True)
63
  >>> export_to_video(video.frames[0], "output.mp4", fps=8)
64
  ```
65
  """
66
 
67
+
68
+ def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
69
  stickwidth = 4
70
  limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
71
  kps = np.array(kps)
 
81
  y = kps[index][:, 1]
82
  length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
83
  angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
84
+ polygon = cv2.ellipse2Poly(
85
+ (int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
86
+ )
87
  out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
88
  out_img = (out_img * 0.6).astype(np.uint8)
89
 
 
92
  x, y = kp
93
  out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
94
 
95
+ out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
96
  return out_img_pil
97
 
98
+
99
  def process_image(image, vae):
100
  image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
101
  image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
 
104
  image_latent_dist = vae.encode(input_image).latent_dist
105
  return image_latent_dist
106
 
107
+
108
  # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
109
  def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
110
  tw = tgt_width
 
198
  raise AttributeError("Could not access latents of provided encoder_output")
199
 
200
 
201
+ @dataclass
202
+ class ConsisIDPipelineOutput(BaseOutput):
203
+ r"""
204
+ Output class for ConsisID pipelines.
205
+
206
+ Args:
207
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
208
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
209
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
210
+ `(batch_size, num_frames, channels, height, width)`.
211
+ """
212
+
213
+ frames: torch.Tensor
214
+
215
+
216
  class ConsisIDPipeline(DiffusionPipeline):
217
  r"""
218
+ Pipeline for image-to-video generation using ConsisID.
219
 
220
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
221
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
 
224
  vae ([`AutoencoderKL`]):
225
  Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
226
  text_encoder ([`T5EncoderModel`]):
227
+ Frozen text-encoder. ConsisID uses
228
  [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
229
  [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
230
  tokenizer (`T5Tokenizer`):
 
250
  tokenizer: T5Tokenizer,
251
  text_encoder: T5EncoderModel,
252
  vae: AutoencoderKLCogVideoX,
253
+ transformer: Union[ConsisIDTransformer3DModel],
254
  scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
255
  ):
256
  super().__init__()
 
274
 
275
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
276
 
277
+ # Copied from diffusers.pipelines.consisid.pipeline_consisID.ConsisIDPipeline._get_t5_prompt_embeds
278
  def _get_t5_prompt_embeds(
279
  self,
280
  prompt: Union[str, List[str]] = None,
 
317
 
318
  return prompt_embeds
319
 
320
+ # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.encode_prompt
321
  def encode_prompt(
322
  self,
323
  prompt: Union[str, List[str]],
 
437
  if kps_cond is not None:
438
  kps_cond = kps_cond.unsqueeze(2)
439
  kps_cond_latents = [
440
+ retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i])
441
+ for i in range(batch_size)
442
  ]
443
  else:
444
  image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
 
484
  latents = latents * self.scheduler.init_noise_sigma
485
  return latents, image_latents
486
 
487
+ # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.decode_latents
488
  def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
489
  latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
490
  latents = 1 / self.vae_scaling_factor_image * latents
 
583
  f" {negative_prompt_embeds.shape}."
584
  )
585
 
586
+ # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.fuse_qkv_projections
587
  def fuse_qkv_projections(self) -> None:
588
  r"""Enables fused QKV projections."""
589
  self.fusing_transformer = True
590
  self.transformer.fuse_qkv_projections()
591
 
592
+ # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.unfuse_qkv_projections
593
  def unfuse_qkv_projections(self) -> None:
594
  r"""Disable QKV projection fusion if enabled."""
595
  if not self.fusing_transformer:
 
598
  self.transformer.unfuse_qkv_projections()
599
  self.fusing_transformer = False
600
 
601
+ # Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline._prepare_rotary_positional_embeddings
602
  def _prepare_rotary_positional_embeddings(
603
  self,
604
  height: int,
 
667
  id_vit_hidden: Optional[torch.Tensor] = None,
668
  id_cond: Optional[torch.Tensor] = None,
669
  kps_cond: Optional[torch.Tensor] = None,
670
+ ) -> Union[ConsisIDPipelineOutput, Tuple]:
671
  """
672
  Function invoked when calling the pipeline for generation.
673
 
 
687
  The width in pixels of the generated image. This is set to 720 by default for the best results.
688
  num_frames (`int`, defaults to `48`):
689
  Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
690
+ contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where
691
  num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
692
  needs to be satisfied is that of divisibility mentioned above.
693
  num_inference_steps (`int`, *optional*, defaults to 50):
 
741
  Examples:
742
 
743
  Returns:
744
+ [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`:
745
+ [`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a
746
  `tuple`. When returning a tuple, the first element is a list with the generated images.
747
  """
748
  if num_frames > 49:
 
813
  image = self.video_processor.preprocess(image, height=height, width=width).to(
814
  device, dtype=prompt_embeds.dtype
815
  )
816
+
817
  latent_channels = self.transformer.config.in_channels // 2
818
  latents, image_latents = self.prepare_latents(
819
  image,
 
826
  device,
827
  generator,
828
  latents,
829
+ kps_cond,
830
  )
831
+
832
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
833
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
834
 
 
865
  timestep=timestep,
866
  image_rotary_emb=image_rotary_emb,
867
  return_dict=False,
868
+ id_vit_hidden=id_vit_hidden,
869
+ id_cond=id_cond,
870
  )[0]
871
  noise_pred = noise_pred.float()
872
 
 
920
  if not return_dict:
921
  return (video,)
922
 
923
+ return ConsisIDPipelineOutput(frames=video)
models/transformer_consisid.py CHANGED
@@ -1,8 +1,16 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
 
 
6
 
7
  from typing import Any, Dict, Optional, Tuple, Union
8
  import os
@@ -38,9 +46,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
 
39
 
40
  @maybe_allow_in_graph
41
- class CogVideoXBlock(nn.Module):
42
  r"""
43
- Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
44
 
45
  Parameters:
46
  dim (`int`):
@@ -132,9 +140,6 @@ class CogVideoXBlock(nn.Module):
132
  hidden_states, encoder_hidden_states, temb
133
  )
134
 
135
- # insert here
136
- # pass
137
-
138
  # attention
139
  attn_hidden_states, attn_encoder_hidden_states = self.attn1(
140
  hidden_states=norm_hidden_states,
@@ -162,7 +167,7 @@ class CogVideoXBlock(nn.Module):
162
 
163
  class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
164
  """
165
- A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
166
 
167
  Parameters:
168
  num_attention_heads (`int`, defaults to `30`):
@@ -191,7 +196,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
191
  The height of the input latents.
192
  sample_frames (`int`, defaults to `49`):
193
  The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
194
- instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
195
  but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
196
  K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
197
  patch_size (`int`, defaults to `2`):
@@ -212,6 +217,32 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
212
  Scaling factor to apply in 3D positional embeddings across spatial dimensions.
213
  temporal_interpolation_scale (`float`, defaults to `1.0`):
214
  Scaling factor to apply in 3D positional embeddings across temporal dimensions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  """
216
 
217
  _supports_gradient_checkpointing = True
@@ -257,7 +288,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
257
 
258
  if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
259
  raise ValueError(
260
- "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
261
  "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
262
  "issue at https://github.com/huggingface/diffusers/issues."
263
  )
@@ -288,7 +319,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
288
  # 3. Define spatio-temporal transformers blocks
289
  self.transformer_blocks = nn.ModuleList(
290
  [
291
- CogVideoXBlock(
292
  dim=inner_dim,
293
  num_attention_heads=num_attention_heads,
294
  attention_head_dim=attention_head_dim,
@@ -319,6 +350,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
319
  self.is_train_face = is_train_face
320
  self.is_kps = is_kps
321
 
 
322
  if is_train_face:
323
  self.inner_dim = inner_dim
324
  self.cross_attn_interval = cross_attn_interval
@@ -338,21 +370,26 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
338
  weight_dtype = next(self.transformer_blocks.parameters()).dtype
339
  self.local_facial_extractor = LocalFacialExtractor()
340
  self.local_facial_extractor.to(device, dtype=weight_dtype)
341
- self.perceiver_cross_attention = nn.ModuleList([
342
- PerceiverCrossAttention(dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim).to(device, dtype=weight_dtype) for _ in range(self.num_ca)
343
- ])
 
 
 
 
 
344
 
345
  def save_face_modules(self, path: str):
346
  save_dict = {
347
- 'local_facial_extractor': self.local_facial_extractor.state_dict(),
348
- 'perceiver_cross_attention': [ca.state_dict() for ca in self.perceiver_cross_attention],
349
  }
350
  torch.save(save_dict, path)
351
 
352
  def load_face_modules(self, path: str):
353
  checkpoint = torch.load(path, map_location=self.device)
354
- self.local_facial_extractor.load_state_dict(checkpoint['local_facial_extractor'])
355
- for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint['perceiver_cross_attention']):
356
  ca.load_state_dict(state_dict)
357
 
358
  @property
@@ -463,14 +500,16 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
463
  timestep_cond: Optional[torch.Tensor] = None,
464
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
465
  attention_kwargs: Optional[Dict[str, Any]] = None,
466
- id_cond: Optional[torch.Tensor] = None,
467
  id_vit_hidden: Optional[torch.Tensor] = None,
468
  return_dict: bool = True,
469
  ):
470
  # fuse clip and insightface
471
  if self.is_train_face:
472
  assert id_cond is not None and id_vit_hidden is not None
473
- valid_face_emb = self.local_facial_extractor(id_cond, id_vit_hidden) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
 
 
474
 
475
  if attention_kwargs is not None:
476
  attention_kwargs = attention_kwargs.copy()
@@ -506,7 +545,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
506
 
507
  text_seq_length = encoder_hidden_states.shape[1]
508
  encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
509
- hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
510
 
511
  # 3. Transformer blocks
512
  ca_idx = 0
@@ -538,17 +577,14 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
538
 
539
  if self.is_train_face:
540
  if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
541
- hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](valid_face_emb, hidden_states) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
 
 
542
  ca_idx += 1
543
 
544
- if not self.config.use_rotary_positional_embeddings:
545
- # CogVideoX-2B
546
- hidden_states = self.norm_final(hidden_states)
547
- else:
548
- # CogVideoX-5B
549
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
550
- hidden_states = self.norm_final(hidden_states)
551
- hidden_states = hidden_states[:, text_seq_length:]
552
 
553
  # 4. Final block
554
  hidden_states = self.norm_out(hidden_states, temb=emb)
@@ -556,8 +592,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
556
 
557
  # 5. Unpatchify
558
  # Note: we use `-1` instead of `channels`:
559
- # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
560
- # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
561
  p = self.config.patch_size
562
  output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
563
  output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
 
1
+ # Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
 
15
  from typing import Any, Dict, Optional, Tuple, Union
16
  import os
 
46
 
47
 
48
  @maybe_allow_in_graph
49
+ class ConsisIDBlock(nn.Module):
50
  r"""
51
+ Transformer block used in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) model.
52
 
53
  Parameters:
54
  dim (`int`):
 
140
  hidden_states, encoder_hidden_states, temb
141
  )
142
 
 
 
 
143
  # attention
144
  attn_hidden_states, attn_encoder_hidden_states = self.attn1(
145
  hidden_states=norm_hidden_states,
 
167
 
168
  class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
169
  """
170
+ A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
171
 
172
  Parameters:
173
  num_attention_heads (`int`, defaults to `30`):
 
196
  The height of the input latents.
197
  sample_frames (`int`, defaults to `49`):
198
  The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
199
+ instead of 13 because ConsisID processed 13 latent frames at once in its default and recommended settings,
200
  but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
201
  K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
202
  patch_size (`int`, defaults to `2`):
 
217
  Scaling factor to apply in 3D positional embeddings across spatial dimensions.
218
  temporal_interpolation_scale (`float`, defaults to `1.0`):
219
  Scaling factor to apply in 3D positional embeddings across temporal dimensions.
220
+ is_train_face (`bool`, defaults to `False`):
221
+ Whether to use enable the identity-preserving module during the training process.
222
+ When set to `True`, the model will focus on identity-preserving tasks.
223
+ is_kps (`bool`, defaults to `False`):
224
+ Whether to enable keypoint for global facial extractor.
225
+ If `True`, keypoints will be in the model.
226
+ cross_attn_interval (`int`, defaults to `1`):
227
+ The interval between cross-attention layers in the Transformer architecture.
228
+ A larger value may reduce the frequency of cross-attention computations,
229
+ which can help reduce computational overhead.
230
+ LFE_num_tokens (`int`, defaults to `32`):
231
+ The number of tokens to use in the Local Facial Extractor (LFE).
232
+ This module is responsible for capturing high frequency representations
233
+ of the face.
234
+ LFE_output_dim (`int`, defaults to `768`):
235
+ The output dimension of the Local Facial Extractor (LFE) module.
236
+ This dimension determines the size of the feature vectors produced
237
+ by the LFE module.
238
+ LFE_heads (`int`, defaults to `12`):
239
+ The number of attention heads used in the Local Facial Extractor (LFE) module.
240
+ More heads may improve the ability to capture diverse features, but
241
+ can also increase computational complexity.
242
+ local_face_scale (`float`, defaults to `1.0`):
243
+ A scaling factor used to adjust the importance of local facial features
244
+ in the model. This can influence how strongly the model focuses on
245
+ high frequency face-related content.
246
  """
247
 
248
  _supports_gradient_checkpointing = True
 
288
 
289
  if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
290
  raise ValueError(
291
+ "There are no ConsisID checkpoints available with disable rotary embeddings and learned positional "
292
  "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
293
  "issue at https://github.com/huggingface/diffusers/issues."
294
  )
 
319
  # 3. Define spatio-temporal transformers blocks
320
  self.transformer_blocks = nn.ModuleList(
321
  [
322
+ ConsisIDBlock(
323
  dim=inner_dim,
324
  num_attention_heads=num_attention_heads,
325
  attention_head_dim=attention_head_dim,
 
350
  self.is_train_face = is_train_face
351
  self.is_kps = is_kps
352
 
353
+ # 5. Define identity-preserving config
354
  if is_train_face:
355
  self.inner_dim = inner_dim
356
  self.cross_attn_interval = cross_attn_interval
 
370
  weight_dtype = next(self.transformer_blocks.parameters()).dtype
371
  self.local_facial_extractor = LocalFacialExtractor()
372
  self.local_facial_extractor.to(device, dtype=weight_dtype)
373
+ self.perceiver_cross_attention = nn.ModuleList(
374
+ [
375
+ PerceiverCrossAttention(
376
+ dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim
377
+ ).to(device, dtype=weight_dtype)
378
+ for _ in range(self.num_ca)
379
+ ]
380
+ )
381
 
382
  def save_face_modules(self, path: str):
383
  save_dict = {
384
+ "local_facial_extractor": self.local_facial_extractor.state_dict(),
385
+ "perceiver_cross_attention": [ca.state_dict() for ca in self.perceiver_cross_attention],
386
  }
387
  torch.save(save_dict, path)
388
 
389
  def load_face_modules(self, path: str):
390
  checkpoint = torch.load(path, map_location=self.device)
391
+ self.local_facial_extractor.load_state_dict(checkpoint["local_facial_extractor"])
392
+ for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint["perceiver_cross_attention"]):
393
  ca.load_state_dict(state_dict)
394
 
395
  @property
 
500
  timestep_cond: Optional[torch.Tensor] = None,
501
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
502
  attention_kwargs: Optional[Dict[str, Any]] = None,
503
+ id_cond: Optional[torch.Tensor] = None,
504
  id_vit_hidden: Optional[torch.Tensor] = None,
505
  return_dict: bool = True,
506
  ):
507
  # fuse clip and insightface
508
  if self.is_train_face:
509
  assert id_cond is not None and id_vit_hidden is not None
510
+ valid_face_emb = self.local_facial_extractor(
511
+ id_cond, id_vit_hidden
512
+ ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
513
 
514
  if attention_kwargs is not None:
515
  attention_kwargs = attention_kwargs.copy()
 
545
 
546
  text_seq_length = encoder_hidden_states.shape[1]
547
  encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
548
+ hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
549
 
550
  # 3. Transformer blocks
551
  ca_idx = 0
 
577
 
578
  if self.is_train_face:
579
  if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
580
+ hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](
581
+ valid_face_emb, hidden_states
582
+ ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
583
  ca_idx += 1
584
 
585
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
586
+ hidden_states = self.norm_final(hidden_states)
587
+ hidden_states = hidden_states[:, text_seq_length:]
 
 
 
 
 
588
 
589
  # 4. Final block
590
  hidden_states = self.norm_out(hidden_states, temb=emb)
 
592
 
593
  # 5. Unpatchify
594
  # Note: we use `-1` instead of `channels`:
595
+ # - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels)
 
596
  p = self.config.patch_size
597
  output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
598
  output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
models/utils.py CHANGED
@@ -1,7 +1,8 @@
 
1
  import cv2
2
  import math
3
  import numpy as np
4
- from PIL import Image
5
 
6
  import torch
7
  from torchvision.transforms import InterpolationMode
@@ -10,7 +11,16 @@ from transformers import T5EncoderModel, T5Tokenizer
10
  from typing import List, Optional, Tuple, Union
11
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
12
  from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
 
13
 
 
 
 
 
 
 
 
 
14
 
15
  def tensor_to_pil(src_img_tensor):
16
  img = src_img_tensor.clone().detach()
@@ -204,12 +214,12 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
204
  return out_img_pil
205
 
206
 
207
- def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image=None, is_align_face=True, cal_uncond=False):
208
  """
209
  Args:
210
  image: numpy rgb image, range [0, 255]
211
  """
212
- face_helper.clean_all()
213
  image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3)
214
  # get antelopev2 embedding
215
  face_info = app.get(image_bgr)
@@ -224,19 +234,19 @@ def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_tr
224
  face_kps = None
225
 
226
  # using facexlib to detect and align face
227
- face_helper.read_image(image_bgr)
228
- face_helper.get_face_landmarks_5(only_center_face=True)
229
  if face_kps is None:
230
- face_kps = face_helper.all_landmarks_5[0]
231
- face_helper.align_warp_face()
232
- if len(face_helper.cropped_faces) == 0:
233
  raise RuntimeError('facexlib align face fail')
234
- align_face = face_helper.cropped_faces[0] # (512, 512, 3) # RGB
235
 
236
  # incase insightface didn't detect face
237
  if id_ante_embedding is None:
238
  print('fail to detect face using insightface, extract embedding on align face')
239
- id_ante_embedding = handler_ante.get_feat(align_face)
240
 
241
  id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
242
  if id_ante_embedding.ndim == 1:
@@ -246,7 +256,7 @@ def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_tr
246
  if is_align_face:
247
  input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
248
  input = input.to(device)
249
- parsing_out = face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
250
  parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
251
  bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
252
  bg = sum(parsing_out == i for i in bg_label).bool()
@@ -270,4 +280,84 @@ def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_tr
270
 
271
  id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
272
 
273
- return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import cv2
3
  import math
4
  import numpy as np
5
+ from PIL import Image, ImageOps
6
 
7
  import torch
8
  from torchvision.transforms import InterpolationMode
 
11
  from typing import List, Optional, Tuple, Union
12
  from diffusers.models.embeddings import get_3d_rotary_pos_embed
13
  from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
14
+ from diffusers.utils import load_image
15
 
16
+ import insightface
17
+ from insightface.app import FaceAnalysis
18
+ from facexlib.parsing import init_parsing_model
19
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
20
+
21
+ from models.eva_clip import create_model_and_transforms
22
+ from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
23
+ from models.eva_clip.utils_qformer import resize_numpy_image_long
24
 
25
  def tensor_to_pil(src_img_tensor):
26
  img = src_img_tensor.clone().detach()
 
214
  return out_img_pil
215
 
216
 
217
+ def process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image=None, is_align_face=True):
218
  """
219
  Args:
220
  image: numpy rgb image, range [0, 255]
221
  """
222
+ face_helper_1.clean_all()
223
  image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3)
224
  # get antelopev2 embedding
225
  face_info = app.get(image_bgr)
 
234
  face_kps = None
235
 
236
  # using facexlib to detect and align face
237
+ face_helper_1.read_image(image_bgr)
238
+ face_helper_1.get_face_landmarks_5(only_center_face=True)
239
  if face_kps is None:
240
+ face_kps = face_helper_1.all_landmarks_5[0]
241
+ face_helper_1.align_warp_face()
242
+ if len(face_helper_1.cropped_faces) == 0:
243
  raise RuntimeError('facexlib align face fail')
244
+ align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
245
 
246
  # incase insightface didn't detect face
247
  if id_ante_embedding is None:
248
  print('fail to detect face using insightface, extract embedding on align face')
249
+ id_ante_embedding = face_helper_2.get_feat(align_face)
250
 
251
  id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
252
  if id_ante_embedding.ndim == 1:
 
256
  if is_align_face:
257
  input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
258
  input = input.to(device)
259
+ parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
260
  parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
261
  bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
262
  bg = sum(parsing_out == i for i in bg_label).bool()
 
280
 
281
  id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
282
 
283
+ return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
284
+
285
+
286
+ def process_face_embeddings_infer(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, img_file_path, is_align_face=True):
287
+ """
288
+ Args:
289
+ image: numpy rgb image, range [0, 255]
290
+ """
291
+ if isinstance(img_file_path, str):
292
+ image = np.array(load_image(image=img_file_path).convert("RGB"))
293
+ else:
294
+ image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB"))
295
+
296
+ image = resize_numpy_image_long(image, 1024)
297
+ original_id_image = image
298
+
299
+ id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image, is_align_face)
300
+
301
+ tensor = align_crop_face_image.cpu().detach()
302
+ tensor = tensor.squeeze()
303
+ tensor = tensor.permute(1, 2, 0)
304
+ tensor = tensor.numpy() * 255
305
+ tensor = tensor.astype(np.uint8)
306
+ image = ImageOps.exif_transpose(Image.fromarray(tensor))
307
+
308
+ return id_cond, id_vit_hidden, image, face_kps
309
+
310
+ def prepare_face_models(model_path, device, dtype):
311
+ """
312
+ Prepare all face models for the facial recognition task.
313
+
314
+ Parameters:
315
+ - model_path: Path to the directory containing model files.
316
+ - device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
317
+ - dtype: Data type (e.g., torch.float32) for model inference.
318
+
319
+ Returns:
320
+ - face_helper_1: First face restoration helper.
321
+ - face_helper_2: Second face restoration helper.
322
+ - face_clip_model: CLIP model for face extraction.
323
+ - eva_transform_mean: Mean value for image normalization.
324
+ - eva_transform_std: Standard deviation value for image normalization.
325
+ - face_main_model: Main face analysis model.
326
+ """
327
+ # get helper model
328
+ face_helper_1 = FaceRestoreHelper(
329
+ upscale_factor=1,
330
+ face_size=512,
331
+ crop_ratio=(1, 1),
332
+ det_model='retinaface_resnet50',
333
+ save_ext='png',
334
+ device=device,
335
+ model_rootpath=os.path.join(model_path, "face_encoder")
336
+ )
337
+ face_helper_1.face_parse = None
338
+ face_helper_1.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
339
+ face_helper_2 = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
340
+ face_helper_2.prepare(ctx_id=0)
341
+
342
+ # get local facial extractor part 1
343
+ model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
344
+ face_clip_model = model.visual
345
+ eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
346
+ eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
347
+ if not isinstance(eva_transform_mean, (list, tuple)):
348
+ eva_transform_mean = (eva_transform_mean,) * 3
349
+ if not isinstance(eva_transform_std, (list, tuple)):
350
+ eva_transform_std = (eva_transform_std,) * 3
351
+ eva_transform_mean = eva_transform_mean
352
+ eva_transform_std = eva_transform_std
353
+
354
+ # get local facial extractor part 2
355
+ face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
356
+ face_main_model.prepare(ctx_id=0, det_size=(640, 640))
357
+
358
+ # move face models to device
359
+ face_helper_1.face_det.eval()
360
+ face_helper_1.face_parse.eval()
361
+ face_clip_model.eval()
362
+
363
+ return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std
requirements.txt CHANGED
@@ -6,7 +6,7 @@ onnx==1.17.0
6
  onnxruntime-gpu==1.19.2
7
  deepspeed==0.15.2
8
  accelerate==1.1.1
9
- diffusers==0.31.0
10
  transformers==4.46.3
11
  tokenizers==0.20.1
12
  peft==0.12.0
 
6
  onnxruntime-gpu==1.19.2
7
  deepspeed==0.15.2
8
  accelerate==1.1.1
9
+ git+https://github.com/SHYuanBest/ConsisID_diffusers.git
10
  transformers==4.46.3
11
  tokenizers==0.20.1
12
  peft==0.12.0
util/dataloader.py DELETED
@@ -1,1010 +0,0 @@
1
- import os
2
- import gc
3
- import cv2
4
- import json
5
- import math
6
- import decord
7
- import random
8
- import numpy as np
9
- from PIL import Image
10
- from tqdm import tqdm
11
- from decord import VideoReader
12
- from contextlib import contextmanager
13
- from func_timeout import FunctionTimedOut
14
- from typing import Optional, Sized, Iterator
15
-
16
- import torch
17
- from torch.utils.data import Dataset, Sampler
18
- import torch.nn.functional as F
19
- from torchvision.transforms import ToPILImage
20
- from torchvision import transforms
21
- from accelerate.logging import get_logger
22
-
23
- logger = get_logger(__name__)
24
-
25
- import threading
26
- log_lock = threading.Lock()
27
-
28
- def log_error_to_file(error_message, video_path):
29
- with log_lock:
30
- with open("error_log.txt", "a") as f:
31
- f.write(f"Error: {error_message}\n")
32
- f.write(f"Video Path: {video_path}\n")
33
- f.write("-" * 50 + "\n")
34
-
35
- def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
36
- stickwidth = 4
37
- limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
38
- kps = np.array(kps)
39
-
40
- w, h = image_pil.size
41
- out_img = np.zeros([h, w, 3])
42
-
43
- for i in range(len(limbSeq)):
44
- index = limbSeq[i]
45
- color = color_list[index[0]]
46
-
47
- x = kps[index][:, 0]
48
- y = kps[index][:, 1]
49
- length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
50
- angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
51
- polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
52
- out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
53
- out_img = (out_img * 0.6).astype(np.uint8)
54
-
55
- for idx_kp, kp in enumerate(kps):
56
- color = color_list[idx_kp]
57
- x, y = kp
58
- out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
59
-
60
- out_img_pil = Image.fromarray(out_img.astype(np.uint8))
61
- return out_img_pil
62
-
63
- @contextmanager
64
- def VideoReader_contextmanager(*args, **kwargs):
65
- vr = VideoReader(*args, **kwargs)
66
- try:
67
- yield vr
68
- finally:
69
- del vr
70
- gc.collect()
71
-
72
- def get_valid_segments(valid_frame, tolerance=5):
73
- valid_positions = sorted(set(valid_frame['face']).union(set(valid_frame['head'])))
74
-
75
- valid_segments = []
76
- current_segment = [valid_positions[0]]
77
-
78
- for i in range(1, len(valid_positions)):
79
- if valid_positions[i] - valid_positions[i - 1] <= tolerance:
80
- current_segment.append(valid_positions[i])
81
- else:
82
- valid_segments.append(current_segment)
83
- current_segment = [valid_positions[i]]
84
-
85
- if current_segment:
86
- valid_segments.append(current_segment)
87
-
88
- return valid_segments
89
-
90
-
91
- def get_frame_indices_adjusted_for_face(valid_frames, n_frames):
92
- valid_length = len(valid_frames)
93
- if valid_length >= n_frames:
94
- return valid_frames[:n_frames]
95
-
96
- additional_frames_needed = n_frames - valid_length
97
- repeat_indices = []
98
-
99
- for i in range(additional_frames_needed):
100
- index_to_repeat = i % valid_length
101
- repeat_indices.append(valid_frames[index_to_repeat])
102
-
103
- all_indices = valid_frames + repeat_indices
104
- all_indices.sort()
105
-
106
- return all_indices
107
-
108
-
109
- def generate_frame_indices_for_face(n_frames, sample_stride, valid_frame, tolerance=7, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0):
110
- valid_segments = get_valid_segments(valid_frame, tolerance)
111
- selected_segment = max(valid_segments, key=len)
112
-
113
- valid_length = len(selected_segment)
114
- if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0:
115
- # print("use skip frame percent")
116
- valid_start = int(valid_length * skip_frames_start_percent)
117
- valid_end = int(valid_length * skip_frames_end_percent)
118
- elif skip_frames_start != 0 or skip_frames_end != 0:
119
- # print("use skip frame")
120
- valid_start = skip_frames_start
121
- valid_end = valid_length - skip_frames_end
122
- else:
123
- # print("no use skip frame")
124
- valid_start = 0
125
- valid_end = valid_length
126
-
127
- if valid_length <= n_frames:
128
- return get_frame_indices_adjusted_for_face(selected_segment, n_frames), valid_length
129
- else:
130
- adjusted_length = valid_end - valid_start
131
- if adjusted_length <= 0:
132
- print(f"video_length: {valid_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}")
133
- raise ValueError("Skipping too many frames results in no frames left to sample.")
134
-
135
- clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1)
136
- start_idx_position = random.randint(valid_start, valid_end - clip_length)
137
- start_frame = selected_segment[start_idx_position]
138
-
139
- selected_frames = []
140
- for i in range(n_frames):
141
- next_frame = start_frame + i * sample_stride
142
- if next_frame in selected_segment:
143
- selected_frames.append(next_frame)
144
- else:
145
- break
146
-
147
- if len(selected_frames) < n_frames:
148
- return get_frame_indices_adjusted_for_face(selected_frames, n_frames), len(selected_frames)
149
-
150
- return selected_frames, len(selected_frames)
151
-
152
- def frame_has_required_confidence(bbox_data, frame, ID, conf_threshold=0.88):
153
- frame_str = str(frame)
154
- if frame_str not in bbox_data:
155
- return False
156
-
157
- frame_data = bbox_data[frame_str]
158
-
159
- face_conf = any(
160
- item['confidence'] > conf_threshold and item['new_track_id'] == ID
161
- for item in frame_data.get('face', [])
162
- )
163
-
164
- head_conf = any(
165
- item['confidence'] > conf_threshold and item['new_track_id'] == ID
166
- for item in frame_data.get('head', [])
167
- )
168
-
169
- return face_conf and head_conf
170
-
171
- def select_mask_frames_from_index(batch_frame, original_batch_frame, valid_id, corresponding_data, control_sam2_frame,
172
- valid_frame, bbox_data, base_dir, min_distance=3, min_frames=1, max_frames=5,
173
- mask_type='face', control_mask_type='head', dense_masks=False,
174
- ensure_control_frame=True):
175
- """
176
- Selects frames with corresponding mask images while ensuring a minimum distance constraint between frames,
177
- and that the frames exist in both batch_frame and valid_frame.
178
-
179
- Parameters:
180
- base_path (str): Base directory where the JSON files and mask results are located.
181
- min_distance (int): Minimum distance between selected frames.
182
- min_frames (int): Minimum number of frames to select.
183
- max_frames (int): Maximum number of frames to select.
184
- mask_type (str): Type of mask to select frames for ('face' or 'head').
185
- control_mask_type (str): Type of mask used for control frame selection ('face' or 'head').
186
-
187
- Returns:
188
- dict: A dictionary where keys are IDs and values are lists of selected mask PNG paths.
189
- """
190
- # Helper function to randomly select frames with at least X frames apart
191
- def select_frames_with_distance_constraint(frames, num_frames, min_distance, control_frame, bbox_data, ID,
192
- ensure_control_frame=True, fallback=True):
193
- """
194
- Selects frames with a minimum distance constraint. If not enough frames can be selected, a fallback plan is applied.
195
-
196
- Parameters:
197
- frames (list): List of frame indices to select from.
198
- num_frames (int): Number of frames to select.
199
- min_distance (int): Minimum distance between selected frames.
200
- control_frame (int): The control frame that must always be included.
201
- fallback (bool): Whether to apply a fallback strategy if not enough frames meet the distance constraint.
202
-
203
- Returns:
204
- list: List of selected frames.
205
- """
206
- conf_thresholds = [0.95, 0.94, 0.93, 0.92, 0.91, 0.90]
207
- if ensure_control_frame:
208
- selected_frames = [control_frame] # Ensure control frame is always included
209
- else:
210
- valid_initial_frames = []
211
- for conf_threshold in conf_thresholds:
212
- valid_initial_frames = [
213
- f for f in frames
214
- if frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold)
215
- ]
216
- if valid_initial_frames:
217
- break
218
- if valid_initial_frames:
219
- selected_frames = [random.choice(valid_initial_frames)]
220
- else:
221
- # If no frame meets the initial confidence, fall back to a random frame (or handle as per your preference)
222
- selected_frames = [random.choice(frames)]
223
-
224
- available_frames = [f for f in frames if f != selected_frames[0]] # Exclude control frame for random selection
225
-
226
- random.shuffle(available_frames) # Shuffle to introduce randomness
227
-
228
- while available_frames and len(selected_frames) < num_frames:
229
- last_selected_frame = selected_frames[-1]
230
-
231
- valid_choices = []
232
- for conf_threshold in conf_thresholds:
233
- valid_choices = [
234
- f for f in available_frames
235
- if abs(f - last_selected_frame) >= min_distance and
236
- frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold)
237
- ]
238
- if valid_choices:
239
- break
240
-
241
- if valid_choices:
242
- frame = random.choice(valid_choices)
243
- available_frames.remove(frame)
244
- selected_frames.append(frame)
245
- else:
246
- if fallback:
247
- # Fallback strategy: uniformly distribute remaining frames if distance constraint cannot be met
248
- remaining_needed = num_frames - len(selected_frames)
249
- remaining_frames = available_frames[:remaining_needed]
250
-
251
- # Distribute the remaining frames evenly if possible
252
- if remaining_frames:
253
- step = max(1, len(remaining_frames) // remaining_needed)
254
- evenly_selected = remaining_frames[::step][:remaining_needed]
255
- selected_frames.extend(evenly_selected)
256
- break
257
- else:
258
- break # No valid choices remain and no fallback strategy is allowed
259
-
260
- if len(selected_frames) < num_frames:
261
- return None
262
-
263
- return selected_frames
264
-
265
- # Convert batch_frame list to a set to remove duplicates
266
- batch_frame_set = set(batch_frame)
267
-
268
- # Dictionary to store selected mask PNGs
269
- selected_masks_dict = {}
270
- selected_bboxs_dict = {}
271
- dense_masks_dict = {}
272
- selected_frames_dict = {}
273
-
274
- # ID
275
- try:
276
- mask_valid_frames = valid_frame[mask_type] # Select frames based on the specified mask type
277
- control_valid_frames = valid_frame[control_mask_type] # Control frames for control_mask_type
278
- except KeyError:
279
- if mask_type not in valid_frame.keys():
280
- print(f"no valid {mask_type}")
281
- if control_mask_type not in valid_frame.keys():
282
- print(f"no valid {control_mask_type}")
283
-
284
- # Get the control frame for the control mask type
285
- control_frame = control_sam2_frame[valid_id][control_mask_type]
286
-
287
- # Filter frames to only those which are in both valid_frame and batch_frame_set
288
- valid_frames = []
289
- # valid_frames = [frame for frame in mask_valid_frames if frame in control_valid_frames and frame in batch_frame_set]
290
- for frame in mask_valid_frames:
291
- if frame in control_valid_frames and frame in batch_frame_set:
292
- # Check if bbox_data has 'head' or 'face' for the frame
293
- if str(frame) in bbox_data:
294
- frame_data = bbox_data[str(frame)]
295
- if 'head' in frame_data or 'face' in frame_data:
296
- valid_frames.append(frame)
297
-
298
- # Ensure the control frame is included in the valid frames
299
- if ensure_control_frame and (control_frame not in valid_frames):
300
- valid_frames.append(control_frame)
301
-
302
- # Select a random number of frames between min_frames and max_frames
303
- num_frames_to_select = random.randint(min_frames, max_frames)
304
- selected_frames = select_frames_with_distance_constraint(valid_frames, num_frames_to_select, min_distance,
305
- control_frame, bbox_data, valid_id, ensure_control_frame)
306
-
307
- # Store the selected frames as mask PNGs and bbox
308
- selected_masks_dict[valid_id] = []
309
- selected_bboxs_dict[valid_id] = []
310
-
311
- # Initialize the dense_masks_dict entry for the current ID
312
- dense_masks_dict[valid_id] = []
313
-
314
- # Store the selected frames in the dictionary
315
- selected_frames_dict[valid_id] = selected_frames
316
-
317
- if dense_masks:
318
- for frame in original_batch_frame:
319
- mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{int(frame):05d}.png"
320
- mask_array = np.array(Image.open(mask_data_path))
321
- binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8)
322
- dense_masks_dict[valid_id].append(binary_mask)
323
-
324
- for frame in selected_frames:
325
- mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{frame:05d}.png"
326
- mask_array = np.array(Image.open(mask_data_path))
327
- binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8)
328
- selected_masks_dict[valid_id].append(binary_mask)
329
-
330
- try:
331
- for item in bbox_data[f"{frame}"]["head"]:
332
- if item['new_track_id'] == int(valid_id):
333
- temp_bbox = item['box']
334
- break
335
- except (KeyError, StopIteration):
336
- try:
337
- for item in bbox_data[f"{frame}"]["face"]:
338
- if item['new_track_id'] == int(valid_id):
339
- temp_bbox = item['box']
340
- break
341
- except (KeyError, StopIteration):
342
- temp_bbox = None
343
-
344
- selected_bboxs_dict[valid_id].append(temp_bbox)
345
-
346
- return selected_frames_dict, selected_masks_dict, selected_bboxs_dict, dense_masks_dict
347
-
348
- def pad_tensor(tensor, target_size, dim=0):
349
- padding_size = target_size - tensor.size(dim)
350
- if padding_size > 0:
351
- pad_shape = list(tensor.shape)
352
- pad_shape[dim] = padding_size
353
- padding_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
354
- return torch.cat([tensor, padding_tensor], dim=dim)
355
- else:
356
- return tensor[:target_size]
357
-
358
- def crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=False):
359
- """
360
- Crop images based on given bounding boxes and frame indices from a video.
361
-
362
- Args:
363
- selected_frame_index (list): List of frame indices to be cropped.
364
- selected_bboxs_dict (list of dict): List of dictionaries, each containing 'x1', 'y1', 'x2', 'y2' bounding box coordinates.
365
- video_reader (VideoReader or list of numpy arrays): Video frames accessible by index, where each frame is a numpy array (H, W, C).
366
-
367
- Returns:
368
- list: A list of cropped images in PIL Image format.
369
- """
370
- expanded_cropped_images = []
371
- original_cropped_images = []
372
- for frame_idx, bbox in zip(selected_frame_index, selected_bboxs_dict):
373
- # Get the specific frame from the video reader using the frame index
374
- frame = video_reader[frame_idx] # torch.tensor # (H, W, C)
375
-
376
- # Extract bounding box coordinates and convert them to integers
377
- x1, y1, x2, y2 = int(bbox['x1']), int(bbox['y1']), int(bbox['x2']), int(bbox['y2'])
378
- # Crop to minimize the bounding box to a square
379
- width = x2 - x1 # Calculate the width of the bounding box
380
- height = y2 - y1 # Calculate the height of the bounding box
381
- side_length = max(width, height) # Determine the side length of the square (max of width or height)
382
-
383
- # Calculate the center of the bounding box
384
- center_x = (x1 + x2) // 2
385
- center_y = (y1 + y2) // 2
386
-
387
- # Calculate new coordinates for the square region centered around the original bounding box
388
- new_x1 = max(0, center_x - side_length // 2) # Ensure x1 is within image bounds
389
- new_y1 = max(0, center_y - side_length // 2) # Ensure y1 is within image bounds
390
- new_x2 = min(frame.shape[1], new_x1 + side_length) # Ensure x2 does not exceed image width
391
- new_y2 = min(frame.shape[0], new_y1 + side_length) # Ensure y2 does not exceed image height
392
-
393
- # Adjust coordinates if the cropped area is smaller than the desired side length
394
- # Ensure final width and height are equal, keeping it a square
395
- actual_width = new_x2 - new_x1
396
- actual_height = new_y2 - new_y1
397
-
398
- if actual_width < side_length:
399
- # Adjust x1 or x2 to ensure the correct side length, while staying in bounds
400
- if new_x1 == 0:
401
- new_x2 = min(frame.shape[1], new_x1 + side_length)
402
- else:
403
- new_x1 = max(0, new_x2 - side_length)
404
-
405
- if actual_height < side_length:
406
- # Adjust y1 or y2 to ensure the correct side length, while staying in bounds
407
- if new_y1 == 0:
408
- new_y2 = min(frame.shape[0], new_y1 + side_length)
409
- else:
410
- new_y1 = max(0, new_y2 - side_length)
411
-
412
- # Expand the square by 20%
413
- expansion_ratio = 0.2 # Define the expansion ratio
414
- expansion_amount = int(side_length * expansion_ratio) # Calculate the number of pixels to expand by
415
-
416
- # Calculate expanded coordinates, ensuring they stay within image bounds
417
- expanded_x1 = max(0, new_x1 - expansion_amount) # Expand left, ensuring x1 is within bounds
418
- expanded_y1 = max(0, new_y1 - expansion_amount) # Expand up, ensuring y1 is within bounds
419
- expanded_x2 = min(frame.shape[1], new_x2 + expansion_amount) # Expand right, ensuring x2 does not exceed bounds
420
- expanded_y2 = min(frame.shape[0], new_y2 + expansion_amount) # Expand down, ensuring y2 does not exceed bounds
421
-
422
- # Ensure the expanded area is still a square
423
- expanded_width = expanded_x2 - expanded_x1
424
- expanded_height = expanded_y2 - expanded_y1
425
- final_side_length = min(expanded_width, expanded_height)
426
-
427
- # Adjust to ensure square shape if necessary
428
- if expanded_width != expanded_height:
429
- if expanded_width > expanded_height:
430
- expanded_x2 = expanded_x1 + final_side_length
431
- else:
432
- expanded_y2 = expanded_y1 + final_side_length
433
-
434
- expanded_cropped_rgb_tensor = frame[expanded_y1:expanded_y2, expanded_x1:expanded_x2, :]
435
- expanded_cropped_rgb = Image.fromarray(np.array(expanded_cropped_rgb_tensor)).convert('RGB')
436
- expanded_cropped_images.append(expanded_cropped_rgb)
437
-
438
- if return_ori:
439
- original_cropped_rgb_tensor = frame[new_y1:new_y2, new_x1:new_x2, :]
440
- original_cropped_rgb = Image.fromarray(np.array(original_cropped_rgb_tensor)).convert('RGB')
441
- original_cropped_images.append(original_cropped_rgb)
442
- return expanded_cropped_images, original_cropped_images
443
-
444
- return expanded_cropped_images, None
445
-
446
- def process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480)):
447
- """
448
- Process a list of cropped images in PIL format.
449
-
450
- Parameters:
451
- expand_images_pil (list of PIL.Image): List of cropped images in PIL format.
452
- target_size (tuple of int): The target size for resizing images, default is (480, 480).
453
-
454
- Returns:
455
- torch.Tensor: A tensor containing the processed images.
456
- """
457
- expand_face_imgs = []
458
- original_face_imgs = []
459
- if len(original_images_pil) != 0:
460
- for expand_img, original_img in zip(expand_images_pil, original_images_pil):
461
- expand_resized_img = expand_img.resize(target_size, Image.LANCZOS)
462
- expand_src_img = np.array(expand_resized_img)
463
- expand_src_img = np.transpose(expand_src_img, (2, 0, 1))
464
- expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float()
465
- expand_face_imgs.append(expand_src_img)
466
-
467
- original_resized_img = original_img.resize(target_size, Image.LANCZOS)
468
- original_src_img = np.array(original_resized_img)
469
- original_src_img = np.transpose(original_src_img, (2, 0, 1))
470
- original_src_img = torch.from_numpy(original_src_img).unsqueeze(0).float()
471
- original_face_imgs.append(original_src_img)
472
-
473
- expand_face_imgs = torch.cat(expand_face_imgs, dim=0)
474
- original_face_imgs = torch.cat(original_face_imgs, dim=0)
475
- else:
476
- for expand_img in expand_images_pil:
477
- expand_resized_img = expand_img.resize(target_size, Image.LANCZOS)
478
- expand_src_img = np.array(expand_resized_img)
479
- expand_src_img = np.transpose(expand_src_img, (2, 0, 1))
480
- expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float()
481
- expand_face_imgs.append(expand_src_img)
482
- expand_face_imgs = torch.cat(expand_face_imgs, dim=0)
483
- original_face_imgs = None
484
-
485
- return expand_face_imgs, original_face_imgs
486
-
487
- class RandomSampler(Sampler[int]):
488
- r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
489
-
490
- If with replacement, then user can specify :attr:`num_samples` to draw.
491
-
492
- Args:
493
- data_source (Dataset): dataset to sample from
494
- replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
495
- num_samples (int): number of samples to draw, default=`len(dataset)`.
496
- generator (Generator): Generator used in sampling.
497
- """
498
-
499
- data_source: Sized
500
- replacement: bool
501
-
502
- def __init__(self, data_source: Sized, replacement: bool = False,
503
- num_samples: Optional[int] = None, generator=None) -> None:
504
- self.data_source = data_source
505
- self.replacement = replacement
506
- self._num_samples = num_samples
507
- self.generator = generator
508
- self._pos_start = 0
509
-
510
- if not isinstance(self.replacement, bool):
511
- raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
512
-
513
- if not isinstance(self.num_samples, int) or self.num_samples <= 0:
514
- raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
515
-
516
- @property
517
- def num_samples(self) -> int:
518
- # dataset size might change at runtime
519
- if self._num_samples is None:
520
- return len(self.data_source)
521
- return self._num_samples
522
-
523
- def __iter__(self) -> Iterator[int]:
524
- n = len(self.data_source)
525
- if self.generator is None:
526
- seed = int(torch.empty((), dtype=torch.int64).random_().item())
527
- generator = torch.Generator()
528
- generator.manual_seed(seed)
529
- else:
530
- generator = self.generator
531
-
532
- if self.replacement:
533
- for _ in range(self.num_samples // 32):
534
- yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
535
- yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
536
- else:
537
- for _ in range(self.num_samples // n):
538
- xx = torch.randperm(n, generator=generator).tolist()
539
- if self._pos_start >= n:
540
- self._pos_start = 0
541
- print("xx top 10", xx[:10], self._pos_start)
542
- for idx in range(self._pos_start, n):
543
- yield xx[idx]
544
- self._pos_start = (self._pos_start + 1) % n
545
- self._pos_start = 0
546
- yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
547
-
548
- def __len__(self) -> int:
549
- return self.num_samples
550
-
551
- class SequentialSampler(Sampler[int]):
552
- r"""Samples elements sequentially, always in the same order.
553
-
554
- Args:
555
- data_source (Dataset): dataset to sample from
556
- """
557
-
558
- data_source: Sized
559
-
560
- def __init__(self, data_source: Sized) -> None:
561
- self.data_source = data_source
562
- self._pos_start = 0
563
-
564
- def __iter__(self) -> Iterator[int]:
565
- n = len(self.data_source)
566
- for idx in range(self._pos_start, n):
567
- yield idx
568
- self._pos_start = (self._pos_start + 1) % n
569
- self._pos_start = 0
570
-
571
- def __len__(self) -> int:
572
- return len(self.data_source)
573
-
574
- class ConsisID_Dataset(Dataset):
575
- def __init__(
576
- self,
577
- instance_data_root: Optional[str] = None,
578
- id_token: Optional[str] = None,
579
- height=480,
580
- width=640,
581
- max_num_frames=49,
582
- sample_stride=3,
583
- skip_frames_start_percent=0.0,
584
- skip_frames_end_percent=1.0,
585
- skip_frames_start=0,
586
- skip_frames_end=0,
587
- text_drop_ratio=-1,
588
- is_train_face=False,
589
- is_single_face=False,
590
- miss_tolerance=6,
591
- min_distance=3,
592
- min_frames=1,
593
- max_frames=5,
594
- is_cross_face=False,
595
- is_reserve_face=False,
596
- ):
597
- self.id_token = id_token or ""
598
-
599
- # ConsisID
600
- self.skip_frames_start_percent = skip_frames_start_percent
601
- self.skip_frames_end_percent = skip_frames_end_percent
602
- self.skip_frames_start = skip_frames_start
603
- self.skip_frames_end = skip_frames_end
604
- self.is_train_face = is_train_face
605
- self.is_single_face = is_single_face
606
-
607
- if is_train_face:
608
- self.miss_tolerance = miss_tolerance
609
- self.min_distance = min_distance
610
- self.min_frames = min_frames
611
- self.max_frames = max_frames
612
- self.is_cross_face = is_cross_face
613
- self.is_reserve_face = is_reserve_face
614
-
615
- # Loading annotations from files
616
- print(f"loading annotations from {instance_data_root} ...")
617
- with open(instance_data_root, 'r') as f:
618
- folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]
619
-
620
- self.instance_prompts = []
621
- self.instance_video_paths = []
622
- self.instance_annotation_base_paths = []
623
- for sub_root, anno, anno_base in tqdm(folder_anno):
624
- print(anno)
625
- self.instance_annotation_base_paths.append(anno_base)
626
- with open(anno, 'r') as f:
627
- sub_list = json.load(f)
628
- for i in tqdm(sub_list):
629
- path = os.path.join(sub_root, os.path.basename(i['path']))
630
- cap = i.get('cap', None)
631
- fps = i.get('fps', 0)
632
- duration = i.get('duration', 0)
633
-
634
- if fps * duration < 49.0:
635
- continue
636
-
637
- self.instance_prompts.append(cap)
638
- self.instance_video_paths.append(path)
639
-
640
- self.num_instance_videos = len(self.instance_video_paths)
641
-
642
- self.text_drop_ratio = text_drop_ratio
643
-
644
- # Video params
645
- self.sample_stride = sample_stride
646
- self.max_num_frames = max_num_frames
647
- self.height = height
648
- self.width = width
649
-
650
- def _get_frame_indices_adjusted(self, video_length, n_frames):
651
- indices = list(range(video_length))
652
- additional_frames_needed = n_frames - video_length
653
-
654
- repeat_indices = []
655
- for i in range(additional_frames_needed):
656
- index_to_repeat = i % video_length
657
- repeat_indices.append(indices[index_to_repeat])
658
-
659
- all_indices = indices + repeat_indices
660
- all_indices.sort()
661
-
662
- return all_indices
663
-
664
-
665
- def _generate_frame_indices(self, video_length, n_frames, sample_stride, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0):
666
- if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0:
667
- print("use skip frame percent")
668
- valid_start = int(video_length * skip_frames_start_percent)
669
- valid_end = int(video_length * skip_frames_end_percent)
670
- elif skip_frames_start != 0 or skip_frames_end != 0:
671
- print("use skip frame")
672
- valid_start = skip_frames_start
673
- valid_end = video_length - skip_frames_end
674
- else:
675
- print("no use skip frame")
676
- valid_start = 0
677
- valid_end = video_length
678
-
679
- adjusted_length = valid_end - valid_start
680
-
681
- if adjusted_length <= 0:
682
- print(f"video_length: {video_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}")
683
- raise ValueError("Skipping too many frames results in no frames left to sample.")
684
-
685
- if video_length <= n_frames:
686
- return self._get_frame_indices_adjusted(video_length, n_frames)
687
- else:
688
- # clip_length = min(video_length, (n_frames - 1) * sample_stride + 1)
689
- # start_idx = random.randint(0, video_length - clip_length)
690
- # frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
691
-
692
- clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1)
693
- start_idx = random.randint(valid_start, valid_end - clip_length)
694
- frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
695
- return frame_indices
696
-
697
- def _short_resize_and_crop(self, frames, target_width, target_height):
698
- """
699
- Resize frames and crop to the specified size.
700
-
701
- Args:
702
- frames (torch.Tensor): Input frames of shape [T, H, W, C].
703
- target_width (int): Desired width.
704
- target_height (int): Desired height.
705
-
706
- Returns:
707
- torch.Tensor: Cropped frames of shape [T, target_height, target_width, C].
708
- """
709
- T, H, W, C = frames.shape
710
- aspect_ratio = W / H
711
-
712
- # Determine new dimensions ensuring they are at least target size
713
- if aspect_ratio > target_width / target_height:
714
- new_width = target_width
715
- new_height = int(target_width / aspect_ratio)
716
- if new_height < target_height:
717
- new_height = target_height
718
- new_width = int(target_height * aspect_ratio)
719
- else:
720
- new_height = target_height
721
- new_width = int(target_height * aspect_ratio)
722
- if new_width < target_width:
723
- new_width = target_width
724
- new_height = int(target_width / aspect_ratio)
725
-
726
- resize_transform = transforms.Resize((new_height, new_width))
727
- crop_transform = transforms.CenterCrop((target_height, target_width))
728
-
729
- frames_tensor = frames.permute(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
730
- resized_frames = resize_transform(frames_tensor)
731
- cropped_frames = crop_transform(resized_frames)
732
- sample = cropped_frames.permute(0, 2, 3, 1)
733
-
734
- return sample
735
-
736
- def _resize_with_aspect_ratio(self, frames, target_width, target_height):
737
- """
738
- Resize frames while maintaining the aspect ratio by padding or cropping.
739
-
740
- Args:
741
- frames (torch.Tensor): Input frames of shape [T, H, W, C].
742
- target_width (int): Desired width.
743
- target_height (int): Desired height.
744
-
745
- Returns:
746
- torch.Tensor: Resized and padded frames of shape [T, target_height, target_width, C].
747
- """
748
- T, frame_height, frame_width, C = frames.shape
749
- aspect_ratio = frame_width / frame_height # 1.77, 1280 720 -> 720 406
750
- target_aspect_ratio = target_width / target_height # 1.50, 720 480 ->
751
-
752
- # If the frame is wider than the target, resize based on width
753
- if aspect_ratio > target_aspect_ratio:
754
- new_width = target_width
755
- new_height = int(target_width / aspect_ratio)
756
- else:
757
- new_height = target_height
758
- new_width = int(target_height * aspect_ratio)
759
-
760
- # Resize using batch processing
761
- frames = frames.permute(0, 3, 1, 2) # [T, C, H, W]
762
- frames = F.interpolate(frames, size=(new_height, new_width), mode='bilinear', align_corners=False)
763
-
764
- # Calculate padding
765
- pad_top = (target_height - new_height) // 2
766
- pad_bottom = target_height - new_height - pad_top
767
- pad_left = (target_width - new_width) // 2
768
- pad_right = target_width - new_width - pad_left
769
-
770
- # Apply padding
771
- frames = F.pad(frames, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
772
-
773
- frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
774
-
775
- return frames
776
-
777
-
778
- def _save_frame(self, frame, name="1.png"):
779
- # [H, W, C] -> [C, H, W]
780
- img = frame
781
- img = img.permute(2, 0, 1)
782
- to_pil = ToPILImage()
783
- img = to_pil(img)
784
- img.save(name)
785
-
786
-
787
- def _save_video(self, torch_frames, name="output.mp4"):
788
- from moviepy.editor import ImageSequenceClip
789
- frames_np = torch_frames.cpu().numpy()
790
- if frames_np.dtype != 'uint8':
791
- frames_np = frames_np.astype('uint8')
792
- frames_list = [frame for frame in frames_np]
793
- desired_fps = 24
794
- clip = ImageSequenceClip(frames_list, fps=desired_fps)
795
- clip.write_videofile(name, codec="libx264")
796
-
797
-
798
- def get_batch(self, idx):
799
- decord.bridge.set_bridge("torch")
800
-
801
- video_dir = self.instance_video_paths[idx]
802
- text = self.instance_prompts[idx]
803
-
804
- train_transforms = transforms.Compose(
805
- [
806
- transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
807
- ]
808
- )
809
-
810
- with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
811
- video_num_frames = len(video_reader)
812
-
813
- if self.is_train_face:
814
- reserve_face_imgs = None
815
- file_base_name = os.path.basename(video_dir.replace(".mp4", ""))
816
-
817
- anno_base_path = self.instance_annotation_base_paths[idx]
818
- valid_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "valid_frame.json")
819
- control_sam2_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "control_sam2_frame.json")
820
- corresponding_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "corresponding_data.json")
821
- masks_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "tracking_mask_results")
822
- bboxs_data_path = os.path.join(anno_base_path, "refine_bbox_jsons", f"{file_base_name}.json")
823
-
824
- with open(corresponding_data_path, 'r') as f:
825
- corresponding_data = json.load(f)
826
-
827
- with open(control_sam2_frame_path, 'r') as f:
828
- control_sam2_frame = json.load(f)
829
-
830
- with open(valid_frame_path, 'r') as f:
831
- valid_frame = json.load(f)
832
-
833
- with open(bboxs_data_path, 'r') as f:
834
- bbox_data = json.load(f)
835
-
836
- if self.is_single_face:
837
- if len(corresponding_data) != 1:
838
- raise ValueError(f"Using single face, but {idx} is multi person.")
839
-
840
- # get random valid id
841
- valid_ids = []
842
- backup_ids = []
843
- for id_key, data in corresponding_data.items():
844
- if 'face' in data and 'head' in data:
845
- valid_ids.append(id_key)
846
-
847
- valid_id = random.choice(valid_ids) if valid_ids else (random.choice(backup_ids) if backup_ids else None)
848
- if valid_id is None:
849
- raise ValueError("No valid ID found: both valid_ids and backup_ids are empty.")
850
-
851
- # get video
852
- total_index = list(range(video_num_frames))
853
- batch_index, _ = generate_frame_indices_for_face(self.max_num_frames, self.sample_stride, valid_frame[valid_id],
854
- self.miss_tolerance, self.skip_frames_start_percent, self.skip_frames_end_percent,
855
- self.skip_frames_start, self.skip_frames_end)
856
-
857
- if self.is_cross_face:
858
- remaining_batch_index_index = [i for i in total_index if i not in batch_index]
859
- try:
860
- selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
861
- remaining_batch_index_index,
862
- batch_index, valid_id,
863
- corresponding_data, control_sam2_frame,
864
- valid_frame[valid_id], bbox_data, masks_data_path,
865
- min_distance=self.min_distance, min_frames=self.min_frames,
866
- max_frames=self.max_frames, dense_masks=True,
867
- ensure_control_frame=False,
868
- )
869
- except:
870
- selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
871
- batch_index,
872
- batch_index, valid_id,
873
- corresponding_data, control_sam2_frame,
874
- valid_frame[valid_id], bbox_data, masks_data_path,
875
- min_distance=self.min_distance, min_frames=self.min_frames,
876
- max_frames=self.max_frames, dense_masks=True,
877
- ensure_control_frame=False,
878
- )
879
- else:
880
- selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
881
- batch_index,
882
- batch_index, valid_id,
883
- corresponding_data, control_sam2_frame,
884
- valid_frame[valid_id], bbox_data, masks_data_path,
885
- min_distance=self.min_distance, min_frames=self.min_frames,
886
- max_frames=self.max_frames, dense_masks=True,
887
- ensure_control_frame=True,
888
- )
889
- if self.is_reserve_face:
890
- reserve_frame_index, _, reserve_bboxs_dict, _ = select_mask_frames_from_index(
891
- batch_index,
892
- batch_index, valid_id,
893
- corresponding_data, control_sam2_frame,
894
- valid_frame[valid_id], bbox_data, masks_data_path,
895
- min_distance=3, min_frames=4,
896
- max_frames=4, dense_masks=False,
897
- ensure_control_frame=False,
898
- )
899
-
900
- # get mask and aligned_face_img
901
- selected_frame_index = selected_frame_index[valid_id]
902
- valid_frame = valid_frame[valid_id]
903
- selected_masks_dict = selected_masks_dict[valid_id]
904
- selected_bboxs_dict = selected_bboxs_dict[valid_id]
905
- dense_masks_dict = dense_masks_dict[valid_id]
906
-
907
- if self.is_reserve_face:
908
- reserve_frame_index = reserve_frame_index[valid_id]
909
- reserve_bboxs_dict = reserve_bboxs_dict[valid_id]
910
-
911
- selected_masks_tensor = torch.stack([torch.tensor(mask) for mask in selected_masks_dict])
912
- temp_dense_masks_tensor = torch.stack([torch.tensor(mask) for mask in dense_masks_dict])
913
- dense_masks_tensor = self._short_resize_and_crop(temp_dense_masks_tensor.unsqueeze(-1), self.width, self.height).squeeze(-1) # [T, H, W] -> [T, H, W, 1] -> [T, H, W]
914
-
915
- expand_images_pil, original_images_pil = crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=True)
916
- expand_face_imgs, original_face_imgs = process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480))
917
- if self.is_reserve_face:
918
- reserve_images_pil, _ = crop_images(reserve_frame_index, reserve_bboxs_dict, video_reader, return_ori=False)
919
- reserve_face_imgs, _ = process_cropped_images(reserve_images_pil, [], target_size=(480, 480))
920
-
921
- if len(expand_face_imgs) == 0 or len(original_face_imgs) == 0:
922
- raise ValueError(f"No face detected in input image pool")
923
-
924
- # post process id related data
925
- expand_face_imgs = pad_tensor(expand_face_imgs, self.max_frames, dim=0)
926
- original_face_imgs = pad_tensor(original_face_imgs, self.max_frames, dim=0)
927
- selected_frame_index = torch.tensor(selected_frame_index) # torch.Size(([15, 13]) [N1]
928
- selected_frame_index = pad_tensor(selected_frame_index, self.max_frames, dim=0)
929
- else:
930
- batch_index = self._generate_frame_indices(video_num_frames, self.max_num_frames, self.sample_stride,
931
- self.skip_frames_start_percent, self.skip_frames_end_percent,
932
- self.skip_frames_start, self.skip_frames_end)
933
-
934
- try:
935
- frames = video_reader.get_batch(batch_index) # torch [T, H, W, C]
936
- frames = self._short_resize_and_crop(frames, self.width, self.height) # [T, H, W, C]
937
- except FunctionTimedOut:
938
- raise ValueError(f"Read {idx} timeout.")
939
- except Exception as e:
940
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
941
-
942
- # Apply training transforms in batch
943
- frames = frames.float()
944
- frames = train_transforms(frames)
945
- pixel_values = frames.permute(0, 3, 1, 2).contiguous() # [T, C, H, W]
946
- del video_reader
947
-
948
- # Random use no text generation
949
- if random.random() < self.text_drop_ratio:
950
- text = ''
951
-
952
- if self.is_train_face:
953
- return pixel_values, text, 'video', video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs
954
- else:
955
- return pixel_values, text, 'video', video_dir
956
-
957
- def __len__(self):
958
- return self.num_instance_videos
959
-
960
- def __getitem__(self, idx):
961
- sample = {}
962
- if self.is_train_face:
963
- pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx)
964
- sample["instance_prompt"] = self.id_token + cap
965
- sample["instance_video"] = pixel_values
966
- sample["video_path"] = video_dir
967
- if self.is_train_face:
968
- sample["expand_face_imgs"] = expand_face_imgs
969
- sample["dense_masks_tensor"] = dense_masks_tensor
970
- sample["selected_frame_index"] = selected_frame_index
971
- if reserve_face_imgs is not None:
972
- sample["reserve_face_imgs"] = reserve_face_imgs
973
- if original_face_imgs is not None:
974
- sample["original_face_imgs"] = original_face_imgs
975
- else:
976
- pixel_values, cap, data_type, video_dir = self.get_batch(idx)
977
- sample["instance_prompt"] = self.id_token + cap
978
- sample["instance_video"] = pixel_values
979
- sample["video_path"] = video_dir
980
- return sample
981
-
982
- # while True:
983
- # sample = {}
984
- # try:
985
- # if self.is_train_face:
986
- # pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx)
987
- # sample["instance_prompt"] = self.id_token + cap
988
- # sample["instance_video"] = pixel_values
989
- # sample["video_path"] = video_dir
990
- # if self.is_train_face:
991
- # sample["expand_face_imgs"] = expand_face_imgs
992
- # sample["dense_masks_tensor"] = dense_masks_tensor
993
- # sample["selected_frame_index"] = selected_frame_index
994
- # if reserve_face_imgs is not None:
995
- # sample["reserve_face_imgs"] = reserve_face_imgs
996
- # if original_face_imgs is not None:
997
- # sample["original_face_imgs"] = original_face_imgs
998
- # else:
999
- # pixel_values, cap, data_type, video_dir, = self.get_batch(idx)
1000
- # sample["instance_prompt"] = self.id_token + cap
1001
- # sample["instance_video"] = pixel_values
1002
- # sample["video_path"] = video_dir
1003
- # break
1004
- # except Exception as e:
1005
- # error_message = str(e)
1006
- # video_path = self.instance_video_paths[idx % len(self.instance_video_paths)]
1007
- # print(error_message, video_path)
1008
- # log_error_to_file(error_message, video_path)
1009
- # idx = random.randint(0, self.num_instance_videos - 1)
1010
- # return sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util/deepspeed_configs/accelerate_config_machine_multi.yaml DELETED
@@ -1,18 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- distributed_type: DEEPSPEED
3
- deepspeed_config:
4
- deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json
5
- deepspeed_hostfile: util/deepspeed_configs/hostfile.txt
6
- fsdp_config: {}
7
- machine_rank: 0
8
- main_process_ip: 100.64.24.6
9
- main_process_port: 12343
10
- main_training_function: main
11
- num_machines: 2
12
- num_processes: 16
13
- rdzv_backend: static
14
- same_network: true
15
- tpu_env: []
16
- tpu_use_cluster: false
17
- tpu_use_sudo: false
18
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util/deepspeed_configs/accelerate_config_machine_single.yaml DELETED
@@ -1,13 +0,0 @@
1
- compute_environment: LOCAL_MACHINE
2
- distributed_type: DEEPSPEED
3
- deepspeed_config:
4
- deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json
5
- fsdp_config: {}
6
- machine_rank: 0
7
- main_process_ip: null
8
- main_process_port: 12345
9
- main_training_function: main
10
- num_machines: 1
11
- num_processes: 8
12
- gpu_ids: 0,1,2,3,4,5,6,7
13
- use_cpu: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
util/deepspeed_configs/hostfile.txt DELETED
@@ -1,2 +0,0 @@
1
- [email protected] slots=8
2
- [email protected] slots=8
 
 
 
util/deepspeed_configs/zero_stage2_config.json DELETED
@@ -1,17 +0,0 @@
1
- {
2
- "bf16": {
3
- "enabled": true
4
- },
5
- "train_micro_batch_size_per_gpu": "auto",
6
- "train_batch_size": "auto",
7
- "gradient_clipping": 1.0,
8
- "gradient_accumulation_steps": "auto",
9
- "dump_state": true,
10
- "zero_optimization": {
11
- "stage": 2,
12
- "overlap_comm": true,
13
- "contiguous_gradients": true,
14
- "sub_group_size": 1e9,
15
- "reduce_bucket_size": 5e8
16
- }
17
- }