anchorxia's picture
add mmcm
a57c6eb
import random
import torch
from torch import nn
import torch.nn.functional as F
import ffmpeg
import numpy as np
import cv2
from moviepy.editor import VideoFileClip
from .utils import get_frames
class TransNetV2(nn.Module):
def __init__(self, F=16, L=3, S=2, D=1024):
super(TransNetV2, self).__init__()
self.SDDCNN = nn.ModuleList(
[
StackedDDCNNV2(
in_filters=3, n_blocks=S, filters=F, stochastic_depth_drop_prob=0.0
)
]
+ [
StackedDDCNNV2(
in_filters=(F * 2 ** (i - 1)) * 4, n_blocks=S, filters=F * 2**i
)
for i in range(1, L)
]
)
# 帧相似网络
self.frame_sim_layer = FrameSimilarity(
sum([(F * 2**i) * 4 for i in range(L)]),
lookup_window=101,
output_dim=128,
similarity_dim=128,
use_bias=True,
)
# 颜色相似网络
self.color_hist_layer = ColorHistograms(lookup_window=101, output_dim=128)
# dropout
self.dropout = nn.Dropout(0.5)
output_dim = ((F * 2 ** (L - 1)) * 4) * 3 * 6 #
output_dim = output_dim + 128 # 使用了帧相似网络, 维度需要加128
output_dim = output_dim + 128 # 使用了颜色相似网络, 维度需要再加128
self.fc1 = nn.Linear(output_dim, D)
self.cls_layer1 = nn.Linear(D, 1)
self.cls_layer2 = nn.Linear(D, 1)
def forward(self, inputs):
# 输入必须为torch.uint8, (h,w)=(27,48)的图片batch样本
# assert isinstance(inputs, torch.Tensor) and list(inputs.shape[2:]) == [27, 48, 3] and inputs.dtype == torch.uint8, "incorrect input type and/or shape"
# uint8 of shape [B, T, H, W, 3] to float of shape [B, 3, T, H, W]
with torch.autograd.set_detect_anomaly(True):
x = inputs.permute([0, 4, 1, 2, 3]).float()
x = x.div_(255.0)
# 收集每一层的SDDCNN特征图
block_features = []
for block in self.SDDCNN:
x = block(x)
block_features.append(x)
x = x.permute(0, 2, 3, 4, 1) # 把维度从[B, 通道数, T, H, W] 转化为 [B, T, H, W, 通道数]
x = x.reshape(x.shape[0], x.shape[1], -1)
x = torch.cat(
[self.frame_sim_layer(block_features), x], 2
) # 在最后一维度cat上block_features输出的特征
x = torch.cat(
[self.color_hist_layer(inputs), x], 2
) # 在最后一维度cat上color_hist_layer输出的特征
x = F.relu(self.fc1(x))
x = self.dropout(x)
one_hot = self.cls_layer1(x)
many_hot = self.cls_layer2(x)
return one_hot, many_hot
# 预测MP4文件转换帧,并给出对应帧位置
def predict_video(
self,
mp4_file,
cache_path="",
c_box=None,
width=48,
height=27,
input_frames=100,
overlap=30,
sample_fps=30,
threshold=0.3,
):
"""
mp4_file: ~/6712566330782010632.mp4
cache_path: ~/视频单帧数据_h48_w27
return: [x,x,...] 点位时间
"""
assert overlap % 2 == 0
assert input_frames > overlap
# fps = eval(ffmpeg.probe(mp4_file)['streams'][0]['r_frame_rate']) # 获取视频的视频帧率
# total_frames = int(ffmpeg.probe(mp4_file)['streams'][0]['nb_frames']) # 获取视频的总帧数
# duration = float(ffmpeg.probe(mp4_file)['streams'][0]['duration']) # 获取视频的总时长
video = VideoFileClip(mp4_file)
# video = video.subclip(0, 60 * 10)
fps = video.fps
duration = video.duration
total_frames = int(duration * fps)
w, h = video.size
print(fps, duration, total_frames, w, h)
if c_box:
video.crop(*c_box)
frame_iter = video.iter_frames(fps=sample_fps)
sample_total_frames = int(sample_fps * duration)
frame_list = []
for i in range(sample_total_frames // (input_frames - overlap) + 1):
# if i==1:
# break
frame_list = frame_list[-overlap:]
start_frame = i * (input_frames - overlap)
end_frame = min(start_frame + input_frames, sample_total_frames)
print("start_frame & end_frame: ", start_frame, end_frame)
for frame in frame_iter:
frame = cv2.resize(frame, (width, height))
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
frame_list.append(frame)
if len(frame_list) == end_frame - start_frame:
break
frames = torch.Tensor(frame_list) # 获得帧
if frames.shape[0] < end_frame - start_frame:
# 原视频的视频时长比音频时长短,体现出来的是原视频最后有声音没画面
print(
"total_frames is wrong: ",
total_frames,
"-->",
start_frame + frames.shape[0],
)
# sample_total_frames = start_frame + frames.shape[0]
# fps = total_frames / duration
frames = frames.cuda()
# single_frame_pred和all_frame_pred都是输出window_size长的是否转场概率,
single_frame_pred, all_frame_pred = self.forward(
frames.unsqueeze(0)
) # 前向推理
# single_frame_pred = F.softmax(single_frame_pred, dim=-1) # 获得每一帧对应的类别概率
# single_frame_pred = torch.argmax(single_frame_pred, dim=-1).reshape(-1)
single_frame_pred = torch.sigmoid(single_frame_pred).reshape(-1)
all_frame_pred = torch.sigmoid(all_frame_pred).reshape(-1)
# single_frame_pred = (single_frame_pred>threshold)*1
if total_frames > end_frame:
if i == 0:
single_frame_pred_label = single_frame_pred[: -overlap // 2]
all_frame_pred_label = all_frame_pred[: -overlap // 2]
else:
single_frame_pred_label = torch.cat(
(
single_frame_pred_label,
single_frame_pred[overlap // 2 : -overlap // 2],
),
dim=0,
)
all_frame_pred_label = torch.cat(
(
all_frame_pred_label,
all_frame_pred[overlap // 2 : -overlap // 2],
),
dim=0,
)
else:
if i == 0:
single_frame_pred_label = single_frame_pred
all_frame_pred_label = all_frame_pred
else:
single_frame_pred_label = torch.cat(
(single_frame_pred_label, single_frame_pred[overlap // 2 :]),
dim=0,
)
all_frame_pred_label = torch.cat(
(all_frame_pred_label, all_frame_pred[overlap // 2 :]), dim=0
)
break
single_frame_pred_label = single_frame_pred_label.cpu().numpy()
all_frame_pred_label = all_frame_pred_label.cpu().numpy()
return (
single_frame_pred_label,
all_frame_pred_label,
fps,
total_frames,
duration,
h,
w,
)
# transition_index = torch.where(pred_label==1)[0].cpu().numpy() # 转场帧位置
# transition_index = transition_index.astype(np.float)
# # 对返回结果做后处理合并相邻帧
# result_transition = []
# for i, transition in enumerate(transition_index):
# if i == 0:
# result_transition.append([transition])
# else:
# if abs(result_transition[-1][-1]-transition) == 1:
# result_transition[-1].append(transition)
# else:
# result_transition.append([transition])
#
# result_transition = [[0]] + [[item[0], item[-1]] if len(item)>1 else [item[0]] for item in result_transition] + [[total_frames]]
#
# return result_transition, fps, total_frames, duration, h, w
def predict_video_2(
self,
mp4_file,
cache_path="",
c_box=None,
width=48,
height=27,
input_frames=100,
overlap=30,
sample_fps=30,
threshold=0.3,
):
"""
mp4_file: ~/6712566330782010632.mp4
cache_path: ~/视频单帧数据_h48_w27
return: [x,x,...] 点位时间
"""
assert overlap % 2 == 0
assert input_frames > overlap
# fps = eval(ffmpeg.probe(mp4_file)['streams'][0]['r_frame_rate']) # 获取视频的视频帧率
# total_frames = int(ffmpeg.probe(mp4_file)['streams'][0]['nb_frames']) # 获取视频的总帧数
# duration = float(ffmpeg.probe(mp4_file)['streams'][0]['duration']) # 获取视频的总时长
video = VideoFileClip(mp4_file)
# video = video.subclip(0, 60 * 10)
fps = video.fps
duration = video.duration
total_frames = int(duration * fps)
w, h = video.size
print(fps, duration, total_frames, w, h)
if c_box:
video.crop(*c_box)
frame_iter = video.iter_frames(fps=sample_fps)
sample_total_frames = int(sample_fps * duration)
frame_list = []
for i in range(sample_total_frames // (input_frames - overlap) + 1):
# if i==1:
# break
frame_list = frame_list[-overlap:]
start_frame = i * (input_frames - overlap)
end_frame = min(start_frame + input_frames, sample_total_frames)
print("start_frame & end_frame: ", start_frame, end_frame)
for frame in frame_iter:
frame = cv2.resize(frame, (width, height))
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
frame_list.append(frame)
if len(frame_list) == end_frame - start_frame:
break
frames = torch.Tensor(frame_list) # 获得帧
if frames.shape[0] < end_frame - start_frame:
# 原视频的视频时长比音频时长短,体现出来的是原视频最后有声音没画面
print(
"total_frames is wrong: ",
total_frames,
"-->",
start_frame + frames.shape[0],
)
# sample_total_frames = start_frame + frames.shape[0]
# fps = total_frames / duration
frames = frames.cuda()
single_frame_pred, all_frame_pred = self.forward(
frames.unsqueeze(0)
) # 前向推理
# single_frame_pred = F.softmax(single_frame_pred, dim=-1) # 获得每一帧对应的类别概率
# single_frame_pred = torch.argmax(single_frame_pred, dim=-1).reshape(-1)
single_frame_pred = torch.sigmoid(single_frame_pred).reshape(-1)
all_frame_pred = torch.sigmoid(all_frame_pred).reshape(-1)
# single_frame_pred = (single_frame_pred>threshold)*1
if total_frames > end_frame:
if i == 0:
single_frame_pred_label = single_frame_pred[: -overlap // 2]
all_frame_pred_label = all_frame_pred[: -overlap // 2]
else:
single_frame_pred_label = torch.cat(
(
single_frame_pred_label,
single_frame_pred[overlap // 2 : -overlap // 2],
),
dim=0,
)
all_frame_pred_label = torch.cat(
(
all_frame_pred_label,
all_frame_pred[overlap // 2 : -overlap // 2],
),
dim=0,
)
else:
if i == 0:
single_frame_pred_label = single_frame_pred
all_frame_pred_label = all_frame_pred
else:
single_frame_pred_label = torch.cat(
(single_frame_pred_label, single_frame_pred[overlap // 2 :]),
dim=0,
)
all_frame_pred_label = torch.cat(
(all_frame_pred_label, all_frame_pred[overlap // 2 :]), dim=0
)
break
single_frame_pred_label = single_frame_pred_label.cpu().numpy()
all_frame_pred_label = all_frame_pred_label.cpu().numpy()
return (
single_frame_pred_label,
all_frame_pred_label,
fps,
total_frames,
duration,
h,
w,
)
class StackedDDCNNV2(nn.Module):
def __init__(
self,
in_filters,
n_blocks,
filters,
shortcut=True,
pool_type="avg",
stochastic_depth_drop_prob=0.0,
):
super(StackedDDCNNV2, self).__init__()
self.shortcut = shortcut
# 定义DDCNN层
self.DDCNN = nn.ModuleList(
[
DilatedDCNNV2(
in_filters if i == 1 else filters * 4,
filters,
activation=F.relu if i != n_blocks else None,
)
for i in range(1, n_blocks + 1)
]
) # 有n_blocks层数量的DilateDCNNV2模块
# 定义pool层
self.pool = (
nn.MaxPool3d(kernel_size=(1, 2, 2))
if pool_type == "max"
else nn.AvgPool3d(kernel_size=(1, 2, 2))
)
self.stochastic_depth_drop_prob = stochastic_depth_drop_prob
def forward(self, inputs):
x = inputs
shortcut = None
# DDCNN层前向传播
for block in self.DDCNN:
x = block(x)
if shortcut is None: # 记录第一层的结果作为残差连接
shortcut = x
x = F.relu(x)
if self.shortcut is not None:
if self.stochastic_depth_drop_prob != 0.0:
if self.training:
if random.random() < self.stochastic_depth_drop_prob:
x = shortcut
else:
x = x + shortcut
else:
x = (1 - self.stochastic_depth_drop_prob) * x + shortcut
else:
x = x + shortcut
x = self.pool(x)
return x
class DilatedDCNNV2(nn.Module):
def __init__(self, in_filters, filters, batch_norm=True, activation=None):
super(DilatedDCNNV2, self).__init__()
self.Conv3D_1 = Conv3DConfigurable(
in_filters, filters, 1, use_bias=not batch_norm
)
self.Conv3D_2 = Conv3DConfigurable(
in_filters, filters, 2, use_bias=not batch_norm
)
self.Conv3D_4 = Conv3DConfigurable(
in_filters, filters, 4, use_bias=not batch_norm
)
self.Conv3D_8 = Conv3DConfigurable(
in_filters, filters, 8, use_bias=not batch_norm
)
self.bn = nn.BatchNorm3d(filters * 4, eps=1e-3) if batch_norm else None
self.activation = activation # 激活函数定义
def forward(self, inputs):
conv1 = self.Conv3D_1(inputs)
conv2 = self.Conv3D_2(inputs)
conv3 = self.Conv3D_4(inputs)
conv4 = self.Conv3D_8(inputs)
x = torch.cat([conv1, conv2, conv3, conv4], dim=1)
if self.bn is not None:
x = self.bn(x)
if self.activation is not None:
x = self.activation(x)
return x
class Conv3DConfigurable(nn.Module):
def __init__(
self, in_filters, filters, dilation_rate, separable=True, use_bias=True
):
super(Conv3DConfigurable, self).__init__()
if separable:
# (2+1)D convolution https://arxiv.org/pdf/1711.11248.pdf
conv1 = nn.Conv3d(
in_filters,
2 * filters,
kernel_size=(1, 3, 3),
dilation=(1, 1, 1),
padding=(0, 1, 1),
bias=False,
)
conv2 = nn.Conv3d(
2 * filters,
filters,
kernel_size=(3, 1, 1),
dilation=(dilation_rate, 1, 1),
padding=(dilation_rate, 0, 0),
bias=use_bias,
)
self.layers = nn.ModuleList([conv1, conv2])
else:
conv = nn.Conv3d(
in_filters,
filters,
kernel_size=3,
dilation=(dilation_rate, 1, 1),
padding=(dilation_rate, 1, 1),
bias=use_bias,
)
self.layers = nn.ModuleList([conv])
def forward(self, inputs):
x = inputs
for layer in self.layers:
x = layer(x)
return x
# 帧相似网络构建
class FrameSimilarity(nn.Module):
def __init__(
self,
in_filters,
similarity_dim=128,
lookup_window=101,
output_dim=128,
use_bias=False,
):
super(FrameSimilarity, self).__init__()
self.projection = nn.Linear(in_filters, similarity_dim, bias=use_bias)
self.fc = nn.Linear(lookup_window, output_dim)
self.lookup_window = lookup_window
assert lookup_window % 2 == 1, "`lookup_window` must be odd integer"
def forward(self, inputs):
x = torch.cat([torch.mean(x, dim=[3, 4]) for x in inputs], dim=1)
x = torch.transpose(x, 1, 2)
x = self.projection(x)
x = F.normalize(x, p=2, dim=2)
batch_size, time_window = x.shape[0], x.shape[1]
similarities = torch.bmm(
x, x.transpose(1, 2)
) # [batch_size, time_window, time_window]余弦相似度
similarities_padded = F.pad(
similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2]
)
batch_indices = (
torch.arange(0, batch_size, device=x.device)
.view([batch_size, 1, 1])
.repeat([1, time_window, self.lookup_window])
)
time_indices = (
torch.arange(0, time_window, device=x.device)
.view([1, time_window, 1])
.repeat([batch_size, 1, self.lookup_window])
)
lookup_indices = (
torch.arange(0, self.lookup_window, device=x.device)
.view([1, 1, self.lookup_window])
.repeat([batch_size, time_window, 1])
+ time_indices
)
similarities = similarities_padded[batch_indices, time_indices, lookup_indices]
return F.relu(self.fc(similarities))
# 颜色相似网络
class ColorHistograms(nn.Module):
def __init__(self, lookup_window=101, output_dim=None):
super(ColorHistograms, self).__init__()
self.fc = (
nn.Linear(lookup_window, output_dim) if output_dim is not None else None
)
self.lookup_window = lookup_window
assert lookup_window % 2 == 1, "`lookup_window` must be odd integer"
@staticmethod
def compute_color_histograms(frames):
frames = frames.int()
def get_bin(frames):
# returns 0 .. 511
R, G, B = frames[:, :, 0], frames[:, :, 1], frames[:, :, 2]
R, G, B = R >> 5, G >> 5, B >> 5
return (R << 6) + (G << 3) + B
batch_size, time_window, height, width, no_channels = frames.shape
assert no_channels == 3
frames_flatten = frames.view(batch_size * time_window, height * width, 3)
binned_values = get_bin(frames_flatten)
frame_bin_prefix = (
torch.arange(0, batch_size * time_window, device=frames.device) << 9
).view(-1, 1)
binned_values = (binned_values + frame_bin_prefix).view(-1)
histograms = torch.zeros(
batch_size * time_window * 512, dtype=torch.int32, device=frames.device
)
histograms.scatter_add_(
0,
binned_values,
torch.ones(len(binned_values), dtype=torch.int32, device=frames.device),
)
histograms = histograms.view(batch_size, time_window, 512).float()
histograms_normalized = F.normalize(histograms, p=2, dim=2)
return histograms_normalized
def forward(self, inputs):
x = self.compute_color_histograms(inputs)
batch_size, time_window = x.shape[0], x.shape[1]
similarities = torch.bmm(
x, x.transpose(1, 2)
) # [batch_size, time_window, time_window]
similarities_padded = F.pad(
similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2]
)
batch_indices = (
torch.arange(0, batch_size, device=x.device)
.view([batch_size, 1, 1])
.repeat([1, time_window, self.lookup_window])
)
time_indices = (
torch.arange(0, time_window, device=x.device)
.view([1, time_window, 1])
.repeat([batch_size, 1, self.lookup_window])
)
lookup_indices = (
torch.arange(0, self.lookup_window, device=x.device)
.view([1, 1, self.lookup_window])
.repeat([batch_size, time_window, 1])
+ time_indices
)
similarities = similarities_padded[batch_indices, time_indices, lookup_indices]
if self.fc is not None:
return F.relu(self.fc(similarities))
return similarities