Spaces:
Runtime error
Runtime error
smile123456789
commited on
Commit
·
b65930c
1
Parent(s):
c7b92cf
reorganize code
Browse files- .gitattributes +2 -0
- app.py +27 -94
- asserts/example_images/4.png +0 -0
- models/local_facial_extractor.py +75 -35
- models/pipeline_consisid.py +65 -36
- models/transformer_consisid.py +70 -35
- models/utils.py +102 -12
- requirements.txt +1 -1
- util/dataloader.py +0 -1010
- util/deepspeed_configs/accelerate_config_machine_multi.yaml +0 -18
- util/deepspeed_configs/accelerate_config_machine_single.yaml +0 -13
- util/deepspeed_configs/hostfile.txt +0 -2
- util/deepspeed_configs/zero_stage2_config.json +0 -17
.gitattributes
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
__pycache__/
|
2 |
+
|
3 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
4 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
5 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,37 +1,26 @@
|
|
1 |
import os
|
2 |
import math
|
3 |
import time
|
4 |
-
import numpy
|
5 |
import spaces
|
6 |
import random
|
7 |
import threading
|
8 |
import gradio as gr
|
9 |
-
from PIL import Image, ImageOps
|
10 |
from moviepy import VideoFileClip
|
11 |
from datetime import datetime, timedelta
|
12 |
from huggingface_hub import hf_hub_download, snapshot_download
|
13 |
|
14 |
-
import insightface
|
15 |
-
from insightface.app import FaceAnalysis
|
16 |
-
from facexlib.parsing import init_parsing_model
|
17 |
-
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
18 |
-
|
19 |
import torch
|
20 |
-
from diffusers import CogVideoXDPMScheduler
|
21 |
-
from diffusers.utils import load_image
|
22 |
from diffusers.image_processor import VaeImageProcessor
|
23 |
from diffusers.training_utils import free_memory
|
24 |
|
25 |
from util.utils import *
|
26 |
from util.rife_model import load_rife_model, rife_inference_with_latents
|
27 |
-
from models.utils import
|
28 |
from models.transformer_consisid import ConsisIDTransformer3DModel
|
29 |
from models.pipeline_consisid import ConsisIDPipeline
|
30 |
-
from models.eva_clip import create_model_and_transforms
|
31 |
-
from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
32 |
-
from models.eva_clip.utils_qformer import resize_numpy_image_long
|
33 |
|
34 |
|
|
|
35 |
model_path = "ckpts"
|
36 |
|
37 |
lora_path = None
|
@@ -51,72 +40,30 @@ if os.path.exists(os.path.join(model_path, "transformer_ema")):
|
|
51 |
subfolder = "transformer_ema"
|
52 |
else:
|
53 |
subfolder = "transformer"
|
54 |
-
|
55 |
-
transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
|
56 |
-
scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
#
|
64 |
-
|
65 |
-
upscale_factor=1,
|
66 |
-
face_size=512,
|
67 |
-
crop_ratio=(1, 1),
|
68 |
-
det_model='retinaface_resnet50',
|
69 |
-
save_ext='png',
|
70 |
-
device=device,
|
71 |
-
model_rootpath=os.path.join(model_path, "face_encoder")
|
72 |
-
)
|
73 |
-
face_helper.face_parse = None
|
74 |
-
face_helper.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
|
75 |
-
face_helper.face_det.eval()
|
76 |
-
face_helper.face_parse.eval()
|
77 |
-
|
78 |
-
model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
|
79 |
-
face_clip_model = model.visual
|
80 |
-
face_clip_model.eval()
|
81 |
-
|
82 |
-
eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
|
83 |
-
eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
|
84 |
-
if not isinstance(eva_transform_mean, (list, tuple)):
|
85 |
-
eva_transform_mean = (eva_transform_mean,) * 3
|
86 |
-
if not isinstance(eva_transform_std, (list, tuple)):
|
87 |
-
eva_transform_std = (eva_transform_std,) * 3
|
88 |
-
eva_transform_mean = eva_transform_mean
|
89 |
-
eva_transform_std = eva_transform_std
|
90 |
-
|
91 |
-
face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
|
92 |
-
handler_ante = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
|
93 |
-
face_main_model.prepare(ctx_id=0, det_size=(640, 640))
|
94 |
-
handler_ante.prepare(ctx_id=0)
|
95 |
-
|
96 |
-
face_clip_model.to(device, dtype=dtype)
|
97 |
-
face_helper.face_det.to(device)
|
98 |
-
face_helper.face_parse.to(device)
|
99 |
transformer.to(device, dtype=dtype)
|
100 |
-
|
101 |
|
102 |
-
pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, scheduler=scheduler, torch_dtype=dtype)
|
103 |
# If you're using with lora, add this code
|
104 |
if lora_path:
|
105 |
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
106 |
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
107 |
|
108 |
-
scheduler_args = {}
|
109 |
-
if "variance_type" in pipe.scheduler.config:
|
110 |
-
variance_type = pipe.scheduler.config.variance_type
|
111 |
-
if variance_type in ["learned", "learned_range"]:
|
112 |
-
variance_type = "fixed_small"
|
113 |
-
scheduler_args["variance_type"] = variance_type
|
114 |
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
pipe.to(device)
|
117 |
-
|
118 |
-
# Enable CPU offload for the model.
|
119 |
-
# turn on if you don't have multiple GPUs or enough GPU memory(such as H100) and it will cost more time in inference, it may also reduce the quality
|
120 |
pipe.enable_model_cpu_offload()
|
121 |
pipe.enable_sequential_cpu_offload()
|
122 |
# pipe.vae.enable_slicing()
|
@@ -125,6 +72,7 @@ pipe.enable_sequential_cpu_offload()
|
|
125 |
os.makedirs("./output", exist_ok=True)
|
126 |
os.makedirs("./gradio_tmp", exist_ok=True)
|
127 |
|
|
|
128 |
upscale_model = load_sd_upscale(f"{model_path}/model_real_esran/RealESRGAN_x4.pth", device)
|
129 |
frame_interpolation_model = load_rife_model(f"{model_path}/model_rife")
|
130 |
|
@@ -142,34 +90,21 @@ def generate(
|
|
142 |
if seed == -1:
|
143 |
seed = random.randint(0, 2**8 - 1)
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper, face_clip_model, handler_ante,
|
148 |
eva_transform_mean, eva_transform_std,
|
149 |
-
face_main_model, device, dtype,
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
if is_kps
|
154 |
-
kps_cond = face_kps
|
155 |
-
else:
|
156 |
-
kps_cond = None
|
157 |
-
|
158 |
-
tensor = align_crop_face_image.cpu().detach()
|
159 |
-
tensor = tensor.squeeze()
|
160 |
-
tensor = tensor.permute(1, 2, 0)
|
161 |
-
tensor = tensor.numpy() * 255
|
162 |
-
tensor = tensor.astype(np.uint8)
|
163 |
-
image = ImageOps.exif_transpose(Image.fromarray(tensor))
|
164 |
|
165 |
prompt = prompt.strip('"')
|
166 |
-
if len(negative_prompt) == 0:
|
167 |
-
negative_prompt = None
|
168 |
if negative_prompt:
|
169 |
negative_prompt = negative_prompt.strip('"')
|
170 |
|
171 |
-
|
172 |
-
|
173 |
video_pt = pipe(
|
174 |
prompt=prompt,
|
175 |
negative_prompt=negative_prompt,
|
@@ -388,8 +323,6 @@ with gr.Blocks() as demo:
|
|
388 |
seed_update = gr.update(visible=True, value=seed)
|
389 |
|
390 |
return video_path, video_update, gif_update, seed_update
|
391 |
-
|
392 |
-
run.zerogpu = True
|
393 |
|
394 |
generate_button.click(
|
395 |
fn=run,
|
@@ -400,4 +333,4 @@ with gr.Blocks() as demo:
|
|
400 |
|
401 |
if __name__ == "__main__":
|
402 |
demo.queue(max_size=15)
|
403 |
-
demo.launch()
|
|
|
1 |
import os
|
2 |
import math
|
3 |
import time
|
|
|
4 |
import spaces
|
5 |
import random
|
6 |
import threading
|
7 |
import gradio as gr
|
|
|
8 |
from moviepy import VideoFileClip
|
9 |
from datetime import datetime, timedelta
|
10 |
from huggingface_hub import hf_hub_download, snapshot_download
|
11 |
|
|
|
|
|
|
|
|
|
|
|
12 |
import torch
|
|
|
|
|
13 |
from diffusers.image_processor import VaeImageProcessor
|
14 |
from diffusers.training_utils import free_memory
|
15 |
|
16 |
from util.utils import *
|
17 |
from util.rife_model import load_rife_model, rife_inference_with_latents
|
18 |
+
from models.utils import process_face_embeddings_infer, prepare_face_models
|
19 |
from models.transformer_consisid import ConsisIDTransformer3DModel
|
20 |
from models.pipeline_consisid import ConsisIDPipeline
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
+
# 0. Pre config
|
24 |
model_path = "ckpts"
|
25 |
|
26 |
lora_path = None
|
|
|
40 |
subfolder = "transformer_ema"
|
41 |
else:
|
42 |
subfolder = "transformer"
|
|
|
|
|
|
|
43 |
|
44 |
+
|
45 |
+
# 1. Prepare all the face models
|
46 |
+
face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = prepare_face_models(model_path, device, dtype)
|
47 |
+
|
48 |
+
|
49 |
+
# 2. Load Pipeline.
|
50 |
+
transformer = ConsisIDTransformer3DModel.from_pretrained_cus(model_path, subfolder=subfolder)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
transformer.to(device, dtype=dtype)
|
52 |
+
pipe = ConsisIDPipeline.from_pretrained(model_path, transformer=transformer, torch_dtype=dtype)
|
53 |
|
|
|
54 |
# If you're using with lora, add this code
|
55 |
if lora_path:
|
56 |
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
|
57 |
pipe.fuse_lora(lora_scale=1 / lora_rank)
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
# 3. Move to device.
|
61 |
+
face_helper_1.face_det.to(device)
|
62 |
+
face_helper_1.face_parse.to(device)
|
63 |
+
face_clip_model.to(device, dtype=dtype)
|
64 |
+
transformer.to(device, dtype=dtype)
|
65 |
pipe.to(device)
|
66 |
+
# Save Memory. Turn on if you don't have multiple GPUs or enough GPU memory(such as H100) and it will cost more time in inference, it may also reduce the quality
|
|
|
|
|
67 |
pipe.enable_model_cpu_offload()
|
68 |
pipe.enable_sequential_cpu_offload()
|
69 |
# pipe.vae.enable_slicing()
|
|
|
72 |
os.makedirs("./output", exist_ok=True)
|
73 |
os.makedirs("./gradio_tmp", exist_ok=True)
|
74 |
|
75 |
+
# load upscale and interpolation model
|
76 |
upscale_model = load_sd_upscale(f"{model_path}/model_real_esran/RealESRGAN_x4.pth", device)
|
77 |
frame_interpolation_model = load_rife_model(f"{model_path}/model_rife")
|
78 |
|
|
|
90 |
if seed == -1:
|
91 |
seed = random.randint(0, 2**8 - 1)
|
92 |
|
93 |
+
# 4. Prepare model input
|
94 |
+
id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(face_helper_1, face_clip_model, face_helper_2,
|
|
|
95 |
eva_transform_mean, eva_transform_std,
|
96 |
+
face_main_model, device, dtype,
|
97 |
+
image_input, is_align_face=True)
|
98 |
+
|
99 |
+
is_kps = getattr(transformer.config, 'is_kps', False)
|
100 |
+
kps_cond = face_kps if is_kps else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
prompt = prompt.strip('"')
|
|
|
|
|
103 |
if negative_prompt:
|
104 |
negative_prompt = negative_prompt.strip('"')
|
105 |
|
106 |
+
# 5. Generate Identity-Preserving Video
|
107 |
+
generator = torch.Generator(device).manual_seed(seed) if seed else None
|
108 |
video_pt = pipe(
|
109 |
prompt=prompt,
|
110 |
negative_prompt=negative_prompt,
|
|
|
323 |
seed_update = gr.update(visible=True, value=seed)
|
324 |
|
325 |
return video_path, video_update, gif_update, seed_update
|
|
|
|
|
326 |
|
327 |
generate_button.click(
|
328 |
fn=run,
|
|
|
333 |
|
334 |
if __name__ == "__main__":
|
335 |
demo.queue(max_size=15)
|
336 |
+
demo.launch()
|
asserts/example_images/4.png
ADDED
models/local_facial_extractor.py
CHANGED
@@ -4,7 +4,18 @@ import torch.nn as nn
|
|
4 |
|
5 |
|
6 |
# FFN
|
7 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
inner_dim = int(dim * mult)
|
9 |
return nn.Sequential(
|
10 |
nn.LayerNorm(dim),
|
@@ -15,20 +26,41 @@ def FeedForward(dim, mult=4):
|
|
15 |
|
16 |
|
17 |
def reshape_tensor(x, heads):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
bs, length, width = x.shape
|
19 |
-
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
20 |
x = x.view(bs, length, heads, -1)
|
21 |
-
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
22 |
x = x.transpose(1, 2)
|
23 |
-
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
24 |
x = x.reshape(bs, heads, length, -1)
|
25 |
return x
|
26 |
|
27 |
|
28 |
class PerceiverAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
|
30 |
super().__init__()
|
31 |
-
self.scale = dim_head
|
32 |
self.dim_head = dim_head
|
33 |
self.heads = heads
|
34 |
inner_dim = dim_head * heads
|
@@ -42,21 +74,27 @@ class PerceiverAttention(nn.Module):
|
|
42 |
|
43 |
def forward(self, x, latents):
|
44 |
"""
|
|
|
|
|
45 |
Args:
|
46 |
-
x (torch.Tensor):
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
50 |
"""
|
|
|
51 |
x = self.norm1(x)
|
52 |
latents = self.norm2(latents)
|
53 |
|
54 |
-
b, seq_len, _ = latents.shape
|
55 |
|
|
|
56 |
q = self.to_q(latents)
|
57 |
kv_input = torch.cat((x, latents), dim=-2)
|
58 |
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
59 |
|
|
|
60 |
q = reshape_tensor(q, self.heads)
|
61 |
k = reshape_tensor(k, self.heads)
|
62 |
v = reshape_tensor(v, self.heads)
|
@@ -67,6 +105,7 @@ class PerceiverAttention(nn.Module):
|
|
67 |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
68 |
out = weight @ v
|
69 |
|
|
|
70 |
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
|
71 |
|
72 |
return self.to_out(out)
|
@@ -74,22 +113,22 @@ class PerceiverAttention(nn.Module):
|
|
74 |
|
75 |
class LocalFacialExtractor(nn.Module):
|
76 |
def __init__(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
):
|
87 |
"""
|
88 |
Initializes the LocalFacialExtractor class.
|
89 |
|
90 |
Parameters:
|
91 |
- dim (int): The dimensionality of latent features.
|
92 |
-
- depth (int): Total number of PerceiverAttention and
|
93 |
- dim_head (int): Dimensionality of each attention head.
|
94 |
- heads (int): Number of attention heads.
|
95 |
- num_id_token (int): Number of tokens used for identity features.
|
@@ -105,21 +144,21 @@ class LocalFacialExtractor(nn.Module):
|
|
105 |
self.num_queries = num_queries
|
106 |
assert depth % 5 == 0
|
107 |
self.depth = depth // 5
|
108 |
-
scale = dim
|
109 |
|
110 |
# Learnable latent query embeddings
|
111 |
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
|
112 |
# Projection layer to map the latent output to the desired dimension
|
113 |
self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
|
114 |
|
115 |
-
# Attention and
|
116 |
self.layers = nn.ModuleList([])
|
117 |
for _ in range(depth):
|
118 |
self.layers.append(
|
119 |
nn.ModuleList(
|
120 |
[
|
121 |
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
|
122 |
-
|
123 |
]
|
124 |
)
|
125 |
)
|
@@ -128,7 +167,7 @@ class LocalFacialExtractor(nn.Module):
|
|
128 |
for i in range(5):
|
129 |
setattr(
|
130 |
self,
|
131 |
-
f
|
132 |
nn.Sequential(
|
133 |
nn.Linear(1024, 1024),
|
134 |
nn.LayerNorm(1024),
|
@@ -175,30 +214,30 @@ class LocalFacialExtractor(nn.Module):
|
|
175 |
|
176 |
# Process each of the 5 visual feature inputs
|
177 |
for i in range(5):
|
178 |
-
vit_feature = getattr(self, f
|
179 |
ctx_feature = torch.cat((x, vit_feature), dim=1)
|
180 |
|
181 |
-
# Pass through the PerceiverAttention and
|
182 |
-
for attn, ff in self.layers[i * self.depth: (i + 1) * self.depth]:
|
183 |
latents = attn(ctx_feature, latents) + latents
|
184 |
latents = ff(latents) + latents
|
185 |
|
186 |
# Retain only the query latents
|
187 |
-
latents = latents[:, :self.num_queries]
|
188 |
# Project the latents to the output dimension
|
189 |
latents = latents @ self.proj_out
|
190 |
return latents
|
191 |
-
|
192 |
|
193 |
class PerceiverCrossAttention(nn.Module):
|
194 |
"""
|
195 |
-
|
196 |
Args:
|
197 |
dim (int): Dimension of the input latent and output. Default is 3072.
|
198 |
dim_head (int): Dimension of each attention head. Default is 128.
|
199 |
heads (int): Number of attention heads. Default is 16.
|
200 |
kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
|
201 |
-
|
202 |
Attributes:
|
203 |
scale (float): Scaling factor used in dot-product attention for numerical stability.
|
204 |
norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
|
@@ -208,9 +247,10 @@ class PerceiverCrossAttention(nn.Module):
|
|
208 |
to_out (nn.Linear): Linear layer for outputting the final result after attention.
|
209 |
|
210 |
"""
|
|
|
211 |
def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
|
212 |
super().__init__()
|
213 |
-
self.scale = dim_head
|
214 |
self.dim_head = dim_head
|
215 |
self.heads = heads
|
216 |
inner_dim = dim_head * heads
|
@@ -232,13 +272,13 @@ class PerceiverCrossAttention(nn.Module):
|
|
232 |
- batch_size (b): Number of samples in the batch.
|
233 |
- n1: Sequence length (e.g., number of patches or tokens).
|
234 |
- D: Feature dimension.
|
235 |
-
|
236 |
latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
|
237 |
- n2: Number of latent elements.
|
238 |
-
|
239 |
Returns:
|
240 |
torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
|
241 |
-
|
242 |
"""
|
243 |
# Apply layer normalization to the input image and latent features
|
244 |
x = self.norm1(x)
|
|
|
4 |
|
5 |
|
6 |
# FFN
|
7 |
+
def ConsisIDFeedForward(dim, mult=4):
|
8 |
+
"""
|
9 |
+
Creates a consistent ID feedforward block consisting of layer normalization,
|
10 |
+
two linear layers, and a GELU activation.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
dim (int): The input dimension of the tensor.
|
14 |
+
mult (int, optional): Multiplier for the inner dimension. Default is 4.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
nn.Sequential: A sequence of layers comprising LayerNorm, Linear layers, and GELU.
|
18 |
+
"""
|
19 |
inner_dim = int(dim * mult)
|
20 |
return nn.Sequential(
|
21 |
nn.LayerNorm(dim),
|
|
|
26 |
|
27 |
|
28 |
def reshape_tensor(x, heads):
|
29 |
+
"""
|
30 |
+
Reshapes the input tensor for multi-head attention.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
x (torch.Tensor): The input tensor with shape (batch_size, length, width).
|
34 |
+
heads (int): The number of attention heads.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
|
38 |
+
"""
|
39 |
bs, length, width = x.shape
|
|
|
40 |
x = x.view(bs, length, heads, -1)
|
|
|
41 |
x = x.transpose(1, 2)
|
|
|
42 |
x = x.reshape(bs, heads, length, -1)
|
43 |
return x
|
44 |
|
45 |
|
46 |
class PerceiverAttention(nn.Module):
|
47 |
+
"""
|
48 |
+
Implements the Perceiver attention mechanism with multi-head attention.
|
49 |
+
|
50 |
+
This layer takes two inputs: 'x' (image features) and 'latents' (latent features),
|
51 |
+
applying multi-head attention to both and producing an output tensor with the same
|
52 |
+
dimension as the input tensor 'x'.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
dim (int): The input dimension.
|
56 |
+
dim_head (int, optional): The dimension of each attention head. Default is 64.
|
57 |
+
heads (int, optional): The number of attention heads. Default is 8.
|
58 |
+
kv_dim (int, optional): The key-value dimension. If None, `dim` is used for both keys and values.
|
59 |
+
"""
|
60 |
+
|
61 |
def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
|
62 |
super().__init__()
|
63 |
+
self.scale = dim_head**-0.5
|
64 |
self.dim_head = dim_head
|
65 |
self.heads = heads
|
66 |
inner_dim = dim_head * heads
|
|
|
74 |
|
75 |
def forward(self, x, latents):
|
76 |
"""
|
77 |
+
Forward pass for Perceiver attention.
|
78 |
+
|
79 |
Args:
|
80 |
+
x (torch.Tensor): Image features tensor with shape (batch_size, num_pixels, D).
|
81 |
+
latents (torch.Tensor): Latent features tensor with shape (batch_size, num_latents, D).
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
torch.Tensor: Output tensor after applying attention and transformation.
|
85 |
"""
|
86 |
+
# Apply normalization
|
87 |
x = self.norm1(x)
|
88 |
latents = self.norm2(latents)
|
89 |
|
90 |
+
b, seq_len, _ = latents.shape # Get batch size and sequence length
|
91 |
|
92 |
+
# Compute query, key, and value matrices
|
93 |
q = self.to_q(latents)
|
94 |
kv_input = torch.cat((x, latents), dim=-2)
|
95 |
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
96 |
|
97 |
+
# Reshape the tensors for multi-head attention
|
98 |
q = reshape_tensor(q, self.heads)
|
99 |
k = reshape_tensor(k, self.heads)
|
100 |
v = reshape_tensor(v, self.heads)
|
|
|
105 |
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
106 |
out = weight @ v
|
107 |
|
108 |
+
# Reshape and return the final output
|
109 |
out = out.permute(0, 2, 1, 3).reshape(b, seq_len, -1)
|
110 |
|
111 |
return self.to_out(out)
|
|
|
113 |
|
114 |
class LocalFacialExtractor(nn.Module):
|
115 |
def __init__(
|
116 |
+
self,
|
117 |
+
dim=1024,
|
118 |
+
depth=10,
|
119 |
+
dim_head=64,
|
120 |
+
heads=16,
|
121 |
+
num_id_token=5,
|
122 |
+
num_queries=32,
|
123 |
+
output_dim=2048,
|
124 |
+
ff_mult=4,
|
125 |
):
|
126 |
"""
|
127 |
Initializes the LocalFacialExtractor class.
|
128 |
|
129 |
Parameters:
|
130 |
- dim (int): The dimensionality of latent features.
|
131 |
+
- depth (int): Total number of PerceiverAttention and ConsisIDFeedForward layers.
|
132 |
- dim_head (int): Dimensionality of each attention head.
|
133 |
- heads (int): Number of attention heads.
|
134 |
- num_id_token (int): Number of tokens used for identity features.
|
|
|
144 |
self.num_queries = num_queries
|
145 |
assert depth % 5 == 0
|
146 |
self.depth = depth // 5
|
147 |
+
scale = dim**-0.5
|
148 |
|
149 |
# Learnable latent query embeddings
|
150 |
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) * scale)
|
151 |
# Projection layer to map the latent output to the desired dimension
|
152 |
self.proj_out = nn.Parameter(scale * torch.randn(dim, output_dim))
|
153 |
|
154 |
+
# Attention and ConsisIDFeedForward layer stack
|
155 |
self.layers = nn.ModuleList([])
|
156 |
for _ in range(depth):
|
157 |
self.layers.append(
|
158 |
nn.ModuleList(
|
159 |
[
|
160 |
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), # Perceiver Attention layer
|
161 |
+
ConsisIDFeedForward(dim=dim, mult=ff_mult), # ConsisIDFeedForward layer
|
162 |
]
|
163 |
)
|
164 |
)
|
|
|
167 |
for i in range(5):
|
168 |
setattr(
|
169 |
self,
|
170 |
+
f"mapping_{i}",
|
171 |
nn.Sequential(
|
172 |
nn.Linear(1024, 1024),
|
173 |
nn.LayerNorm(1024),
|
|
|
214 |
|
215 |
# Process each of the 5 visual feature inputs
|
216 |
for i in range(5):
|
217 |
+
vit_feature = getattr(self, f"mapping_{i}")(y[i])
|
218 |
ctx_feature = torch.cat((x, vit_feature), dim=1)
|
219 |
|
220 |
+
# Pass through the PerceiverAttention and ConsisIDFeedForward layers
|
221 |
+
for attn, ff in self.layers[i * self.depth : (i + 1) * self.depth]:
|
222 |
latents = attn(ctx_feature, latents) + latents
|
223 |
latents = ff(latents) + latents
|
224 |
|
225 |
# Retain only the query latents
|
226 |
+
latents = latents[:, : self.num_queries]
|
227 |
# Project the latents to the output dimension
|
228 |
latents = latents @ self.proj_out
|
229 |
return latents
|
230 |
+
|
231 |
|
232 |
class PerceiverCrossAttention(nn.Module):
|
233 |
"""
|
234 |
+
|
235 |
Args:
|
236 |
dim (int): Dimension of the input latent and output. Default is 3072.
|
237 |
dim_head (int): Dimension of each attention head. Default is 128.
|
238 |
heads (int): Number of attention heads. Default is 16.
|
239 |
kv_dim (int): Dimension of the key/value input, allowing flexible cross-attention. Default is 2048.
|
240 |
+
|
241 |
Attributes:
|
242 |
scale (float): Scaling factor used in dot-product attention for numerical stability.
|
243 |
norm1 (nn.LayerNorm): Layer normalization applied to the input image features.
|
|
|
247 |
to_out (nn.Linear): Linear layer for outputting the final result after attention.
|
248 |
|
249 |
"""
|
250 |
+
|
251 |
def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
|
252 |
super().__init__()
|
253 |
+
self.scale = dim_head**-0.5
|
254 |
self.dim_head = dim_head
|
255 |
self.heads = heads
|
256 |
inner_dim = dim_head * heads
|
|
|
272 |
- batch_size (b): Number of samples in the batch.
|
273 |
- n1: Sequence length (e.g., number of patches or tokens).
|
274 |
- D: Feature dimension.
|
275 |
+
|
276 |
latents (torch.Tensor): Latent feature representations with shape (batch_size, n2, D), where:
|
277 |
- n2: Number of latent elements.
|
278 |
+
|
279 |
Returns:
|
280 |
torch.Tensor: Attention-modulated features with shape (batch_size, n2, D).
|
281 |
+
|
282 |
"""
|
283 |
# Apply layer normalization to the input image and latent features
|
284 |
x = self.norm1(x)
|
models/pipeline_consisid.py
CHANGED
@@ -1,8 +1,16 @@
|
|
1 |
-
# Copyright
|
2 |
-
#
|
3 |
-
|
4 |
-
#
|
5 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
import inspect
|
8 |
import math
|
@@ -13,20 +21,19 @@ import sys
|
|
13 |
import PIL
|
14 |
import numpy as np
|
15 |
import cv2
|
16 |
-
from PIL import Image
|
17 |
import torch
|
|
|
18 |
from transformers import T5EncoderModel, T5Tokenizer
|
19 |
|
20 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
21 |
from diffusers.image_processor import PipelineImageInput
|
22 |
-
from diffusers.models import AutoencoderKLCogVideoX
|
23 |
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
24 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
25 |
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
26 |
-
from diffusers.utils import logging, replace_example_docstring
|
27 |
from diffusers.utils.torch_utils import randn_tensor
|
28 |
from diffusers.video_processor import VideoProcessor
|
29 |
-
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
|
30 |
|
31 |
from models.transformer_consisid import ConsisIDTransformer3DModel
|
32 |
|
@@ -37,26 +44,28 @@ for project_root in project_roots:
|
|
37 |
|
38 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
|
|
|
40 |
EXAMPLE_DOC_STRING = """
|
41 |
Examples:
|
42 |
```py
|
43 |
>>> import torch
|
44 |
-
>>> from diffusers import
|
45 |
>>> from diffusers.utils import export_to_video, load_image
|
46 |
|
47 |
-
>>> pipe =
|
48 |
>>> pipe.to("cuda")
|
49 |
|
50 |
-
>>> prompt = "
|
51 |
>>> image = load_image(
|
52 |
-
... "https://
|
53 |
... )
|
54 |
>>> video = pipe(image, prompt, use_dynamic_cfg=True)
|
55 |
>>> export_to_video(video.frames[0], "output.mp4", fps=8)
|
56 |
```
|
57 |
"""
|
58 |
|
59 |
-
|
|
|
60 |
stickwidth = 4
|
61 |
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
62 |
kps = np.array(kps)
|
@@ -72,7 +81,9 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
|
|
72 |
y = kps[index][:, 1]
|
73 |
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
74 |
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
75 |
-
polygon = cv2.ellipse2Poly(
|
|
|
|
|
76 |
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
77 |
out_img = (out_img * 0.6).astype(np.uint8)
|
78 |
|
@@ -81,9 +92,10 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
|
|
81 |
x, y = kp
|
82 |
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
83 |
|
84 |
-
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
|
85 |
return out_img_pil
|
86 |
|
|
|
87 |
def process_image(image, vae):
|
88 |
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
|
89 |
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
|
@@ -92,6 +104,7 @@ def process_image(image, vae):
|
|
92 |
image_latent_dist = vae.encode(input_image).latent_dist
|
93 |
return image_latent_dist
|
94 |
|
|
|
95 |
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
96 |
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
97 |
tw = tgt_width
|
@@ -185,9 +198,24 @@ def retrieve_latents(
|
|
185 |
raise AttributeError("Could not access latents of provided encoder_output")
|
186 |
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
class ConsisIDPipeline(DiffusionPipeline):
|
189 |
r"""
|
190 |
-
Pipeline for image-to-video generation using
|
191 |
|
192 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
193 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
@@ -196,7 +224,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
196 |
vae ([`AutoencoderKL`]):
|
197 |
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
198 |
text_encoder ([`T5EncoderModel`]):
|
199 |
-
Frozen text-encoder.
|
200 |
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
201 |
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
202 |
tokenizer (`T5Tokenizer`):
|
@@ -222,7 +250,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
222 |
tokenizer: T5Tokenizer,
|
223 |
text_encoder: T5EncoderModel,
|
224 |
vae: AutoencoderKLCogVideoX,
|
225 |
-
transformer: Union[ConsisIDTransformer3DModel
|
226 |
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
227 |
):
|
228 |
super().__init__()
|
@@ -246,7 +274,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
246 |
|
247 |
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
248 |
|
249 |
-
# Copied from diffusers.pipelines.
|
250 |
def _get_t5_prompt_embeds(
|
251 |
self,
|
252 |
prompt: Union[str, List[str]] = None,
|
@@ -289,7 +317,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
289 |
|
290 |
return prompt_embeds
|
291 |
|
292 |
-
# Copied from diffusers.pipelines.
|
293 |
def encode_prompt(
|
294 |
self,
|
295 |
prompt: Union[str, List[str]],
|
@@ -409,7 +437,8 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
409 |
if kps_cond is not None:
|
410 |
kps_cond = kps_cond.unsqueeze(2)
|
411 |
kps_cond_latents = [
|
412 |
-
retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i])
|
|
|
413 |
]
|
414 |
else:
|
415 |
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
@@ -455,7 +484,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
455 |
latents = latents * self.scheduler.init_noise_sigma
|
456 |
return latents, image_latents
|
457 |
|
458 |
-
# Copied from diffusers.pipelines.
|
459 |
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
460 |
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
461 |
latents = 1 / self.vae_scaling_factor_image * latents
|
@@ -554,13 +583,13 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
554 |
f" {negative_prompt_embeds.shape}."
|
555 |
)
|
556 |
|
557 |
-
# Copied from diffusers.pipelines.
|
558 |
def fuse_qkv_projections(self) -> None:
|
559 |
r"""Enables fused QKV projections."""
|
560 |
self.fusing_transformer = True
|
561 |
self.transformer.fuse_qkv_projections()
|
562 |
|
563 |
-
# Copied from diffusers.pipelines.
|
564 |
def unfuse_qkv_projections(self) -> None:
|
565 |
r"""Disable QKV projection fusion if enabled."""
|
566 |
if not self.fusing_transformer:
|
@@ -569,7 +598,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
569 |
self.transformer.unfuse_qkv_projections()
|
570 |
self.fusing_transformer = False
|
571 |
|
572 |
-
# Copied from diffusers.pipelines.
|
573 |
def _prepare_rotary_positional_embeddings(
|
574 |
self,
|
575 |
height: int,
|
@@ -638,7 +667,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
638 |
id_vit_hidden: Optional[torch.Tensor] = None,
|
639 |
id_cond: Optional[torch.Tensor] = None,
|
640 |
kps_cond: Optional[torch.Tensor] = None,
|
641 |
-
) -> Union[
|
642 |
"""
|
643 |
Function invoked when calling the pipeline for generation.
|
644 |
|
@@ -658,7 +687,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
658 |
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
659 |
num_frames (`int`, defaults to `48`):
|
660 |
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
661 |
-
contain 1 extra frame because
|
662 |
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
663 |
needs to be satisfied is that of divisibility mentioned above.
|
664 |
num_inference_steps (`int`, *optional*, defaults to 50):
|
@@ -712,8 +741,8 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
712 |
Examples:
|
713 |
|
714 |
Returns:
|
715 |
-
[`~pipelines.
|
716 |
-
[`~pipelines.
|
717 |
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
718 |
"""
|
719 |
if num_frames > 49:
|
@@ -784,7 +813,7 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
784 |
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
785 |
device, dtype=prompt_embeds.dtype
|
786 |
)
|
787 |
-
|
788 |
latent_channels = self.transformer.config.in_channels // 2
|
789 |
latents, image_latents = self.prepare_latents(
|
790 |
image,
|
@@ -797,9 +826,9 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
797 |
device,
|
798 |
generator,
|
799 |
latents,
|
800 |
-
kps_cond
|
801 |
)
|
802 |
-
|
803 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
804 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
805 |
|
@@ -836,8 +865,8 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
836 |
timestep=timestep,
|
837 |
image_rotary_emb=image_rotary_emb,
|
838 |
return_dict=False,
|
839 |
-
id_vit_hidden
|
840 |
-
id_cond
|
841 |
)[0]
|
842 |
noise_pred = noise_pred.float()
|
843 |
|
@@ -891,4 +920,4 @@ class ConsisIDPipeline(DiffusionPipeline):
|
|
891 |
if not return_dict:
|
892 |
return (video,)
|
893 |
|
894 |
-
return
|
|
|
1 |
+
# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
|
15 |
import inspect
|
16 |
import math
|
|
|
21 |
import PIL
|
22 |
import numpy as np
|
23 |
import cv2
|
|
|
24 |
import torch
|
25 |
+
from dataclasses import dataclass
|
26 |
from transformers import T5EncoderModel, T5Tokenizer
|
27 |
|
28 |
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
29 |
from diffusers.image_processor import PipelineImageInput
|
30 |
+
from diffusers.models import AutoencoderKLCogVideoX
|
31 |
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
32 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
33 |
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
34 |
+
from diffusers.utils import logging, replace_example_docstring, BaseOutput
|
35 |
from diffusers.utils.torch_utils import randn_tensor
|
36 |
from diffusers.video_processor import VideoProcessor
|
|
|
37 |
|
38 |
from models.transformer_consisid import ConsisIDTransformer3DModel
|
39 |
|
|
|
44 |
|
45 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
46 |
|
47 |
+
|
48 |
EXAMPLE_DOC_STRING = """
|
49 |
Examples:
|
50 |
```py
|
51 |
>>> import torch
|
52 |
+
>>> from diffusers import ConsisIDPipeline
|
53 |
>>> from diffusers.utils import export_to_video, load_image
|
54 |
|
55 |
+
>>> pipe = ConsisIDPipeline.from_pretrained("https://huggingface.co/BestWishYsh/ConsisID-preview", torch_dtype=torch.bfloat16)
|
56 |
>>> pipe.to("cuda")
|
57 |
|
58 |
+
>>> prompt = "A woman adorned with a delicate flower crown, is standing amidst a field of gently swaying wildflowers. Her eyes sparkle with a serene gaze, and a faint smile graces her lips, suggesting a moment of peaceful contentment. The shot is framed from the waist up, highlighting the gentle breeze lightly tousling her hair. The background reveals an expansive meadow under a bright blue sky, capturing the tranquility of a sunny afternoon."
|
59 |
>>> image = load_image(
|
60 |
+
... "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/1.png?raw=true"
|
61 |
... )
|
62 |
>>> video = pipe(image, prompt, use_dynamic_cfg=True)
|
63 |
>>> export_to_video(video.frames[0], "output.mp4", fps=8)
|
64 |
```
|
65 |
"""
|
66 |
|
67 |
+
|
68 |
+
def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
|
69 |
stickwidth = 4
|
70 |
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
71 |
kps = np.array(kps)
|
|
|
81 |
y = kps[index][:, 1]
|
82 |
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
83 |
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
84 |
+
polygon = cv2.ellipse2Poly(
|
85 |
+
(int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1
|
86 |
+
)
|
87 |
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
88 |
out_img = (out_img * 0.6).astype(np.uint8)
|
89 |
|
|
|
92 |
x, y = kp
|
93 |
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
94 |
|
95 |
+
out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
|
96 |
return out_img_pil
|
97 |
|
98 |
+
|
99 |
def process_image(image, vae):
|
100 |
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=image.device)
|
101 |
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=image.dtype)
|
|
|
104 |
image_latent_dist = vae.encode(input_image).latent_dist
|
105 |
return image_latent_dist
|
106 |
|
107 |
+
|
108 |
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
109 |
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
110 |
tw = tgt_width
|
|
|
198 |
raise AttributeError("Could not access latents of provided encoder_output")
|
199 |
|
200 |
|
201 |
+
@dataclass
|
202 |
+
class ConsisIDPipelineOutput(BaseOutput):
|
203 |
+
r"""
|
204 |
+
Output class for ConsisID pipelines.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
208 |
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
209 |
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
210 |
+
`(batch_size, num_frames, channels, height, width)`.
|
211 |
+
"""
|
212 |
+
|
213 |
+
frames: torch.Tensor
|
214 |
+
|
215 |
+
|
216 |
class ConsisIDPipeline(DiffusionPipeline):
|
217 |
r"""
|
218 |
+
Pipeline for image-to-video generation using ConsisID.
|
219 |
|
220 |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
221 |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
|
|
224 |
vae ([`AutoencoderKL`]):
|
225 |
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
226 |
text_encoder ([`T5EncoderModel`]):
|
227 |
+
Frozen text-encoder. ConsisID uses
|
228 |
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
229 |
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
230 |
tokenizer (`T5Tokenizer`):
|
|
|
250 |
tokenizer: T5Tokenizer,
|
251 |
text_encoder: T5EncoderModel,
|
252 |
vae: AutoencoderKLCogVideoX,
|
253 |
+
transformer: Union[ConsisIDTransformer3DModel],
|
254 |
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
255 |
):
|
256 |
super().__init__()
|
|
|
274 |
|
275 |
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
276 |
|
277 |
+
# Copied from diffusers.pipelines.consisid.pipeline_consisID.ConsisIDPipeline._get_t5_prompt_embeds
|
278 |
def _get_t5_prompt_embeds(
|
279 |
self,
|
280 |
prompt: Union[str, List[str]] = None,
|
|
|
317 |
|
318 |
return prompt_embeds
|
319 |
|
320 |
+
# Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.encode_prompt
|
321 |
def encode_prompt(
|
322 |
self,
|
323 |
prompt: Union[str, List[str]],
|
|
|
437 |
if kps_cond is not None:
|
438 |
kps_cond = kps_cond.unsqueeze(2)
|
439 |
kps_cond_latents = [
|
440 |
+
retrieve_latents(self.vae.encode(kps_cond[i].unsqueeze(0)), generator[i])
|
441 |
+
for i in range(batch_size)
|
442 |
]
|
443 |
else:
|
444 |
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
|
|
484 |
latents = latents * self.scheduler.init_noise_sigma
|
485 |
return latents, image_latents
|
486 |
|
487 |
+
# Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.decode_latents
|
488 |
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
489 |
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
490 |
latents = 1 / self.vae_scaling_factor_image * latents
|
|
|
583 |
f" {negative_prompt_embeds.shape}."
|
584 |
)
|
585 |
|
586 |
+
# Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.fuse_qkv_projections
|
587 |
def fuse_qkv_projections(self) -> None:
|
588 |
r"""Enables fused QKV projections."""
|
589 |
self.fusing_transformer = True
|
590 |
self.transformer.fuse_qkv_projections()
|
591 |
|
592 |
+
# Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline.unfuse_qkv_projections
|
593 |
def unfuse_qkv_projections(self) -> None:
|
594 |
r"""Disable QKV projection fusion if enabled."""
|
595 |
if not self.fusing_transformer:
|
|
|
598 |
self.transformer.unfuse_qkv_projections()
|
599 |
self.fusing_transformer = False
|
600 |
|
601 |
+
# Copied from diffusers.pipelines.consisid.pipeline_consisid.ConsisIDPipeline._prepare_rotary_positional_embeddings
|
602 |
def _prepare_rotary_positional_embeddings(
|
603 |
self,
|
604 |
height: int,
|
|
|
667 |
id_vit_hidden: Optional[torch.Tensor] = None,
|
668 |
id_cond: Optional[torch.Tensor] = None,
|
669 |
kps_cond: Optional[torch.Tensor] = None,
|
670 |
+
) -> Union[ConsisIDPipelineOutput, Tuple]:
|
671 |
"""
|
672 |
Function invoked when calling the pipeline for generation.
|
673 |
|
|
|
687 |
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
688 |
num_frames (`int`, defaults to `48`):
|
689 |
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
690 |
+
contain 1 extra frame because ConsisID is conditioned with (num_seconds * fps + 1) frames where
|
691 |
num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
|
692 |
needs to be satisfied is that of divisibility mentioned above.
|
693 |
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
|
741 |
Examples:
|
742 |
|
743 |
Returns:
|
744 |
+
[`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] or `tuple`:
|
745 |
+
[`~pipelines.consisid.pipeline_output.ConsisIDPipelineOutput`] if `return_dict` is True, otherwise a
|
746 |
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
747 |
"""
|
748 |
if num_frames > 49:
|
|
|
813 |
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
814 |
device, dtype=prompt_embeds.dtype
|
815 |
)
|
816 |
+
|
817 |
latent_channels = self.transformer.config.in_channels // 2
|
818 |
latents, image_latents = self.prepare_latents(
|
819 |
image,
|
|
|
826 |
device,
|
827 |
generator,
|
828 |
latents,
|
829 |
+
kps_cond,
|
830 |
)
|
831 |
+
|
832 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
833 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
834 |
|
|
|
865 |
timestep=timestep,
|
866 |
image_rotary_emb=image_rotary_emb,
|
867 |
return_dict=False,
|
868 |
+
id_vit_hidden=id_vit_hidden,
|
869 |
+
id_cond=id_cond,
|
870 |
)[0]
|
871 |
noise_pred = noise_pred.float()
|
872 |
|
|
|
920 |
if not return_dict:
|
921 |
return (video,)
|
922 |
|
923 |
+
return ConsisIDPipelineOutput(frames=video)
|
models/transformer_consisid.py
CHANGED
@@ -1,8 +1,16 @@
|
|
1 |
-
# Copyright
|
2 |
-
#
|
3 |
-
|
4 |
-
#
|
5 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from typing import Any, Dict, Optional, Tuple, Union
|
8 |
import os
|
@@ -38,9 +46,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
38 |
|
39 |
|
40 |
@maybe_allow_in_graph
|
41 |
-
class
|
42 |
r"""
|
43 |
-
Transformer block used in [
|
44 |
|
45 |
Parameters:
|
46 |
dim (`int`):
|
@@ -132,9 +140,6 @@ class CogVideoXBlock(nn.Module):
|
|
132 |
hidden_states, encoder_hidden_states, temb
|
133 |
)
|
134 |
|
135 |
-
# insert here
|
136 |
-
# pass
|
137 |
-
|
138 |
# attention
|
139 |
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
140 |
hidden_states=norm_hidden_states,
|
@@ -162,7 +167,7 @@ class CogVideoXBlock(nn.Module):
|
|
162 |
|
163 |
class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
164 |
"""
|
165 |
-
A Transformer model for video-like data in [
|
166 |
|
167 |
Parameters:
|
168 |
num_attention_heads (`int`, defaults to `30`):
|
@@ -191,7 +196,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
191 |
The height of the input latents.
|
192 |
sample_frames (`int`, defaults to `49`):
|
193 |
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
194 |
-
instead of 13 because
|
195 |
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
196 |
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
197 |
patch_size (`int`, defaults to `2`):
|
@@ -212,6 +217,32 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
212 |
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
213 |
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
214 |
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
"""
|
216 |
|
217 |
_supports_gradient_checkpointing = True
|
@@ -257,7 +288,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
257 |
|
258 |
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
259 |
raise ValueError(
|
260 |
-
"There are no
|
261 |
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
262 |
"issue at https://github.com/huggingface/diffusers/issues."
|
263 |
)
|
@@ -288,7 +319,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
288 |
# 3. Define spatio-temporal transformers blocks
|
289 |
self.transformer_blocks = nn.ModuleList(
|
290 |
[
|
291 |
-
|
292 |
dim=inner_dim,
|
293 |
num_attention_heads=num_attention_heads,
|
294 |
attention_head_dim=attention_head_dim,
|
@@ -319,6 +350,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
319 |
self.is_train_face = is_train_face
|
320 |
self.is_kps = is_kps
|
321 |
|
|
|
322 |
if is_train_face:
|
323 |
self.inner_dim = inner_dim
|
324 |
self.cross_attn_interval = cross_attn_interval
|
@@ -338,21 +370,26 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
338 |
weight_dtype = next(self.transformer_blocks.parameters()).dtype
|
339 |
self.local_facial_extractor = LocalFacialExtractor()
|
340 |
self.local_facial_extractor.to(device, dtype=weight_dtype)
|
341 |
-
self.perceiver_cross_attention = nn.ModuleList(
|
342 |
-
|
343 |
-
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
def save_face_modules(self, path: str):
|
346 |
save_dict = {
|
347 |
-
|
348 |
-
|
349 |
}
|
350 |
torch.save(save_dict, path)
|
351 |
|
352 |
def load_face_modules(self, path: str):
|
353 |
checkpoint = torch.load(path, map_location=self.device)
|
354 |
-
self.local_facial_extractor.load_state_dict(checkpoint[
|
355 |
-
for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint[
|
356 |
ca.load_state_dict(state_dict)
|
357 |
|
358 |
@property
|
@@ -463,14 +500,16 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
463 |
timestep_cond: Optional[torch.Tensor] = None,
|
464 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
465 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
466 |
-
id_cond: Optional[torch.Tensor] = None,
|
467 |
id_vit_hidden: Optional[torch.Tensor] = None,
|
468 |
return_dict: bool = True,
|
469 |
):
|
470 |
# fuse clip and insightface
|
471 |
if self.is_train_face:
|
472 |
assert id_cond is not None and id_vit_hidden is not None
|
473 |
-
valid_face_emb = self.local_facial_extractor(
|
|
|
|
|
474 |
|
475 |
if attention_kwargs is not None:
|
476 |
attention_kwargs = attention_kwargs.copy()
|
@@ -506,7 +545,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
506 |
|
507 |
text_seq_length = encoder_hidden_states.shape[1]
|
508 |
encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
|
509 |
-
hidden_states = hidden_states[:, text_seq_length:]
|
510 |
|
511 |
# 3. Transformer blocks
|
512 |
ca_idx = 0
|
@@ -538,17 +577,14 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
538 |
|
539 |
if self.is_train_face:
|
540 |
if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
|
541 |
-
hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](
|
|
|
|
|
542 |
ca_idx += 1
|
543 |
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
else:
|
548 |
-
# CogVideoX-5B
|
549 |
-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
550 |
-
hidden_states = self.norm_final(hidden_states)
|
551 |
-
hidden_states = hidden_states[:, text_seq_length:]
|
552 |
|
553 |
# 4. Final block
|
554 |
hidden_states = self.norm_out(hidden_states, temb=emb)
|
@@ -556,8 +592,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
556 |
|
557 |
# 5. Unpatchify
|
558 |
# Note: we use `-1` instead of `channels`:
|
559 |
-
# - It is okay to `channels` use for
|
560 |
-
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
561 |
p = self.config.patch_size
|
562 |
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
563 |
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
|
|
1 |
+
# Copyright 2024 ConsisID Authors and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
|
15 |
from typing import Any, Dict, Optional, Tuple, Union
|
16 |
import os
|
|
|
46 |
|
47 |
|
48 |
@maybe_allow_in_graph
|
49 |
+
class ConsisIDBlock(nn.Module):
|
50 |
r"""
|
51 |
+
Transformer block used in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) model.
|
52 |
|
53 |
Parameters:
|
54 |
dim (`int`):
|
|
|
140 |
hidden_states, encoder_hidden_states, temb
|
141 |
)
|
142 |
|
|
|
|
|
|
|
143 |
# attention
|
144 |
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
145 |
hidden_states=norm_hidden_states,
|
|
|
167 |
|
168 |
class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
169 |
"""
|
170 |
+
A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
|
171 |
|
172 |
Parameters:
|
173 |
num_attention_heads (`int`, defaults to `30`):
|
|
|
196 |
The height of the input latents.
|
197 |
sample_frames (`int`, defaults to `49`):
|
198 |
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
199 |
+
instead of 13 because ConsisID processed 13 latent frames at once in its default and recommended settings,
|
200 |
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
201 |
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
202 |
patch_size (`int`, defaults to `2`):
|
|
|
217 |
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
218 |
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
219 |
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
220 |
+
is_train_face (`bool`, defaults to `False`):
|
221 |
+
Whether to use enable the identity-preserving module during the training process.
|
222 |
+
When set to `True`, the model will focus on identity-preserving tasks.
|
223 |
+
is_kps (`bool`, defaults to `False`):
|
224 |
+
Whether to enable keypoint for global facial extractor.
|
225 |
+
If `True`, keypoints will be in the model.
|
226 |
+
cross_attn_interval (`int`, defaults to `1`):
|
227 |
+
The interval between cross-attention layers in the Transformer architecture.
|
228 |
+
A larger value may reduce the frequency of cross-attention computations,
|
229 |
+
which can help reduce computational overhead.
|
230 |
+
LFE_num_tokens (`int`, defaults to `32`):
|
231 |
+
The number of tokens to use in the Local Facial Extractor (LFE).
|
232 |
+
This module is responsible for capturing high frequency representations
|
233 |
+
of the face.
|
234 |
+
LFE_output_dim (`int`, defaults to `768`):
|
235 |
+
The output dimension of the Local Facial Extractor (LFE) module.
|
236 |
+
This dimension determines the size of the feature vectors produced
|
237 |
+
by the LFE module.
|
238 |
+
LFE_heads (`int`, defaults to `12`):
|
239 |
+
The number of attention heads used in the Local Facial Extractor (LFE) module.
|
240 |
+
More heads may improve the ability to capture diverse features, but
|
241 |
+
can also increase computational complexity.
|
242 |
+
local_face_scale (`float`, defaults to `1.0`):
|
243 |
+
A scaling factor used to adjust the importance of local facial features
|
244 |
+
in the model. This can influence how strongly the model focuses on
|
245 |
+
high frequency face-related content.
|
246 |
"""
|
247 |
|
248 |
_supports_gradient_checkpointing = True
|
|
|
288 |
|
289 |
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
290 |
raise ValueError(
|
291 |
+
"There are no ConsisID checkpoints available with disable rotary embeddings and learned positional "
|
292 |
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
293 |
"issue at https://github.com/huggingface/diffusers/issues."
|
294 |
)
|
|
|
319 |
# 3. Define spatio-temporal transformers blocks
|
320 |
self.transformer_blocks = nn.ModuleList(
|
321 |
[
|
322 |
+
ConsisIDBlock(
|
323 |
dim=inner_dim,
|
324 |
num_attention_heads=num_attention_heads,
|
325 |
attention_head_dim=attention_head_dim,
|
|
|
350 |
self.is_train_face = is_train_face
|
351 |
self.is_kps = is_kps
|
352 |
|
353 |
+
# 5. Define identity-preserving config
|
354 |
if is_train_face:
|
355 |
self.inner_dim = inner_dim
|
356 |
self.cross_attn_interval = cross_attn_interval
|
|
|
370 |
weight_dtype = next(self.transformer_blocks.parameters()).dtype
|
371 |
self.local_facial_extractor = LocalFacialExtractor()
|
372 |
self.local_facial_extractor.to(device, dtype=weight_dtype)
|
373 |
+
self.perceiver_cross_attention = nn.ModuleList(
|
374 |
+
[
|
375 |
+
PerceiverCrossAttention(
|
376 |
+
dim=self.inner_dim, dim_head=128, heads=16, kv_dim=self.LFE_final_output_dim
|
377 |
+
).to(device, dtype=weight_dtype)
|
378 |
+
for _ in range(self.num_ca)
|
379 |
+
]
|
380 |
+
)
|
381 |
|
382 |
def save_face_modules(self, path: str):
|
383 |
save_dict = {
|
384 |
+
"local_facial_extractor": self.local_facial_extractor.state_dict(),
|
385 |
+
"perceiver_cross_attention": [ca.state_dict() for ca in self.perceiver_cross_attention],
|
386 |
}
|
387 |
torch.save(save_dict, path)
|
388 |
|
389 |
def load_face_modules(self, path: str):
|
390 |
checkpoint = torch.load(path, map_location=self.device)
|
391 |
+
self.local_facial_extractor.load_state_dict(checkpoint["local_facial_extractor"])
|
392 |
+
for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint["perceiver_cross_attention"]):
|
393 |
ca.load_state_dict(state_dict)
|
394 |
|
395 |
@property
|
|
|
500 |
timestep_cond: Optional[torch.Tensor] = None,
|
501 |
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
502 |
attention_kwargs: Optional[Dict[str, Any]] = None,
|
503 |
+
id_cond: Optional[torch.Tensor] = None,
|
504 |
id_vit_hidden: Optional[torch.Tensor] = None,
|
505 |
return_dict: bool = True,
|
506 |
):
|
507 |
# fuse clip and insightface
|
508 |
if self.is_train_face:
|
509 |
assert id_cond is not None and id_vit_hidden is not None
|
510 |
+
valid_face_emb = self.local_facial_extractor(
|
511 |
+
id_cond, id_vit_hidden
|
512 |
+
) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])
|
513 |
|
514 |
if attention_kwargs is not None:
|
515 |
attention_kwargs = attention_kwargs.copy()
|
|
|
545 |
|
546 |
text_seq_length = encoder_hidden_states.shape[1]
|
547 |
encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072])
|
548 |
+
hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072])
|
549 |
|
550 |
# 3. Transformer blocks
|
551 |
ca_idx = 0
|
|
|
577 |
|
578 |
if self.is_train_face:
|
579 |
if i % self.cross_attn_interval == 0 and valid_face_emb is not None:
|
580 |
+
hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx](
|
581 |
+
valid_face_emb, hidden_states
|
582 |
+
) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072])
|
583 |
ca_idx += 1
|
584 |
|
585 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
586 |
+
hidden_states = self.norm_final(hidden_states)
|
587 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
|
|
|
|
|
|
|
|
|
|
588 |
|
589 |
# 4. Final block
|
590 |
hidden_states = self.norm_out(hidden_states, temb=emb)
|
|
|
592 |
|
593 |
# 5. Unpatchify
|
594 |
# Note: we use `-1` instead of `channels`:
|
595 |
+
# - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels)
|
|
|
596 |
p = self.config.patch_size
|
597 |
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
598 |
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
models/utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
|
|
1 |
import cv2
|
2 |
import math
|
3 |
import numpy as np
|
4 |
-
from PIL import Image
|
5 |
|
6 |
import torch
|
7 |
from torchvision.transforms import InterpolationMode
|
@@ -10,7 +11,16 @@ from transformers import T5EncoderModel, T5Tokenizer
|
|
10 |
from typing import List, Optional, Tuple, Union
|
11 |
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
12 |
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
|
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def tensor_to_pil(src_img_tensor):
|
16 |
img = src_img_tensor.clone().detach()
|
@@ -204,12 +214,12 @@ def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,2
|
|
204 |
return out_img_pil
|
205 |
|
206 |
|
207 |
-
def process_face_embeddings(
|
208 |
"""
|
209 |
Args:
|
210 |
image: numpy rgb image, range [0, 255]
|
211 |
"""
|
212 |
-
|
213 |
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3)
|
214 |
# get antelopev2 embedding
|
215 |
face_info = app.get(image_bgr)
|
@@ -224,19 +234,19 @@ def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_tr
|
|
224 |
face_kps = None
|
225 |
|
226 |
# using facexlib to detect and align face
|
227 |
-
|
228 |
-
|
229 |
if face_kps is None:
|
230 |
-
face_kps =
|
231 |
-
|
232 |
-
if len(
|
233 |
raise RuntimeError('facexlib align face fail')
|
234 |
-
align_face =
|
235 |
|
236 |
# incase insightface didn't detect face
|
237 |
if id_ante_embedding is None:
|
238 |
print('fail to detect face using insightface, extract embedding on align face')
|
239 |
-
id_ante_embedding =
|
240 |
|
241 |
id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
|
242 |
if id_ante_embedding.ndim == 1:
|
@@ -246,7 +256,7 @@ def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_tr
|
|
246 |
if is_align_face:
|
247 |
input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
|
248 |
input = input.to(device)
|
249 |
-
parsing_out =
|
250 |
parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
|
251 |
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
|
252 |
bg = sum(parsing_out == i for i in bg_label).bool()
|
@@ -270,4 +280,84 @@ def process_face_embeddings(face_helper, clip_vision_model, handler_ante, eva_tr
|
|
270 |
|
271 |
id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
|
272 |
|
273 |
-
return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import cv2
|
3 |
import math
|
4 |
import numpy as np
|
5 |
+
from PIL import Image, ImageOps
|
6 |
|
7 |
import torch
|
8 |
from torchvision.transforms import InterpolationMode
|
|
|
11 |
from typing import List, Optional, Tuple, Union
|
12 |
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
13 |
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
14 |
+
from diffusers.utils import load_image
|
15 |
|
16 |
+
import insightface
|
17 |
+
from insightface.app import FaceAnalysis
|
18 |
+
from facexlib.parsing import init_parsing_model
|
19 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
20 |
+
|
21 |
+
from models.eva_clip import create_model_and_transforms
|
22 |
+
from models.eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
23 |
+
from models.eva_clip.utils_qformer import resize_numpy_image_long
|
24 |
|
25 |
def tensor_to_pil(src_img_tensor):
|
26 |
img = src_img_tensor.clone().detach()
|
|
|
214 |
return out_img_pil
|
215 |
|
216 |
|
217 |
+
def process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image=None, is_align_face=True):
|
218 |
"""
|
219 |
Args:
|
220 |
image: numpy rgb image, range [0, 255]
|
221 |
"""
|
222 |
+
face_helper_1.clean_all()
|
223 |
image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # (724, 502, 3)
|
224 |
# get antelopev2 embedding
|
225 |
face_info = app.get(image_bgr)
|
|
|
234 |
face_kps = None
|
235 |
|
236 |
# using facexlib to detect and align face
|
237 |
+
face_helper_1.read_image(image_bgr)
|
238 |
+
face_helper_1.get_face_landmarks_5(only_center_face=True)
|
239 |
if face_kps is None:
|
240 |
+
face_kps = face_helper_1.all_landmarks_5[0]
|
241 |
+
face_helper_1.align_warp_face()
|
242 |
+
if len(face_helper_1.cropped_faces) == 0:
|
243 |
raise RuntimeError('facexlib align face fail')
|
244 |
+
align_face = face_helper_1.cropped_faces[0] # (512, 512, 3) # RGB
|
245 |
|
246 |
# incase insightface didn't detect face
|
247 |
if id_ante_embedding is None:
|
248 |
print('fail to detect face using insightface, extract embedding on align face')
|
249 |
+
id_ante_embedding = face_helper_2.get_feat(align_face)
|
250 |
|
251 |
id_ante_embedding = torch.from_numpy(id_ante_embedding).to(device, weight_dtype) # torch.Size([512])
|
252 |
if id_ante_embedding.ndim == 1:
|
|
|
256 |
if is_align_face:
|
257 |
input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 # torch.Size([1, 3, 512, 512])
|
258 |
input = input.to(device)
|
259 |
+
parsing_out = face_helper_1.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0]
|
260 |
parsing_out = parsing_out.argmax(dim=1, keepdim=True) # torch.Size([1, 1, 512, 512])
|
261 |
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
|
262 |
bg = sum(parsing_out == i for i in bg_label).bool()
|
|
|
280 |
|
281 |
id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) # torch.Size([1, 512]), torch.Size([1, 768]) -> torch.Size([1, 1280])
|
282 |
|
283 |
+
return id_cond, id_vit_hidden, return_face_features_image_2, face_kps # torch.Size([1, 1280]), list(torch.Size([1, 577, 1024]))
|
284 |
+
|
285 |
+
|
286 |
+
def process_face_embeddings_infer(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, img_file_path, is_align_face=True):
|
287 |
+
"""
|
288 |
+
Args:
|
289 |
+
image: numpy rgb image, range [0, 255]
|
290 |
+
"""
|
291 |
+
if isinstance(img_file_path, str):
|
292 |
+
image = np.array(load_image(image=img_file_path).convert("RGB"))
|
293 |
+
else:
|
294 |
+
image = np.array(ImageOps.exif_transpose(Image.fromarray(img_file_path)).convert("RGB"))
|
295 |
+
|
296 |
+
image = resize_numpy_image_long(image, 1024)
|
297 |
+
original_id_image = image
|
298 |
+
|
299 |
+
id_cond, id_vit_hidden, align_crop_face_image, face_kps = process_face_embeddings(face_helper_1, clip_vision_model, face_helper_2, eva_transform_mean, eva_transform_std, app, device, weight_dtype, image, original_id_image, is_align_face)
|
300 |
+
|
301 |
+
tensor = align_crop_face_image.cpu().detach()
|
302 |
+
tensor = tensor.squeeze()
|
303 |
+
tensor = tensor.permute(1, 2, 0)
|
304 |
+
tensor = tensor.numpy() * 255
|
305 |
+
tensor = tensor.astype(np.uint8)
|
306 |
+
image = ImageOps.exif_transpose(Image.fromarray(tensor))
|
307 |
+
|
308 |
+
return id_cond, id_vit_hidden, image, face_kps
|
309 |
+
|
310 |
+
def prepare_face_models(model_path, device, dtype):
|
311 |
+
"""
|
312 |
+
Prepare all face models for the facial recognition task.
|
313 |
+
|
314 |
+
Parameters:
|
315 |
+
- model_path: Path to the directory containing model files.
|
316 |
+
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
|
317 |
+
- dtype: Data type (e.g., torch.float32) for model inference.
|
318 |
+
|
319 |
+
Returns:
|
320 |
+
- face_helper_1: First face restoration helper.
|
321 |
+
- face_helper_2: Second face restoration helper.
|
322 |
+
- face_clip_model: CLIP model for face extraction.
|
323 |
+
- eva_transform_mean: Mean value for image normalization.
|
324 |
+
- eva_transform_std: Standard deviation value for image normalization.
|
325 |
+
- face_main_model: Main face analysis model.
|
326 |
+
"""
|
327 |
+
# get helper model
|
328 |
+
face_helper_1 = FaceRestoreHelper(
|
329 |
+
upscale_factor=1,
|
330 |
+
face_size=512,
|
331 |
+
crop_ratio=(1, 1),
|
332 |
+
det_model='retinaface_resnet50',
|
333 |
+
save_ext='png',
|
334 |
+
device=device,
|
335 |
+
model_rootpath=os.path.join(model_path, "face_encoder")
|
336 |
+
)
|
337 |
+
face_helper_1.face_parse = None
|
338 |
+
face_helper_1.face_parse = init_parsing_model(model_name='bisenet', device=device, model_rootpath=os.path.join(model_path, "face_encoder"))
|
339 |
+
face_helper_2 = insightface.model_zoo.get_model(f'{model_path}/face_encoder/models/antelopev2/glintr100.onnx', providers=['CUDAExecutionProvider'])
|
340 |
+
face_helper_2.prepare(ctx_id=0)
|
341 |
+
|
342 |
+
# get local facial extractor part 1
|
343 |
+
model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', os.path.join(model_path, "face_encoder", "EVA02_CLIP_L_336_psz14_s6B.pt"), force_custom_clip=True)
|
344 |
+
face_clip_model = model.visual
|
345 |
+
eva_transform_mean = getattr(face_clip_model, 'image_mean', OPENAI_DATASET_MEAN)
|
346 |
+
eva_transform_std = getattr(face_clip_model, 'image_std', OPENAI_DATASET_STD)
|
347 |
+
if not isinstance(eva_transform_mean, (list, tuple)):
|
348 |
+
eva_transform_mean = (eva_transform_mean,) * 3
|
349 |
+
if not isinstance(eva_transform_std, (list, tuple)):
|
350 |
+
eva_transform_std = (eva_transform_std,) * 3
|
351 |
+
eva_transform_mean = eva_transform_mean
|
352 |
+
eva_transform_std = eva_transform_std
|
353 |
+
|
354 |
+
# get local facial extractor part 2
|
355 |
+
face_main_model = FaceAnalysis(name='antelopev2', root=os.path.join(model_path, "face_encoder"), providers=['CUDAExecutionProvider'])
|
356 |
+
face_main_model.prepare(ctx_id=0, det_size=(640, 640))
|
357 |
+
|
358 |
+
# move face models to device
|
359 |
+
face_helper_1.face_det.eval()
|
360 |
+
face_helper_1.face_parse.eval()
|
361 |
+
face_clip_model.eval()
|
362 |
+
|
363 |
+
return face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std
|
requirements.txt
CHANGED
@@ -6,7 +6,7 @@ onnx==1.17.0
|
|
6 |
onnxruntime-gpu==1.19.2
|
7 |
deepspeed==0.15.2
|
8 |
accelerate==1.1.1
|
9 |
-
|
10 |
transformers==4.46.3
|
11 |
tokenizers==0.20.1
|
12 |
peft==0.12.0
|
|
|
6 |
onnxruntime-gpu==1.19.2
|
7 |
deepspeed==0.15.2
|
8 |
accelerate==1.1.1
|
9 |
+
git+https://github.com/SHYuanBest/ConsisID_diffusers.git
|
10 |
transformers==4.46.3
|
11 |
tokenizers==0.20.1
|
12 |
peft==0.12.0
|
util/dataloader.py
DELETED
@@ -1,1010 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import gc
|
3 |
-
import cv2
|
4 |
-
import json
|
5 |
-
import math
|
6 |
-
import decord
|
7 |
-
import random
|
8 |
-
import numpy as np
|
9 |
-
from PIL import Image
|
10 |
-
from tqdm import tqdm
|
11 |
-
from decord import VideoReader
|
12 |
-
from contextlib import contextmanager
|
13 |
-
from func_timeout import FunctionTimedOut
|
14 |
-
from typing import Optional, Sized, Iterator
|
15 |
-
|
16 |
-
import torch
|
17 |
-
from torch.utils.data import Dataset, Sampler
|
18 |
-
import torch.nn.functional as F
|
19 |
-
from torchvision.transforms import ToPILImage
|
20 |
-
from torchvision import transforms
|
21 |
-
from accelerate.logging import get_logger
|
22 |
-
|
23 |
-
logger = get_logger(__name__)
|
24 |
-
|
25 |
-
import threading
|
26 |
-
log_lock = threading.Lock()
|
27 |
-
|
28 |
-
def log_error_to_file(error_message, video_path):
|
29 |
-
with log_lock:
|
30 |
-
with open("error_log.txt", "a") as f:
|
31 |
-
f.write(f"Error: {error_message}\n")
|
32 |
-
f.write(f"Video Path: {video_path}\n")
|
33 |
-
f.write("-" * 50 + "\n")
|
34 |
-
|
35 |
-
def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
|
36 |
-
stickwidth = 4
|
37 |
-
limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
|
38 |
-
kps = np.array(kps)
|
39 |
-
|
40 |
-
w, h = image_pil.size
|
41 |
-
out_img = np.zeros([h, w, 3])
|
42 |
-
|
43 |
-
for i in range(len(limbSeq)):
|
44 |
-
index = limbSeq[i]
|
45 |
-
color = color_list[index[0]]
|
46 |
-
|
47 |
-
x = kps[index][:, 0]
|
48 |
-
y = kps[index][:, 1]
|
49 |
-
length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
|
50 |
-
angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
|
51 |
-
polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
52 |
-
out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
|
53 |
-
out_img = (out_img * 0.6).astype(np.uint8)
|
54 |
-
|
55 |
-
for idx_kp, kp in enumerate(kps):
|
56 |
-
color = color_list[idx_kp]
|
57 |
-
x, y = kp
|
58 |
-
out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
|
59 |
-
|
60 |
-
out_img_pil = Image.fromarray(out_img.astype(np.uint8))
|
61 |
-
return out_img_pil
|
62 |
-
|
63 |
-
@contextmanager
|
64 |
-
def VideoReader_contextmanager(*args, **kwargs):
|
65 |
-
vr = VideoReader(*args, **kwargs)
|
66 |
-
try:
|
67 |
-
yield vr
|
68 |
-
finally:
|
69 |
-
del vr
|
70 |
-
gc.collect()
|
71 |
-
|
72 |
-
def get_valid_segments(valid_frame, tolerance=5):
|
73 |
-
valid_positions = sorted(set(valid_frame['face']).union(set(valid_frame['head'])))
|
74 |
-
|
75 |
-
valid_segments = []
|
76 |
-
current_segment = [valid_positions[0]]
|
77 |
-
|
78 |
-
for i in range(1, len(valid_positions)):
|
79 |
-
if valid_positions[i] - valid_positions[i - 1] <= tolerance:
|
80 |
-
current_segment.append(valid_positions[i])
|
81 |
-
else:
|
82 |
-
valid_segments.append(current_segment)
|
83 |
-
current_segment = [valid_positions[i]]
|
84 |
-
|
85 |
-
if current_segment:
|
86 |
-
valid_segments.append(current_segment)
|
87 |
-
|
88 |
-
return valid_segments
|
89 |
-
|
90 |
-
|
91 |
-
def get_frame_indices_adjusted_for_face(valid_frames, n_frames):
|
92 |
-
valid_length = len(valid_frames)
|
93 |
-
if valid_length >= n_frames:
|
94 |
-
return valid_frames[:n_frames]
|
95 |
-
|
96 |
-
additional_frames_needed = n_frames - valid_length
|
97 |
-
repeat_indices = []
|
98 |
-
|
99 |
-
for i in range(additional_frames_needed):
|
100 |
-
index_to_repeat = i % valid_length
|
101 |
-
repeat_indices.append(valid_frames[index_to_repeat])
|
102 |
-
|
103 |
-
all_indices = valid_frames + repeat_indices
|
104 |
-
all_indices.sort()
|
105 |
-
|
106 |
-
return all_indices
|
107 |
-
|
108 |
-
|
109 |
-
def generate_frame_indices_for_face(n_frames, sample_stride, valid_frame, tolerance=7, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0):
|
110 |
-
valid_segments = get_valid_segments(valid_frame, tolerance)
|
111 |
-
selected_segment = max(valid_segments, key=len)
|
112 |
-
|
113 |
-
valid_length = len(selected_segment)
|
114 |
-
if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0:
|
115 |
-
# print("use skip frame percent")
|
116 |
-
valid_start = int(valid_length * skip_frames_start_percent)
|
117 |
-
valid_end = int(valid_length * skip_frames_end_percent)
|
118 |
-
elif skip_frames_start != 0 or skip_frames_end != 0:
|
119 |
-
# print("use skip frame")
|
120 |
-
valid_start = skip_frames_start
|
121 |
-
valid_end = valid_length - skip_frames_end
|
122 |
-
else:
|
123 |
-
# print("no use skip frame")
|
124 |
-
valid_start = 0
|
125 |
-
valid_end = valid_length
|
126 |
-
|
127 |
-
if valid_length <= n_frames:
|
128 |
-
return get_frame_indices_adjusted_for_face(selected_segment, n_frames), valid_length
|
129 |
-
else:
|
130 |
-
adjusted_length = valid_end - valid_start
|
131 |
-
if adjusted_length <= 0:
|
132 |
-
print(f"video_length: {valid_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}")
|
133 |
-
raise ValueError("Skipping too many frames results in no frames left to sample.")
|
134 |
-
|
135 |
-
clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1)
|
136 |
-
start_idx_position = random.randint(valid_start, valid_end - clip_length)
|
137 |
-
start_frame = selected_segment[start_idx_position]
|
138 |
-
|
139 |
-
selected_frames = []
|
140 |
-
for i in range(n_frames):
|
141 |
-
next_frame = start_frame + i * sample_stride
|
142 |
-
if next_frame in selected_segment:
|
143 |
-
selected_frames.append(next_frame)
|
144 |
-
else:
|
145 |
-
break
|
146 |
-
|
147 |
-
if len(selected_frames) < n_frames:
|
148 |
-
return get_frame_indices_adjusted_for_face(selected_frames, n_frames), len(selected_frames)
|
149 |
-
|
150 |
-
return selected_frames, len(selected_frames)
|
151 |
-
|
152 |
-
def frame_has_required_confidence(bbox_data, frame, ID, conf_threshold=0.88):
|
153 |
-
frame_str = str(frame)
|
154 |
-
if frame_str not in bbox_data:
|
155 |
-
return False
|
156 |
-
|
157 |
-
frame_data = bbox_data[frame_str]
|
158 |
-
|
159 |
-
face_conf = any(
|
160 |
-
item['confidence'] > conf_threshold and item['new_track_id'] == ID
|
161 |
-
for item in frame_data.get('face', [])
|
162 |
-
)
|
163 |
-
|
164 |
-
head_conf = any(
|
165 |
-
item['confidence'] > conf_threshold and item['new_track_id'] == ID
|
166 |
-
for item in frame_data.get('head', [])
|
167 |
-
)
|
168 |
-
|
169 |
-
return face_conf and head_conf
|
170 |
-
|
171 |
-
def select_mask_frames_from_index(batch_frame, original_batch_frame, valid_id, corresponding_data, control_sam2_frame,
|
172 |
-
valid_frame, bbox_data, base_dir, min_distance=3, min_frames=1, max_frames=5,
|
173 |
-
mask_type='face', control_mask_type='head', dense_masks=False,
|
174 |
-
ensure_control_frame=True):
|
175 |
-
"""
|
176 |
-
Selects frames with corresponding mask images while ensuring a minimum distance constraint between frames,
|
177 |
-
and that the frames exist in both batch_frame and valid_frame.
|
178 |
-
|
179 |
-
Parameters:
|
180 |
-
base_path (str): Base directory where the JSON files and mask results are located.
|
181 |
-
min_distance (int): Minimum distance between selected frames.
|
182 |
-
min_frames (int): Minimum number of frames to select.
|
183 |
-
max_frames (int): Maximum number of frames to select.
|
184 |
-
mask_type (str): Type of mask to select frames for ('face' or 'head').
|
185 |
-
control_mask_type (str): Type of mask used for control frame selection ('face' or 'head').
|
186 |
-
|
187 |
-
Returns:
|
188 |
-
dict: A dictionary where keys are IDs and values are lists of selected mask PNG paths.
|
189 |
-
"""
|
190 |
-
# Helper function to randomly select frames with at least X frames apart
|
191 |
-
def select_frames_with_distance_constraint(frames, num_frames, min_distance, control_frame, bbox_data, ID,
|
192 |
-
ensure_control_frame=True, fallback=True):
|
193 |
-
"""
|
194 |
-
Selects frames with a minimum distance constraint. If not enough frames can be selected, a fallback plan is applied.
|
195 |
-
|
196 |
-
Parameters:
|
197 |
-
frames (list): List of frame indices to select from.
|
198 |
-
num_frames (int): Number of frames to select.
|
199 |
-
min_distance (int): Minimum distance between selected frames.
|
200 |
-
control_frame (int): The control frame that must always be included.
|
201 |
-
fallback (bool): Whether to apply a fallback strategy if not enough frames meet the distance constraint.
|
202 |
-
|
203 |
-
Returns:
|
204 |
-
list: List of selected frames.
|
205 |
-
"""
|
206 |
-
conf_thresholds = [0.95, 0.94, 0.93, 0.92, 0.91, 0.90]
|
207 |
-
if ensure_control_frame:
|
208 |
-
selected_frames = [control_frame] # Ensure control frame is always included
|
209 |
-
else:
|
210 |
-
valid_initial_frames = []
|
211 |
-
for conf_threshold in conf_thresholds:
|
212 |
-
valid_initial_frames = [
|
213 |
-
f for f in frames
|
214 |
-
if frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold)
|
215 |
-
]
|
216 |
-
if valid_initial_frames:
|
217 |
-
break
|
218 |
-
if valid_initial_frames:
|
219 |
-
selected_frames = [random.choice(valid_initial_frames)]
|
220 |
-
else:
|
221 |
-
# If no frame meets the initial confidence, fall back to a random frame (or handle as per your preference)
|
222 |
-
selected_frames = [random.choice(frames)]
|
223 |
-
|
224 |
-
available_frames = [f for f in frames if f != selected_frames[0]] # Exclude control frame for random selection
|
225 |
-
|
226 |
-
random.shuffle(available_frames) # Shuffle to introduce randomness
|
227 |
-
|
228 |
-
while available_frames and len(selected_frames) < num_frames:
|
229 |
-
last_selected_frame = selected_frames[-1]
|
230 |
-
|
231 |
-
valid_choices = []
|
232 |
-
for conf_threshold in conf_thresholds:
|
233 |
-
valid_choices = [
|
234 |
-
f for f in available_frames
|
235 |
-
if abs(f - last_selected_frame) >= min_distance and
|
236 |
-
frame_has_required_confidence(bbox_data, f, ID, conf_threshold=conf_threshold)
|
237 |
-
]
|
238 |
-
if valid_choices:
|
239 |
-
break
|
240 |
-
|
241 |
-
if valid_choices:
|
242 |
-
frame = random.choice(valid_choices)
|
243 |
-
available_frames.remove(frame)
|
244 |
-
selected_frames.append(frame)
|
245 |
-
else:
|
246 |
-
if fallback:
|
247 |
-
# Fallback strategy: uniformly distribute remaining frames if distance constraint cannot be met
|
248 |
-
remaining_needed = num_frames - len(selected_frames)
|
249 |
-
remaining_frames = available_frames[:remaining_needed]
|
250 |
-
|
251 |
-
# Distribute the remaining frames evenly if possible
|
252 |
-
if remaining_frames:
|
253 |
-
step = max(1, len(remaining_frames) // remaining_needed)
|
254 |
-
evenly_selected = remaining_frames[::step][:remaining_needed]
|
255 |
-
selected_frames.extend(evenly_selected)
|
256 |
-
break
|
257 |
-
else:
|
258 |
-
break # No valid choices remain and no fallback strategy is allowed
|
259 |
-
|
260 |
-
if len(selected_frames) < num_frames:
|
261 |
-
return None
|
262 |
-
|
263 |
-
return selected_frames
|
264 |
-
|
265 |
-
# Convert batch_frame list to a set to remove duplicates
|
266 |
-
batch_frame_set = set(batch_frame)
|
267 |
-
|
268 |
-
# Dictionary to store selected mask PNGs
|
269 |
-
selected_masks_dict = {}
|
270 |
-
selected_bboxs_dict = {}
|
271 |
-
dense_masks_dict = {}
|
272 |
-
selected_frames_dict = {}
|
273 |
-
|
274 |
-
# ID
|
275 |
-
try:
|
276 |
-
mask_valid_frames = valid_frame[mask_type] # Select frames based on the specified mask type
|
277 |
-
control_valid_frames = valid_frame[control_mask_type] # Control frames for control_mask_type
|
278 |
-
except KeyError:
|
279 |
-
if mask_type not in valid_frame.keys():
|
280 |
-
print(f"no valid {mask_type}")
|
281 |
-
if control_mask_type not in valid_frame.keys():
|
282 |
-
print(f"no valid {control_mask_type}")
|
283 |
-
|
284 |
-
# Get the control frame for the control mask type
|
285 |
-
control_frame = control_sam2_frame[valid_id][control_mask_type]
|
286 |
-
|
287 |
-
# Filter frames to only those which are in both valid_frame and batch_frame_set
|
288 |
-
valid_frames = []
|
289 |
-
# valid_frames = [frame for frame in mask_valid_frames if frame in control_valid_frames and frame in batch_frame_set]
|
290 |
-
for frame in mask_valid_frames:
|
291 |
-
if frame in control_valid_frames and frame in batch_frame_set:
|
292 |
-
# Check if bbox_data has 'head' or 'face' for the frame
|
293 |
-
if str(frame) in bbox_data:
|
294 |
-
frame_data = bbox_data[str(frame)]
|
295 |
-
if 'head' in frame_data or 'face' in frame_data:
|
296 |
-
valid_frames.append(frame)
|
297 |
-
|
298 |
-
# Ensure the control frame is included in the valid frames
|
299 |
-
if ensure_control_frame and (control_frame not in valid_frames):
|
300 |
-
valid_frames.append(control_frame)
|
301 |
-
|
302 |
-
# Select a random number of frames between min_frames and max_frames
|
303 |
-
num_frames_to_select = random.randint(min_frames, max_frames)
|
304 |
-
selected_frames = select_frames_with_distance_constraint(valid_frames, num_frames_to_select, min_distance,
|
305 |
-
control_frame, bbox_data, valid_id, ensure_control_frame)
|
306 |
-
|
307 |
-
# Store the selected frames as mask PNGs and bbox
|
308 |
-
selected_masks_dict[valid_id] = []
|
309 |
-
selected_bboxs_dict[valid_id] = []
|
310 |
-
|
311 |
-
# Initialize the dense_masks_dict entry for the current ID
|
312 |
-
dense_masks_dict[valid_id] = []
|
313 |
-
|
314 |
-
# Store the selected frames in the dictionary
|
315 |
-
selected_frames_dict[valid_id] = selected_frames
|
316 |
-
|
317 |
-
if dense_masks:
|
318 |
-
for frame in original_batch_frame:
|
319 |
-
mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{int(frame):05d}.png"
|
320 |
-
mask_array = np.array(Image.open(mask_data_path))
|
321 |
-
binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8)
|
322 |
-
dense_masks_dict[valid_id].append(binary_mask)
|
323 |
-
|
324 |
-
for frame in selected_frames:
|
325 |
-
mask_data_path = f"{base_dir}/{valid_id}/annotated_frame_{frame:05d}.png"
|
326 |
-
mask_array = np.array(Image.open(mask_data_path))
|
327 |
-
binary_mask = np.where(mask_array > 0, 1, 0).astype(np.uint8)
|
328 |
-
selected_masks_dict[valid_id].append(binary_mask)
|
329 |
-
|
330 |
-
try:
|
331 |
-
for item in bbox_data[f"{frame}"]["head"]:
|
332 |
-
if item['new_track_id'] == int(valid_id):
|
333 |
-
temp_bbox = item['box']
|
334 |
-
break
|
335 |
-
except (KeyError, StopIteration):
|
336 |
-
try:
|
337 |
-
for item in bbox_data[f"{frame}"]["face"]:
|
338 |
-
if item['new_track_id'] == int(valid_id):
|
339 |
-
temp_bbox = item['box']
|
340 |
-
break
|
341 |
-
except (KeyError, StopIteration):
|
342 |
-
temp_bbox = None
|
343 |
-
|
344 |
-
selected_bboxs_dict[valid_id].append(temp_bbox)
|
345 |
-
|
346 |
-
return selected_frames_dict, selected_masks_dict, selected_bboxs_dict, dense_masks_dict
|
347 |
-
|
348 |
-
def pad_tensor(tensor, target_size, dim=0):
|
349 |
-
padding_size = target_size - tensor.size(dim)
|
350 |
-
if padding_size > 0:
|
351 |
-
pad_shape = list(tensor.shape)
|
352 |
-
pad_shape[dim] = padding_size
|
353 |
-
padding_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
|
354 |
-
return torch.cat([tensor, padding_tensor], dim=dim)
|
355 |
-
else:
|
356 |
-
return tensor[:target_size]
|
357 |
-
|
358 |
-
def crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=False):
|
359 |
-
"""
|
360 |
-
Crop images based on given bounding boxes and frame indices from a video.
|
361 |
-
|
362 |
-
Args:
|
363 |
-
selected_frame_index (list): List of frame indices to be cropped.
|
364 |
-
selected_bboxs_dict (list of dict): List of dictionaries, each containing 'x1', 'y1', 'x2', 'y2' bounding box coordinates.
|
365 |
-
video_reader (VideoReader or list of numpy arrays): Video frames accessible by index, where each frame is a numpy array (H, W, C).
|
366 |
-
|
367 |
-
Returns:
|
368 |
-
list: A list of cropped images in PIL Image format.
|
369 |
-
"""
|
370 |
-
expanded_cropped_images = []
|
371 |
-
original_cropped_images = []
|
372 |
-
for frame_idx, bbox in zip(selected_frame_index, selected_bboxs_dict):
|
373 |
-
# Get the specific frame from the video reader using the frame index
|
374 |
-
frame = video_reader[frame_idx] # torch.tensor # (H, W, C)
|
375 |
-
|
376 |
-
# Extract bounding box coordinates and convert them to integers
|
377 |
-
x1, y1, x2, y2 = int(bbox['x1']), int(bbox['y1']), int(bbox['x2']), int(bbox['y2'])
|
378 |
-
# Crop to minimize the bounding box to a square
|
379 |
-
width = x2 - x1 # Calculate the width of the bounding box
|
380 |
-
height = y2 - y1 # Calculate the height of the bounding box
|
381 |
-
side_length = max(width, height) # Determine the side length of the square (max of width or height)
|
382 |
-
|
383 |
-
# Calculate the center of the bounding box
|
384 |
-
center_x = (x1 + x2) // 2
|
385 |
-
center_y = (y1 + y2) // 2
|
386 |
-
|
387 |
-
# Calculate new coordinates for the square region centered around the original bounding box
|
388 |
-
new_x1 = max(0, center_x - side_length // 2) # Ensure x1 is within image bounds
|
389 |
-
new_y1 = max(0, center_y - side_length // 2) # Ensure y1 is within image bounds
|
390 |
-
new_x2 = min(frame.shape[1], new_x1 + side_length) # Ensure x2 does not exceed image width
|
391 |
-
new_y2 = min(frame.shape[0], new_y1 + side_length) # Ensure y2 does not exceed image height
|
392 |
-
|
393 |
-
# Adjust coordinates if the cropped area is smaller than the desired side length
|
394 |
-
# Ensure final width and height are equal, keeping it a square
|
395 |
-
actual_width = new_x2 - new_x1
|
396 |
-
actual_height = new_y2 - new_y1
|
397 |
-
|
398 |
-
if actual_width < side_length:
|
399 |
-
# Adjust x1 or x2 to ensure the correct side length, while staying in bounds
|
400 |
-
if new_x1 == 0:
|
401 |
-
new_x2 = min(frame.shape[1], new_x1 + side_length)
|
402 |
-
else:
|
403 |
-
new_x1 = max(0, new_x2 - side_length)
|
404 |
-
|
405 |
-
if actual_height < side_length:
|
406 |
-
# Adjust y1 or y2 to ensure the correct side length, while staying in bounds
|
407 |
-
if new_y1 == 0:
|
408 |
-
new_y2 = min(frame.shape[0], new_y1 + side_length)
|
409 |
-
else:
|
410 |
-
new_y1 = max(0, new_y2 - side_length)
|
411 |
-
|
412 |
-
# Expand the square by 20%
|
413 |
-
expansion_ratio = 0.2 # Define the expansion ratio
|
414 |
-
expansion_amount = int(side_length * expansion_ratio) # Calculate the number of pixels to expand by
|
415 |
-
|
416 |
-
# Calculate expanded coordinates, ensuring they stay within image bounds
|
417 |
-
expanded_x1 = max(0, new_x1 - expansion_amount) # Expand left, ensuring x1 is within bounds
|
418 |
-
expanded_y1 = max(0, new_y1 - expansion_amount) # Expand up, ensuring y1 is within bounds
|
419 |
-
expanded_x2 = min(frame.shape[1], new_x2 + expansion_amount) # Expand right, ensuring x2 does not exceed bounds
|
420 |
-
expanded_y2 = min(frame.shape[0], new_y2 + expansion_amount) # Expand down, ensuring y2 does not exceed bounds
|
421 |
-
|
422 |
-
# Ensure the expanded area is still a square
|
423 |
-
expanded_width = expanded_x2 - expanded_x1
|
424 |
-
expanded_height = expanded_y2 - expanded_y1
|
425 |
-
final_side_length = min(expanded_width, expanded_height)
|
426 |
-
|
427 |
-
# Adjust to ensure square shape if necessary
|
428 |
-
if expanded_width != expanded_height:
|
429 |
-
if expanded_width > expanded_height:
|
430 |
-
expanded_x2 = expanded_x1 + final_side_length
|
431 |
-
else:
|
432 |
-
expanded_y2 = expanded_y1 + final_side_length
|
433 |
-
|
434 |
-
expanded_cropped_rgb_tensor = frame[expanded_y1:expanded_y2, expanded_x1:expanded_x2, :]
|
435 |
-
expanded_cropped_rgb = Image.fromarray(np.array(expanded_cropped_rgb_tensor)).convert('RGB')
|
436 |
-
expanded_cropped_images.append(expanded_cropped_rgb)
|
437 |
-
|
438 |
-
if return_ori:
|
439 |
-
original_cropped_rgb_tensor = frame[new_y1:new_y2, new_x1:new_x2, :]
|
440 |
-
original_cropped_rgb = Image.fromarray(np.array(original_cropped_rgb_tensor)).convert('RGB')
|
441 |
-
original_cropped_images.append(original_cropped_rgb)
|
442 |
-
return expanded_cropped_images, original_cropped_images
|
443 |
-
|
444 |
-
return expanded_cropped_images, None
|
445 |
-
|
446 |
-
def process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480)):
|
447 |
-
"""
|
448 |
-
Process a list of cropped images in PIL format.
|
449 |
-
|
450 |
-
Parameters:
|
451 |
-
expand_images_pil (list of PIL.Image): List of cropped images in PIL format.
|
452 |
-
target_size (tuple of int): The target size for resizing images, default is (480, 480).
|
453 |
-
|
454 |
-
Returns:
|
455 |
-
torch.Tensor: A tensor containing the processed images.
|
456 |
-
"""
|
457 |
-
expand_face_imgs = []
|
458 |
-
original_face_imgs = []
|
459 |
-
if len(original_images_pil) != 0:
|
460 |
-
for expand_img, original_img in zip(expand_images_pil, original_images_pil):
|
461 |
-
expand_resized_img = expand_img.resize(target_size, Image.LANCZOS)
|
462 |
-
expand_src_img = np.array(expand_resized_img)
|
463 |
-
expand_src_img = np.transpose(expand_src_img, (2, 0, 1))
|
464 |
-
expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float()
|
465 |
-
expand_face_imgs.append(expand_src_img)
|
466 |
-
|
467 |
-
original_resized_img = original_img.resize(target_size, Image.LANCZOS)
|
468 |
-
original_src_img = np.array(original_resized_img)
|
469 |
-
original_src_img = np.transpose(original_src_img, (2, 0, 1))
|
470 |
-
original_src_img = torch.from_numpy(original_src_img).unsqueeze(0).float()
|
471 |
-
original_face_imgs.append(original_src_img)
|
472 |
-
|
473 |
-
expand_face_imgs = torch.cat(expand_face_imgs, dim=0)
|
474 |
-
original_face_imgs = torch.cat(original_face_imgs, dim=0)
|
475 |
-
else:
|
476 |
-
for expand_img in expand_images_pil:
|
477 |
-
expand_resized_img = expand_img.resize(target_size, Image.LANCZOS)
|
478 |
-
expand_src_img = np.array(expand_resized_img)
|
479 |
-
expand_src_img = np.transpose(expand_src_img, (2, 0, 1))
|
480 |
-
expand_src_img = torch.from_numpy(expand_src_img).unsqueeze(0).float()
|
481 |
-
expand_face_imgs.append(expand_src_img)
|
482 |
-
expand_face_imgs = torch.cat(expand_face_imgs, dim=0)
|
483 |
-
original_face_imgs = None
|
484 |
-
|
485 |
-
return expand_face_imgs, original_face_imgs
|
486 |
-
|
487 |
-
class RandomSampler(Sampler[int]):
|
488 |
-
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
|
489 |
-
|
490 |
-
If with replacement, then user can specify :attr:`num_samples` to draw.
|
491 |
-
|
492 |
-
Args:
|
493 |
-
data_source (Dataset): dataset to sample from
|
494 |
-
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
|
495 |
-
num_samples (int): number of samples to draw, default=`len(dataset)`.
|
496 |
-
generator (Generator): Generator used in sampling.
|
497 |
-
"""
|
498 |
-
|
499 |
-
data_source: Sized
|
500 |
-
replacement: bool
|
501 |
-
|
502 |
-
def __init__(self, data_source: Sized, replacement: bool = False,
|
503 |
-
num_samples: Optional[int] = None, generator=None) -> None:
|
504 |
-
self.data_source = data_source
|
505 |
-
self.replacement = replacement
|
506 |
-
self._num_samples = num_samples
|
507 |
-
self.generator = generator
|
508 |
-
self._pos_start = 0
|
509 |
-
|
510 |
-
if not isinstance(self.replacement, bool):
|
511 |
-
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
|
512 |
-
|
513 |
-
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
|
514 |
-
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
|
515 |
-
|
516 |
-
@property
|
517 |
-
def num_samples(self) -> int:
|
518 |
-
# dataset size might change at runtime
|
519 |
-
if self._num_samples is None:
|
520 |
-
return len(self.data_source)
|
521 |
-
return self._num_samples
|
522 |
-
|
523 |
-
def __iter__(self) -> Iterator[int]:
|
524 |
-
n = len(self.data_source)
|
525 |
-
if self.generator is None:
|
526 |
-
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
527 |
-
generator = torch.Generator()
|
528 |
-
generator.manual_seed(seed)
|
529 |
-
else:
|
530 |
-
generator = self.generator
|
531 |
-
|
532 |
-
if self.replacement:
|
533 |
-
for _ in range(self.num_samples // 32):
|
534 |
-
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
|
535 |
-
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
|
536 |
-
else:
|
537 |
-
for _ in range(self.num_samples // n):
|
538 |
-
xx = torch.randperm(n, generator=generator).tolist()
|
539 |
-
if self._pos_start >= n:
|
540 |
-
self._pos_start = 0
|
541 |
-
print("xx top 10", xx[:10], self._pos_start)
|
542 |
-
for idx in range(self._pos_start, n):
|
543 |
-
yield xx[idx]
|
544 |
-
self._pos_start = (self._pos_start + 1) % n
|
545 |
-
self._pos_start = 0
|
546 |
-
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
|
547 |
-
|
548 |
-
def __len__(self) -> int:
|
549 |
-
return self.num_samples
|
550 |
-
|
551 |
-
class SequentialSampler(Sampler[int]):
|
552 |
-
r"""Samples elements sequentially, always in the same order.
|
553 |
-
|
554 |
-
Args:
|
555 |
-
data_source (Dataset): dataset to sample from
|
556 |
-
"""
|
557 |
-
|
558 |
-
data_source: Sized
|
559 |
-
|
560 |
-
def __init__(self, data_source: Sized) -> None:
|
561 |
-
self.data_source = data_source
|
562 |
-
self._pos_start = 0
|
563 |
-
|
564 |
-
def __iter__(self) -> Iterator[int]:
|
565 |
-
n = len(self.data_source)
|
566 |
-
for idx in range(self._pos_start, n):
|
567 |
-
yield idx
|
568 |
-
self._pos_start = (self._pos_start + 1) % n
|
569 |
-
self._pos_start = 0
|
570 |
-
|
571 |
-
def __len__(self) -> int:
|
572 |
-
return len(self.data_source)
|
573 |
-
|
574 |
-
class ConsisID_Dataset(Dataset):
|
575 |
-
def __init__(
|
576 |
-
self,
|
577 |
-
instance_data_root: Optional[str] = None,
|
578 |
-
id_token: Optional[str] = None,
|
579 |
-
height=480,
|
580 |
-
width=640,
|
581 |
-
max_num_frames=49,
|
582 |
-
sample_stride=3,
|
583 |
-
skip_frames_start_percent=0.0,
|
584 |
-
skip_frames_end_percent=1.0,
|
585 |
-
skip_frames_start=0,
|
586 |
-
skip_frames_end=0,
|
587 |
-
text_drop_ratio=-1,
|
588 |
-
is_train_face=False,
|
589 |
-
is_single_face=False,
|
590 |
-
miss_tolerance=6,
|
591 |
-
min_distance=3,
|
592 |
-
min_frames=1,
|
593 |
-
max_frames=5,
|
594 |
-
is_cross_face=False,
|
595 |
-
is_reserve_face=False,
|
596 |
-
):
|
597 |
-
self.id_token = id_token or ""
|
598 |
-
|
599 |
-
# ConsisID
|
600 |
-
self.skip_frames_start_percent = skip_frames_start_percent
|
601 |
-
self.skip_frames_end_percent = skip_frames_end_percent
|
602 |
-
self.skip_frames_start = skip_frames_start
|
603 |
-
self.skip_frames_end = skip_frames_end
|
604 |
-
self.is_train_face = is_train_face
|
605 |
-
self.is_single_face = is_single_face
|
606 |
-
|
607 |
-
if is_train_face:
|
608 |
-
self.miss_tolerance = miss_tolerance
|
609 |
-
self.min_distance = min_distance
|
610 |
-
self.min_frames = min_frames
|
611 |
-
self.max_frames = max_frames
|
612 |
-
self.is_cross_face = is_cross_face
|
613 |
-
self.is_reserve_face = is_reserve_face
|
614 |
-
|
615 |
-
# Loading annotations from files
|
616 |
-
print(f"loading annotations from {instance_data_root} ...")
|
617 |
-
with open(instance_data_root, 'r') as f:
|
618 |
-
folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]
|
619 |
-
|
620 |
-
self.instance_prompts = []
|
621 |
-
self.instance_video_paths = []
|
622 |
-
self.instance_annotation_base_paths = []
|
623 |
-
for sub_root, anno, anno_base in tqdm(folder_anno):
|
624 |
-
print(anno)
|
625 |
-
self.instance_annotation_base_paths.append(anno_base)
|
626 |
-
with open(anno, 'r') as f:
|
627 |
-
sub_list = json.load(f)
|
628 |
-
for i in tqdm(sub_list):
|
629 |
-
path = os.path.join(sub_root, os.path.basename(i['path']))
|
630 |
-
cap = i.get('cap', None)
|
631 |
-
fps = i.get('fps', 0)
|
632 |
-
duration = i.get('duration', 0)
|
633 |
-
|
634 |
-
if fps * duration < 49.0:
|
635 |
-
continue
|
636 |
-
|
637 |
-
self.instance_prompts.append(cap)
|
638 |
-
self.instance_video_paths.append(path)
|
639 |
-
|
640 |
-
self.num_instance_videos = len(self.instance_video_paths)
|
641 |
-
|
642 |
-
self.text_drop_ratio = text_drop_ratio
|
643 |
-
|
644 |
-
# Video params
|
645 |
-
self.sample_stride = sample_stride
|
646 |
-
self.max_num_frames = max_num_frames
|
647 |
-
self.height = height
|
648 |
-
self.width = width
|
649 |
-
|
650 |
-
def _get_frame_indices_adjusted(self, video_length, n_frames):
|
651 |
-
indices = list(range(video_length))
|
652 |
-
additional_frames_needed = n_frames - video_length
|
653 |
-
|
654 |
-
repeat_indices = []
|
655 |
-
for i in range(additional_frames_needed):
|
656 |
-
index_to_repeat = i % video_length
|
657 |
-
repeat_indices.append(indices[index_to_repeat])
|
658 |
-
|
659 |
-
all_indices = indices + repeat_indices
|
660 |
-
all_indices.sort()
|
661 |
-
|
662 |
-
return all_indices
|
663 |
-
|
664 |
-
|
665 |
-
def _generate_frame_indices(self, video_length, n_frames, sample_stride, skip_frames_start_percent=0.0, skip_frames_end_percent=1.0, skip_frames_start=0, skip_frames_end=0):
|
666 |
-
if skip_frames_start_percent != 0.0 or skip_frames_end_percent != 1.0:
|
667 |
-
print("use skip frame percent")
|
668 |
-
valid_start = int(video_length * skip_frames_start_percent)
|
669 |
-
valid_end = int(video_length * skip_frames_end_percent)
|
670 |
-
elif skip_frames_start != 0 or skip_frames_end != 0:
|
671 |
-
print("use skip frame")
|
672 |
-
valid_start = skip_frames_start
|
673 |
-
valid_end = video_length - skip_frames_end
|
674 |
-
else:
|
675 |
-
print("no use skip frame")
|
676 |
-
valid_start = 0
|
677 |
-
valid_end = video_length
|
678 |
-
|
679 |
-
adjusted_length = valid_end - valid_start
|
680 |
-
|
681 |
-
if adjusted_length <= 0:
|
682 |
-
print(f"video_length: {video_length}, adjusted_length: {adjusted_length}, valid_start:{valid_start}, skip_frames_end: {valid_end}")
|
683 |
-
raise ValueError("Skipping too many frames results in no frames left to sample.")
|
684 |
-
|
685 |
-
if video_length <= n_frames:
|
686 |
-
return self._get_frame_indices_adjusted(video_length, n_frames)
|
687 |
-
else:
|
688 |
-
# clip_length = min(video_length, (n_frames - 1) * sample_stride + 1)
|
689 |
-
# start_idx = random.randint(0, video_length - clip_length)
|
690 |
-
# frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
|
691 |
-
|
692 |
-
clip_length = min(adjusted_length, (n_frames - 1) * sample_stride + 1)
|
693 |
-
start_idx = random.randint(valid_start, valid_end - clip_length)
|
694 |
-
frame_indices = np.linspace(start_idx, start_idx + clip_length - 1, n_frames, dtype=int).tolist()
|
695 |
-
return frame_indices
|
696 |
-
|
697 |
-
def _short_resize_and_crop(self, frames, target_width, target_height):
|
698 |
-
"""
|
699 |
-
Resize frames and crop to the specified size.
|
700 |
-
|
701 |
-
Args:
|
702 |
-
frames (torch.Tensor): Input frames of shape [T, H, W, C].
|
703 |
-
target_width (int): Desired width.
|
704 |
-
target_height (int): Desired height.
|
705 |
-
|
706 |
-
Returns:
|
707 |
-
torch.Tensor: Cropped frames of shape [T, target_height, target_width, C].
|
708 |
-
"""
|
709 |
-
T, H, W, C = frames.shape
|
710 |
-
aspect_ratio = W / H
|
711 |
-
|
712 |
-
# Determine new dimensions ensuring they are at least target size
|
713 |
-
if aspect_ratio > target_width / target_height:
|
714 |
-
new_width = target_width
|
715 |
-
new_height = int(target_width / aspect_ratio)
|
716 |
-
if new_height < target_height:
|
717 |
-
new_height = target_height
|
718 |
-
new_width = int(target_height * aspect_ratio)
|
719 |
-
else:
|
720 |
-
new_height = target_height
|
721 |
-
new_width = int(target_height * aspect_ratio)
|
722 |
-
if new_width < target_width:
|
723 |
-
new_width = target_width
|
724 |
-
new_height = int(target_width / aspect_ratio)
|
725 |
-
|
726 |
-
resize_transform = transforms.Resize((new_height, new_width))
|
727 |
-
crop_transform = transforms.CenterCrop((target_height, target_width))
|
728 |
-
|
729 |
-
frames_tensor = frames.permute(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
|
730 |
-
resized_frames = resize_transform(frames_tensor)
|
731 |
-
cropped_frames = crop_transform(resized_frames)
|
732 |
-
sample = cropped_frames.permute(0, 2, 3, 1)
|
733 |
-
|
734 |
-
return sample
|
735 |
-
|
736 |
-
def _resize_with_aspect_ratio(self, frames, target_width, target_height):
|
737 |
-
"""
|
738 |
-
Resize frames while maintaining the aspect ratio by padding or cropping.
|
739 |
-
|
740 |
-
Args:
|
741 |
-
frames (torch.Tensor): Input frames of shape [T, H, W, C].
|
742 |
-
target_width (int): Desired width.
|
743 |
-
target_height (int): Desired height.
|
744 |
-
|
745 |
-
Returns:
|
746 |
-
torch.Tensor: Resized and padded frames of shape [T, target_height, target_width, C].
|
747 |
-
"""
|
748 |
-
T, frame_height, frame_width, C = frames.shape
|
749 |
-
aspect_ratio = frame_width / frame_height # 1.77, 1280 720 -> 720 406
|
750 |
-
target_aspect_ratio = target_width / target_height # 1.50, 720 480 ->
|
751 |
-
|
752 |
-
# If the frame is wider than the target, resize based on width
|
753 |
-
if aspect_ratio > target_aspect_ratio:
|
754 |
-
new_width = target_width
|
755 |
-
new_height = int(target_width / aspect_ratio)
|
756 |
-
else:
|
757 |
-
new_height = target_height
|
758 |
-
new_width = int(target_height * aspect_ratio)
|
759 |
-
|
760 |
-
# Resize using batch processing
|
761 |
-
frames = frames.permute(0, 3, 1, 2) # [T, C, H, W]
|
762 |
-
frames = F.interpolate(frames, size=(new_height, new_width), mode='bilinear', align_corners=False)
|
763 |
-
|
764 |
-
# Calculate padding
|
765 |
-
pad_top = (target_height - new_height) // 2
|
766 |
-
pad_bottom = target_height - new_height - pad_top
|
767 |
-
pad_left = (target_width - new_width) // 2
|
768 |
-
pad_right = target_width - new_width - pad_left
|
769 |
-
|
770 |
-
# Apply padding
|
771 |
-
frames = F.pad(frames, (pad_left, pad_right, pad_top, pad_bottom), mode='constant', value=0)
|
772 |
-
|
773 |
-
frames = frames.permute(0, 2, 3, 1) # [T, H, W, C]
|
774 |
-
|
775 |
-
return frames
|
776 |
-
|
777 |
-
|
778 |
-
def _save_frame(self, frame, name="1.png"):
|
779 |
-
# [H, W, C] -> [C, H, W]
|
780 |
-
img = frame
|
781 |
-
img = img.permute(2, 0, 1)
|
782 |
-
to_pil = ToPILImage()
|
783 |
-
img = to_pil(img)
|
784 |
-
img.save(name)
|
785 |
-
|
786 |
-
|
787 |
-
def _save_video(self, torch_frames, name="output.mp4"):
|
788 |
-
from moviepy.editor import ImageSequenceClip
|
789 |
-
frames_np = torch_frames.cpu().numpy()
|
790 |
-
if frames_np.dtype != 'uint8':
|
791 |
-
frames_np = frames_np.astype('uint8')
|
792 |
-
frames_list = [frame for frame in frames_np]
|
793 |
-
desired_fps = 24
|
794 |
-
clip = ImageSequenceClip(frames_list, fps=desired_fps)
|
795 |
-
clip.write_videofile(name, codec="libx264")
|
796 |
-
|
797 |
-
|
798 |
-
def get_batch(self, idx):
|
799 |
-
decord.bridge.set_bridge("torch")
|
800 |
-
|
801 |
-
video_dir = self.instance_video_paths[idx]
|
802 |
-
text = self.instance_prompts[idx]
|
803 |
-
|
804 |
-
train_transforms = transforms.Compose(
|
805 |
-
[
|
806 |
-
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
807 |
-
]
|
808 |
-
)
|
809 |
-
|
810 |
-
with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
|
811 |
-
video_num_frames = len(video_reader)
|
812 |
-
|
813 |
-
if self.is_train_face:
|
814 |
-
reserve_face_imgs = None
|
815 |
-
file_base_name = os.path.basename(video_dir.replace(".mp4", ""))
|
816 |
-
|
817 |
-
anno_base_path = self.instance_annotation_base_paths[idx]
|
818 |
-
valid_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "valid_frame.json")
|
819 |
-
control_sam2_frame_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "control_sam2_frame.json")
|
820 |
-
corresponding_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "corresponding_data.json")
|
821 |
-
masks_data_path = os.path.join(anno_base_path, "track_masks_data", file_base_name, "tracking_mask_results")
|
822 |
-
bboxs_data_path = os.path.join(anno_base_path, "refine_bbox_jsons", f"{file_base_name}.json")
|
823 |
-
|
824 |
-
with open(corresponding_data_path, 'r') as f:
|
825 |
-
corresponding_data = json.load(f)
|
826 |
-
|
827 |
-
with open(control_sam2_frame_path, 'r') as f:
|
828 |
-
control_sam2_frame = json.load(f)
|
829 |
-
|
830 |
-
with open(valid_frame_path, 'r') as f:
|
831 |
-
valid_frame = json.load(f)
|
832 |
-
|
833 |
-
with open(bboxs_data_path, 'r') as f:
|
834 |
-
bbox_data = json.load(f)
|
835 |
-
|
836 |
-
if self.is_single_face:
|
837 |
-
if len(corresponding_data) != 1:
|
838 |
-
raise ValueError(f"Using single face, but {idx} is multi person.")
|
839 |
-
|
840 |
-
# get random valid id
|
841 |
-
valid_ids = []
|
842 |
-
backup_ids = []
|
843 |
-
for id_key, data in corresponding_data.items():
|
844 |
-
if 'face' in data and 'head' in data:
|
845 |
-
valid_ids.append(id_key)
|
846 |
-
|
847 |
-
valid_id = random.choice(valid_ids) if valid_ids else (random.choice(backup_ids) if backup_ids else None)
|
848 |
-
if valid_id is None:
|
849 |
-
raise ValueError("No valid ID found: both valid_ids and backup_ids are empty.")
|
850 |
-
|
851 |
-
# get video
|
852 |
-
total_index = list(range(video_num_frames))
|
853 |
-
batch_index, _ = generate_frame_indices_for_face(self.max_num_frames, self.sample_stride, valid_frame[valid_id],
|
854 |
-
self.miss_tolerance, self.skip_frames_start_percent, self.skip_frames_end_percent,
|
855 |
-
self.skip_frames_start, self.skip_frames_end)
|
856 |
-
|
857 |
-
if self.is_cross_face:
|
858 |
-
remaining_batch_index_index = [i for i in total_index if i not in batch_index]
|
859 |
-
try:
|
860 |
-
selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
|
861 |
-
remaining_batch_index_index,
|
862 |
-
batch_index, valid_id,
|
863 |
-
corresponding_data, control_sam2_frame,
|
864 |
-
valid_frame[valid_id], bbox_data, masks_data_path,
|
865 |
-
min_distance=self.min_distance, min_frames=self.min_frames,
|
866 |
-
max_frames=self.max_frames, dense_masks=True,
|
867 |
-
ensure_control_frame=False,
|
868 |
-
)
|
869 |
-
except:
|
870 |
-
selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
|
871 |
-
batch_index,
|
872 |
-
batch_index, valid_id,
|
873 |
-
corresponding_data, control_sam2_frame,
|
874 |
-
valid_frame[valid_id], bbox_data, masks_data_path,
|
875 |
-
min_distance=self.min_distance, min_frames=self.min_frames,
|
876 |
-
max_frames=self.max_frames, dense_masks=True,
|
877 |
-
ensure_control_frame=False,
|
878 |
-
)
|
879 |
-
else:
|
880 |
-
selected_frame_index, selected_masks_dict, selected_bboxs_dict, dense_masks_dict = select_mask_frames_from_index(
|
881 |
-
batch_index,
|
882 |
-
batch_index, valid_id,
|
883 |
-
corresponding_data, control_sam2_frame,
|
884 |
-
valid_frame[valid_id], bbox_data, masks_data_path,
|
885 |
-
min_distance=self.min_distance, min_frames=self.min_frames,
|
886 |
-
max_frames=self.max_frames, dense_masks=True,
|
887 |
-
ensure_control_frame=True,
|
888 |
-
)
|
889 |
-
if self.is_reserve_face:
|
890 |
-
reserve_frame_index, _, reserve_bboxs_dict, _ = select_mask_frames_from_index(
|
891 |
-
batch_index,
|
892 |
-
batch_index, valid_id,
|
893 |
-
corresponding_data, control_sam2_frame,
|
894 |
-
valid_frame[valid_id], bbox_data, masks_data_path,
|
895 |
-
min_distance=3, min_frames=4,
|
896 |
-
max_frames=4, dense_masks=False,
|
897 |
-
ensure_control_frame=False,
|
898 |
-
)
|
899 |
-
|
900 |
-
# get mask and aligned_face_img
|
901 |
-
selected_frame_index = selected_frame_index[valid_id]
|
902 |
-
valid_frame = valid_frame[valid_id]
|
903 |
-
selected_masks_dict = selected_masks_dict[valid_id]
|
904 |
-
selected_bboxs_dict = selected_bboxs_dict[valid_id]
|
905 |
-
dense_masks_dict = dense_masks_dict[valid_id]
|
906 |
-
|
907 |
-
if self.is_reserve_face:
|
908 |
-
reserve_frame_index = reserve_frame_index[valid_id]
|
909 |
-
reserve_bboxs_dict = reserve_bboxs_dict[valid_id]
|
910 |
-
|
911 |
-
selected_masks_tensor = torch.stack([torch.tensor(mask) for mask in selected_masks_dict])
|
912 |
-
temp_dense_masks_tensor = torch.stack([torch.tensor(mask) for mask in dense_masks_dict])
|
913 |
-
dense_masks_tensor = self._short_resize_and_crop(temp_dense_masks_tensor.unsqueeze(-1), self.width, self.height).squeeze(-1) # [T, H, W] -> [T, H, W, 1] -> [T, H, W]
|
914 |
-
|
915 |
-
expand_images_pil, original_images_pil = crop_images(selected_frame_index, selected_bboxs_dict, video_reader, return_ori=True)
|
916 |
-
expand_face_imgs, original_face_imgs = process_cropped_images(expand_images_pil, original_images_pil, target_size=(480, 480))
|
917 |
-
if self.is_reserve_face:
|
918 |
-
reserve_images_pil, _ = crop_images(reserve_frame_index, reserve_bboxs_dict, video_reader, return_ori=False)
|
919 |
-
reserve_face_imgs, _ = process_cropped_images(reserve_images_pil, [], target_size=(480, 480))
|
920 |
-
|
921 |
-
if len(expand_face_imgs) == 0 or len(original_face_imgs) == 0:
|
922 |
-
raise ValueError(f"No face detected in input image pool")
|
923 |
-
|
924 |
-
# post process id related data
|
925 |
-
expand_face_imgs = pad_tensor(expand_face_imgs, self.max_frames, dim=0)
|
926 |
-
original_face_imgs = pad_tensor(original_face_imgs, self.max_frames, dim=0)
|
927 |
-
selected_frame_index = torch.tensor(selected_frame_index) # torch.Size(([15, 13]) [N1]
|
928 |
-
selected_frame_index = pad_tensor(selected_frame_index, self.max_frames, dim=0)
|
929 |
-
else:
|
930 |
-
batch_index = self._generate_frame_indices(video_num_frames, self.max_num_frames, self.sample_stride,
|
931 |
-
self.skip_frames_start_percent, self.skip_frames_end_percent,
|
932 |
-
self.skip_frames_start, self.skip_frames_end)
|
933 |
-
|
934 |
-
try:
|
935 |
-
frames = video_reader.get_batch(batch_index) # torch [T, H, W, C]
|
936 |
-
frames = self._short_resize_and_crop(frames, self.width, self.height) # [T, H, W, C]
|
937 |
-
except FunctionTimedOut:
|
938 |
-
raise ValueError(f"Read {idx} timeout.")
|
939 |
-
except Exception as e:
|
940 |
-
raise ValueError(f"Failed to extract frames from video. Error is {e}.")
|
941 |
-
|
942 |
-
# Apply training transforms in batch
|
943 |
-
frames = frames.float()
|
944 |
-
frames = train_transforms(frames)
|
945 |
-
pixel_values = frames.permute(0, 3, 1, 2).contiguous() # [T, C, H, W]
|
946 |
-
del video_reader
|
947 |
-
|
948 |
-
# Random use no text generation
|
949 |
-
if random.random() < self.text_drop_ratio:
|
950 |
-
text = ''
|
951 |
-
|
952 |
-
if self.is_train_face:
|
953 |
-
return pixel_values, text, 'video', video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs
|
954 |
-
else:
|
955 |
-
return pixel_values, text, 'video', video_dir
|
956 |
-
|
957 |
-
def __len__(self):
|
958 |
-
return self.num_instance_videos
|
959 |
-
|
960 |
-
def __getitem__(self, idx):
|
961 |
-
sample = {}
|
962 |
-
if self.is_train_face:
|
963 |
-
pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx)
|
964 |
-
sample["instance_prompt"] = self.id_token + cap
|
965 |
-
sample["instance_video"] = pixel_values
|
966 |
-
sample["video_path"] = video_dir
|
967 |
-
if self.is_train_face:
|
968 |
-
sample["expand_face_imgs"] = expand_face_imgs
|
969 |
-
sample["dense_masks_tensor"] = dense_masks_tensor
|
970 |
-
sample["selected_frame_index"] = selected_frame_index
|
971 |
-
if reserve_face_imgs is not None:
|
972 |
-
sample["reserve_face_imgs"] = reserve_face_imgs
|
973 |
-
if original_face_imgs is not None:
|
974 |
-
sample["original_face_imgs"] = original_face_imgs
|
975 |
-
else:
|
976 |
-
pixel_values, cap, data_type, video_dir = self.get_batch(idx)
|
977 |
-
sample["instance_prompt"] = self.id_token + cap
|
978 |
-
sample["instance_video"] = pixel_values
|
979 |
-
sample["video_path"] = video_dir
|
980 |
-
return sample
|
981 |
-
|
982 |
-
# while True:
|
983 |
-
# sample = {}
|
984 |
-
# try:
|
985 |
-
# if self.is_train_face:
|
986 |
-
# pixel_values, cap, data_type, video_dir, expand_face_imgs, dense_masks_tensor, selected_frame_index, reserve_face_imgs, original_face_imgs = self.get_batch(idx)
|
987 |
-
# sample["instance_prompt"] = self.id_token + cap
|
988 |
-
# sample["instance_video"] = pixel_values
|
989 |
-
# sample["video_path"] = video_dir
|
990 |
-
# if self.is_train_face:
|
991 |
-
# sample["expand_face_imgs"] = expand_face_imgs
|
992 |
-
# sample["dense_masks_tensor"] = dense_masks_tensor
|
993 |
-
# sample["selected_frame_index"] = selected_frame_index
|
994 |
-
# if reserve_face_imgs is not None:
|
995 |
-
# sample["reserve_face_imgs"] = reserve_face_imgs
|
996 |
-
# if original_face_imgs is not None:
|
997 |
-
# sample["original_face_imgs"] = original_face_imgs
|
998 |
-
# else:
|
999 |
-
# pixel_values, cap, data_type, video_dir, = self.get_batch(idx)
|
1000 |
-
# sample["instance_prompt"] = self.id_token + cap
|
1001 |
-
# sample["instance_video"] = pixel_values
|
1002 |
-
# sample["video_path"] = video_dir
|
1003 |
-
# break
|
1004 |
-
# except Exception as e:
|
1005 |
-
# error_message = str(e)
|
1006 |
-
# video_path = self.instance_video_paths[idx % len(self.instance_video_paths)]
|
1007 |
-
# print(error_message, video_path)
|
1008 |
-
# log_error_to_file(error_message, video_path)
|
1009 |
-
# idx = random.randint(0, self.num_instance_videos - 1)
|
1010 |
-
# return sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util/deepspeed_configs/accelerate_config_machine_multi.yaml
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
compute_environment: LOCAL_MACHINE
|
2 |
-
distributed_type: DEEPSPEED
|
3 |
-
deepspeed_config:
|
4 |
-
deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json
|
5 |
-
deepspeed_hostfile: util/deepspeed_configs/hostfile.txt
|
6 |
-
fsdp_config: {}
|
7 |
-
machine_rank: 0
|
8 |
-
main_process_ip: 100.64.24.6
|
9 |
-
main_process_port: 12343
|
10 |
-
main_training_function: main
|
11 |
-
num_machines: 2
|
12 |
-
num_processes: 16
|
13 |
-
rdzv_backend: static
|
14 |
-
same_network: true
|
15 |
-
tpu_env: []
|
16 |
-
tpu_use_cluster: false
|
17 |
-
tpu_use_sudo: false
|
18 |
-
use_cpu: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util/deepspeed_configs/accelerate_config_machine_single.yaml
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
compute_environment: LOCAL_MACHINE
|
2 |
-
distributed_type: DEEPSPEED
|
3 |
-
deepspeed_config:
|
4 |
-
deepspeed_config_file: util/deepspeed_configs/zero_stage2_config.json
|
5 |
-
fsdp_config: {}
|
6 |
-
machine_rank: 0
|
7 |
-
main_process_ip: null
|
8 |
-
main_process_port: 12345
|
9 |
-
main_training_function: main
|
10 |
-
num_machines: 1
|
11 |
-
num_processes: 8
|
12 |
-
gpu_ids: 0,1,2,3,4,5,6,7
|
13 |
-
use_cpu: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
util/deepspeed_configs/hostfile.txt
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
[email protected] slots=8
|
2 |
-
[email protected] slots=8
|
|
|
|
|
|
util/deepspeed_configs/zero_stage2_config.json
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"bf16": {
|
3 |
-
"enabled": true
|
4 |
-
},
|
5 |
-
"train_micro_batch_size_per_gpu": "auto",
|
6 |
-
"train_batch_size": "auto",
|
7 |
-
"gradient_clipping": 1.0,
|
8 |
-
"gradient_accumulation_steps": "auto",
|
9 |
-
"dump_state": true,
|
10 |
-
"zero_optimization": {
|
11 |
-
"stage": 2,
|
12 |
-
"overlap_comm": true,
|
13 |
-
"contiguous_gradients": true,
|
14 |
-
"sub_group_size": 1e9,
|
15 |
-
"reduce_bucket_size": 5e8
|
16 |
-
}
|
17 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|