anchorxia's picture
add mmcm
a57c6eb
import warnings
import logging
import os
import pickle
import copy
import time
import numpy as np
import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F
from .TransNetmodels import TransNetV2
warnings.filterwarnings("ignore")
import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
## 工具函数
def complete_results_batch(
mp4_ids,
batch_mp4_scenes_index,
fps_batch,
single_frame_pred,
class_threshold,
cache_file="/data_share7/v_hyggewang/视频切换模型依赖数据/转场真实标签字典.pkl",
):
"""
single_frame_pred: [片段数, 100, 2]
return:[[xs,xs,xs...],[xs,xs..]]每个元素是对应视频的真实值
"""
cache = pickle.load(open(cache_file, "rb")) # 读取储存好的MP4真实标签
pre_index = 0
result = []
for mp4_id, index, fps in zip(mp4_ids, batch_mp4_scenes_index, fps_batch):
raw_transition_index = single_frame_pred[
pre_index : (pre_index + int(index)), 15:-15, :
].reshape(
-1, 70, 2
) # 这里得到15-85,85-155...帧信息具体切割参看dataset中验证集数据生成。
raw_transition_index = F.softmax(raw_transition_index, dim=-1) # 获得每一帧对应的类别概率
zero = torch.zeros_like(raw_transition_index)
one = torch.ones_like(raw_transition_index)
raw_transition_index = torch.where(
raw_transition_index < class_threshold, zero, one
)[
:, :, -1
] # 只获取属于1标签的预测结果
pred_label = raw_transition_index.reshape(-1) # 得到所有帧的结果
# raw_transition_index = F.softmax(raw_transition_index, dim=-1) # 获得每一帧对应的类别概率
# pred_label = torch.argmax(raw_transition_index, dim=-1).reshape(-1) # 得到最终类别
transition_index = (
torch.where(pred_label == 1)[0] + 15
) / fps # 转场帧位置(前15帧需要加入)
# 对返回结果做后处理合并相邻帧
result_transition = []
for i, transition in enumerate(transition_index):
if i == 0:
result_transition.append([transition])
else:
if abs(result_transition[-1][-1] - transition) < 0.035:
result_transition[-1].append(transition)
else:
result_transition.append([transition])
result_transition_ = [
np.mean(item, dtype=np.float16) for item in result_transition
] # 得到最终预测结果
mp4_GT_label_transition = cache[int(mp4_id)] # 储存MP4过渡转场真实标签
result.append({"真实标签": mp4_GT_label_transition, "预测标签": result_transition_})
pre_index = pre_index + int(index)
return result
### 工具函数
def pr_call(label_list, thresholds=[0.1, 0.3, 0.5, 0.7]):
"""
根据时间误差返回各个时间误差情况下的,召回度和准确度
"""
correct_num_dict = {threshold: 0 for threshold in thresholds} # 记录各个阈值下准确预测个数
result = {threshold: None for threshold in thresholds} # 记录各个阈值下,准确度和召回度
pre_positive_num = 0 # 所有样本预测正例个数
GT_positive_num = 0 # 所有样本真实正例个数
for label_dic in label_list:
true_labels, pre_labels = label_dic["真实标签"], label_dic["预测标签"]
pre_positive_num += len(pre_labels)
GT_positive_num += len(true_labels)
for threshold in thresholds:
pre_label_used = set() # 记录已经匹配的预测标签防止重复匹配
for true_label in true_labels:
matched = False # 真值是否被匹配上了
for pre_label in pre_labels:
if pre_label > true_label + threshold: # 如果预测值大于了阈值范围,则跳过剩下的预测值
break
if pre_label in pre_label_used: # 如果该标签已经被匹配上了则跳过匹配
continue
if (
(true_label - threshold)
<= pre_label
<= (true_label + threshold)
):
correct_num_dict[threshold] += 1
matched = True
if matched: # 如果真值已经被匹配上了,则跳过剩下的预测值
pre_label_used.add(pre_label) # 增加已经匹配上的标签
break
for item in correct_num_dict.items():
result[item[0]] = {
"precision": item[1] / (pre_positive_num + 1e-8),
"recall": item[1] / (GT_positive_num + 1e-8),
}
return result
class MInterface(pl.LightningModule):
def __init__(self, args):
super().__init__()
logger.info("TransNetV2 模型初始化开始...")
self.args = args
self.batch_size = self.args.batch_size
self.learning_rate = self.args.lr
self.model = TransNetV2()
## 参数初始化
for m in self.model.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
## 使用原始权重初始化
if self.args.raw_transnet_weights is not None:
checkpoint = torch.load(self.args.raw_transnet_weights)
del checkpoint["cls_layer1.weight"]
del checkpoint["cls_layer1.bias"]
del checkpoint["cls_layer2.weight"]
del checkpoint["cls_layer2.bias"]
self.model.load_state_dict(checkpoint, strict=False)
print("载入原始模型权重")
logger.info("TransNetV2 模型初始化结束")
def training_step(self, batch, batch_idx):
frames, one_hot_gt, many_hot_gt = (
batch["frames"],
batch["one_hot"],
batch["many_hot"],
)
single_frame_pred, all_frame_pred = self.model(frames)
return single_frame_pred, all_frame_pred, one_hot_gt, many_hot_gt
def training_step_end(self, output):
(
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
) = output # single_frame_pred维度为[片段数, 100, 3],one_hot_gt维度为[片段数, 100]
loss_one = F.cross_entropy(
single_frame_pred[:, 15:-15, :].reshape(-1, 2),
one_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as(
single_frame_pred
),
)
loss_all = F.cross_entropy(
all_frame_pred[:, 15:-15, :].reshape(-1, 2),
many_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=all_frame_pred.device).type_as(
all_frame_pred
),
)
loss_total = loss_one * 0.9 + loss_all * 0.1
self.log(
"train_loss",
loss_total,
on_epoch=True,
on_step=True,
prog_bar=True,
logger=True,
)
return loss_total
def validation_step(self, batch, batch_idx):
frames, one_hot_gt, many_hot_gt = (
batch["frames"],
batch["one_hot"],
batch["many_hot"],
)
single_frame_pred, all_frame_pred = self.model(frames)
mp4_ids = batch["mp4_ids"]
batch_mp4_scenes_index = batch["batch_mp4_scenes_index"]
fps_batch = batch["fps_batch"]
return (
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
mp4_ids,
batch_mp4_scenes_index,
fps_batch,
)
def validation_step_end(self, output):
(
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
mp4_ids,
_,
_,
) = output
# loss_one = self.lossfun(single_frame_pred.reshape(-1,3), one_hot_gt.reshape(-1))
# loss_all = self.lossfun(all_frame_pred.reshape(-1,3), many_hot_gt.reshape(-1))
loss_one = F.cross_entropy(
single_frame_pred[:, 15:-15, :].reshape(-1, 2),
one_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as(
single_frame_pred
),
)
loss_all = F.cross_entropy(
all_frame_pred[:, 15:-15, :].reshape(-1, 2),
many_hot_gt[:, 15:-15].reshape(-1),
weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as(
single_frame_pred
),
)
loss_total = loss_one * 0.8 + loss_all * 0.2
self.log(
"val_loss",
loss_total,
on_epoch=True,
on_step=True,
prog_bar=True,
logger=True,
)
def validation_epoch_end(self, output):
start = time.time()
class_threshold_list = [0.1, 0.3, 0.5, 0.7]
# 计算每个不同的class_threshold下召准
for class_threshold in class_threshold_list:
transition_label_list = []
for output_each in output:
(
single_frame_pred,
all_frame_pred,
one_hot_gt,
many_hot_gt,
mp4_ids,
batch_mp4_scenes_index,
fps_batch,
) = output_each
transition_label_list = transition_label_list + complete_results_batch(
mp4_ids.cpu(),
batch_mp4_scenes_index.cpu(),
fps_batch.cpu(),
single_frame_pred.cpu().float(),
class_threshold,
)
custom_indicator = pr_call(
transition_label_list, thresholds=[0.05, 0.1, 0.2, 0.3]
)
self.log(
f"{class_threshold}_0.01s_P",
custom_indicator[0.05]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.01s_R",
custom_indicator[0.05]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.1s_P",
custom_indicator[0.1]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.1s_R",
custom_indicator[0.1]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.2s_P",
custom_indicator[0.2]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.2s_R",
custom_indicator[0.2]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.3s_P",
custom_indicator[0.3]["precision"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
self.log(
f"{class_threshold}_0.3s_R",
custom_indicator[0.3]["recall"],
on_epoch=True,
on_step=False,
prog_bar=False,
logger=True,
)
print("推理耗时:{}".format(time.time() - start))
## 优化器配置
def configure_optimizers(self):
logger.info("configure_optimizers 初始化开始...")
# 选择优化器
if self.args.optim == "SGD":
optimizer = torch.optim.SGD(
self.parameters(), lr=self.learning_rate, momentum=0.9
)
else:
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
# 选择学习率调度方式
if self.args.lr_scheduler == "OneCycleLR":
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=0.0002, verbose=True, epochs=500, steps_per_epoch=7
)
logger.info("configure_optimizers 初始化结束...")
return [optimizer], [scheduler]
elif self.args.lr_scheduler == "CosineAnnealingLR":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=200, eta_min=5e-7, verbose=True, last_epoch=-1
)
logger.info("configure_optimizers 初始化结束...")
return [optimizer], [scheduler]
elif self.args.lr_scheduler == "None":
logger.info("configure_optimizers 初始化结束...")
return optimizer