File size: 11,332 Bytes
3c3b47c 8f9d4fd 3c3b47c 8f9d4fd 3c3b47c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
from torch.utils.data import Dataset
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torchaudio
import os
import logging
from torchvision.models import resnet50, ResNet50_Weights, resnet152, resnet18, resnet34, ResNet152_Weights
from PIL import Image
from time import strftime
import math
import numpy as np
import moviepy.editor as mpe
class VideoDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.data_map = []
dir_map = os.listdir(data_dir)
for d in dir_map:
name, extension = os.path.splitext(d)
if extension == ".mp4":
self.data_map.append({"video": os.path.join(data_dir, d)})
def __len__(self):
return len(self.data_map)
def __getitem__(self, idx):
return self.data_map[idx]["video"]
# input: video_path, output: wav_music
class VideoToT5(nn.Module):
def __init__(self,
device: str,
video_extraction_framerate: int,
encoder_input_dimension: int,
encoder_output_dimension: int,
encoder_heads: int,
encoder_dim_feedforward: int,
encoder_layers: int
):
super().__init__()
self.video_extraction_framerate = video_extraction_framerate
self.video_feature_extractor = VideoFeatureExtractor(video_extraction_framerate=video_extraction_framerate,
device=device)
self.video_encoder = VideoEncoder(
device,
encoder_input_dimension,
encoder_output_dimension,
encoder_heads,
encoder_dim_feedforward,
encoder_layers
)
def forward(self, video_paths: [str]):
image_embeddings = []
for video_path in video_paths:
video = mpe.VideoFileClip(video_path)
video_embedding = self.video_feature_extractor(video)
image_embeddings.append(video_embedding)
video_embedding = torch.stack(
image_embeddings) # resulting shape: [batch_size, video_extraction_framerate, resnet_output_dimension]
# not used, gives worse results!
# video_embeddings = torch.mean(video_embeddings, 0, True) # average out all image embedding to one video embedding
t5_embeddings = self.video_encoder(video_embedding) # T5 output: [batch_size, num_tokens,
# t5_embedding_size]
return t5_embeddings
class VideoEncoder(nn.Module):
def __init__(self,
device: str,
encoder_input_dimension: int,
encoder_output_dimension: int,
encoder_heads: int,
encoder_dim_feedforward: int,
encoder_layers: int
):
super().__init__()
self.device = device
self.encoder = (nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=encoder_input_dimension,
nhead=encoder_heads,
dim_feedforward=encoder_dim_feedforward
),
num_layers=encoder_layers,
)
).to(device)
# linear layer to match T5 embedding dimension
self.linear = (nn.Linear(
in_features=encoder_input_dimension,
out_features=encoder_output_dimension)
.to(device))
def forward(self, x):
assert x.dim() == 3
x = torch.transpose(x, 0, 1) # encoder expects [sequence_length, batch_size, embedding_dimension]
x = self.encoder(x) # encoder forward pass
x = self.linear(x) # forward pass through the linear layer
x = torch.transpose(x, 0, 1) # shape: [batch_size, sequence_length, embedding_dimension]
return x
class VideoFeatureExtractor(nn.Module):
def __init__(self,
device: str,
video_extraction_framerate: int = 1,
resnet_output_dimension: int = 2048):
super().__init__()
self.device = device
# using a ResNet trained on ImageNet
self.resnet = resnet50(weights="IMAGENET1K_V2").eval()
self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])).to(device) # remove ResNet layer
self.resnet_preprocessor = ResNet50_Weights.DEFAULT.transforms().to(device)
self.video_extraction_framerate = video_extraction_framerate # setting the fps at which the video is processed
self.positional_encoder = PositionalEncoding(resnet_output_dimension).to(device)
def forward(self, video: mpe.VideoFileClip):
embeddings = []
for i in range(0, 30 * self.video_extraction_framerate):
i = video.get_frame(i) # get frame as numpy array
i = Image.fromarray(i) # create PIL image from numpy array
i = self.resnet_preprocessor(i) # preprocess image
i = i.to(self.device)
i = i.unsqueeze(0) # adding a batch dimension
i = self.resnet(i).squeeze() # ResNet forward pass
i = i.squeeze()
embeddings.append(i) # collect embeddings
embeddings = torch.stack(embeddings) # concatenate all frame embeddings into one video embedding
embeddings = embeddings.unsqueeze(1)
embeddings = self.positional_encoder(embeddings) # apply positional encoding with a sequence length of 30
embeddings = embeddings.squeeze()
return embeddings
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_length: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_length).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_length, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
x = x + self.pe[:x.size(0)]
return self.dropout(x)
def freeze_model(model: nn.Module):
for param in model.parameters():
param.requires_grad = False
model.eval()
def split_dataset_randomly(dataset, validation_split: float, test_split: float, seed: int = None):
dataset_size = len(dataset)
indices = list(range(dataset_size))
datapoints_validation = int(np.floor(validation_split * dataset_size))
datapoints_testing = int(np.floor(test_split * dataset_size))
if seed:
np.random.seed(seed)
np.random.shuffle(indices) # in-place operation
training = indices[datapoints_validation + datapoints_testing:]
validation = indices[datapoints_validation:datapoints_testing + datapoints_validation]
testing = indices[:datapoints_testing]
assert len(validation) == datapoints_validation, "Validation set length incorrect"
assert len(testing) == datapoints_testing, "Testing set length incorrect"
assert len(training) == dataset_size - (datapoints_testing + datapoints_testing), "Training set length incorrect"
assert not any([item in training for item in validation]), "Training and Validation overlap"
assert not any([item in training for item in testing]), "Training and Testing overlap"
assert not any([item in validation for item in testing]), "Validation and Testing overlap"
return training, validation, testing
### private function from audiocraft.solver.musicgen.py => _compute_cross_entropy
def compute_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor):
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
Returns:
ce (torch.Tensor): Cross entropy averaged over the codebooks
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
"""
B, K, T = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
# average cross entropy across codebooks
ce = ce / K
return ce, ce_per_codebook
def generate_audio_codes(audio_paths: [str],
audiocraft_compression_model: torch.nn.Module,
device: str) -> torch.Tensor:
audio_duration = 30
encodec_sample_rate = audiocraft_compression_model.sample_rate
torch_audios = []
for audio_path in audio_paths:
wav, original_sample_rate = torchaudio.load(audio_path) # load audio from file
wav = torchaudio.functional.resample(wav, original_sample_rate,
encodec_sample_rate) # cast audio to model sample rate
wav = wav[:, :encodec_sample_rate * audio_duration] # enforce an exact audio length of 30 seconds
assert len(wav.shape) == 2, f"audio data is not of shape [channels, duration]"
assert wav.shape[0] == 2, "audio data should be in stereo, but has not 2 channels"
torch_audios.append(wav)
torch_audios = torch.stack(torch_audios)
torch_audios = torch_audios.to(device)
with torch.no_grad():
gen_audio = audiocraft_compression_model.encode(torch_audios)
codes, scale = gen_audio
assert scale is None
return codes
def create_condition_tensors(
video_embeddings: torch.Tensor,
batch_size: int,
video_extraction_framerate: int,
device: str
):
# model T5 mask
mask = torch.ones((batch_size, video_extraction_framerate * 30), dtype=torch.int).to(device)
condition_tensors = {
'description': (video_embeddings, mask)
}
return condition_tensors
def get_current_timestamp():
return strftime("%Y_%m_%d___%H_%M_%S")
def configure_logging(output_dir: str, filename: str, log_level):
# create logs folder, if not existing
os.makedirs(output_dir, exist_ok=True)
level = getattr(logging, log_level)
file_path = output_dir + "/" + filename
logging.basicConfig(filename=file_path, encoding='utf-8', level=level)
logger = logging.getLogger()
# only add a StreamHandler if it is not present yet
if len(logger.handlers) <= 1:
logger.addHandler(logging.StreamHandler())
|