import warnings import logging import os import pickle import copy import time import numpy as np import pytorch_lightning as pl import torch from torch import nn import torch.nn.functional as F from .TransNetmodels import TransNetV2 warnings.filterwarnings("ignore") import logging logger = logging.getLogger(__name__) # pylint: disable=invalid-name ## 工具函数 def complete_results_batch( mp4_ids, batch_mp4_scenes_index, fps_batch, single_frame_pred, class_threshold, cache_file="/data_share7/v_hyggewang/视频切换模型依赖数据/转场真实标签字典.pkl", ): """ single_frame_pred: [片段数, 100, 2] return:[[xs,xs,xs...],[xs,xs..]]每个元素是对应视频的真实值 """ cache = pickle.load(open(cache_file, "rb")) # 读取储存好的MP4真实标签 pre_index = 0 result = [] for mp4_id, index, fps in zip(mp4_ids, batch_mp4_scenes_index, fps_batch): raw_transition_index = single_frame_pred[ pre_index : (pre_index + int(index)), 15:-15, : ].reshape( -1, 70, 2 ) # 这里得到15-85,85-155...帧信息具体切割参看dataset中验证集数据生成。 raw_transition_index = F.softmax(raw_transition_index, dim=-1) # 获得每一帧对应的类别概率 zero = torch.zeros_like(raw_transition_index) one = torch.ones_like(raw_transition_index) raw_transition_index = torch.where( raw_transition_index < class_threshold, zero, one )[ :, :, -1 ] # 只获取属于1标签的预测结果 pred_label = raw_transition_index.reshape(-1) # 得到所有帧的结果 # raw_transition_index = F.softmax(raw_transition_index, dim=-1) # 获得每一帧对应的类别概率 # pred_label = torch.argmax(raw_transition_index, dim=-1).reshape(-1) # 得到最终类别 transition_index = ( torch.where(pred_label == 1)[0] + 15 ) / fps # 转场帧位置(前15帧需要加入) # 对返回结果做后处理合并相邻帧 result_transition = [] for i, transition in enumerate(transition_index): if i == 0: result_transition.append([transition]) else: if abs(result_transition[-1][-1] - transition) < 0.035: result_transition[-1].append(transition) else: result_transition.append([transition]) result_transition_ = [ np.mean(item, dtype=np.float16) for item in result_transition ] # 得到最终预测结果 mp4_GT_label_transition = cache[int(mp4_id)] # 储存MP4过渡转场真实标签 result.append({"真实标签": mp4_GT_label_transition, "预测标签": result_transition_}) pre_index = pre_index + int(index) return result ### 工具函数 def pr_call(label_list, thresholds=[0.1, 0.3, 0.5, 0.7]): """ 根据时间误差返回各个时间误差情况下的,召回度和准确度 """ correct_num_dict = {threshold: 0 for threshold in thresholds} # 记录各个阈值下准确预测个数 result = {threshold: None for threshold in thresholds} # 记录各个阈值下,准确度和召回度 pre_positive_num = 0 # 所有样本预测正例个数 GT_positive_num = 0 # 所有样本真实正例个数 for label_dic in label_list: true_labels, pre_labels = label_dic["真实标签"], label_dic["预测标签"] pre_positive_num += len(pre_labels) GT_positive_num += len(true_labels) for threshold in thresholds: pre_label_used = set() # 记录已经匹配的预测标签防止重复匹配 for true_label in true_labels: matched = False # 真值是否被匹配上了 for pre_label in pre_labels: if pre_label > true_label + threshold: # 如果预测值大于了阈值范围,则跳过剩下的预测值 break if pre_label in pre_label_used: # 如果该标签已经被匹配上了则跳过匹配 continue if ( (true_label - threshold) <= pre_label <= (true_label + threshold) ): correct_num_dict[threshold] += 1 matched = True if matched: # 如果真值已经被匹配上了,则跳过剩下的预测值 pre_label_used.add(pre_label) # 增加已经匹配上的标签 break for item in correct_num_dict.items(): result[item[0]] = { "precision": item[1] / (pre_positive_num + 1e-8), "recall": item[1] / (GT_positive_num + 1e-8), } return result class MInterface(pl.LightningModule): def __init__(self, args): super().__init__() logger.info("TransNetV2 模型初始化开始...") self.args = args self.batch_size = self.args.batch_size self.learning_rate = self.args.lr self.model = TransNetV2() ## 参数初始化 for m in self.model.modules(): if isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.xavier_uniform_(m.weight) ## 使用原始权重初始化 if self.args.raw_transnet_weights is not None: checkpoint = torch.load(self.args.raw_transnet_weights) del checkpoint["cls_layer1.weight"] del checkpoint["cls_layer1.bias"] del checkpoint["cls_layer2.weight"] del checkpoint["cls_layer2.bias"] self.model.load_state_dict(checkpoint, strict=False) print("载入原始模型权重") logger.info("TransNetV2 模型初始化结束") def training_step(self, batch, batch_idx): frames, one_hot_gt, many_hot_gt = ( batch["frames"], batch["one_hot"], batch["many_hot"], ) single_frame_pred, all_frame_pred = self.model(frames) return single_frame_pred, all_frame_pred, one_hot_gt, many_hot_gt def training_step_end(self, output): ( single_frame_pred, all_frame_pred, one_hot_gt, many_hot_gt, ) = output # single_frame_pred维度为[片段数, 100, 3],one_hot_gt维度为[片段数, 100] loss_one = F.cross_entropy( single_frame_pred[:, 15:-15, :].reshape(-1, 2), one_hot_gt[:, 15:-15].reshape(-1), weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as( single_frame_pred ), ) loss_all = F.cross_entropy( all_frame_pred[:, 15:-15, :].reshape(-1, 2), many_hot_gt[:, 15:-15].reshape(-1), weight=torch.tensor([0.15, 0.85], device=all_frame_pred.device).type_as( all_frame_pred ), ) loss_total = loss_one * 0.9 + loss_all * 0.1 self.log( "train_loss", loss_total, on_epoch=True, on_step=True, prog_bar=True, logger=True, ) return loss_total def validation_step(self, batch, batch_idx): frames, one_hot_gt, many_hot_gt = ( batch["frames"], batch["one_hot"], batch["many_hot"], ) single_frame_pred, all_frame_pred = self.model(frames) mp4_ids = batch["mp4_ids"] batch_mp4_scenes_index = batch["batch_mp4_scenes_index"] fps_batch = batch["fps_batch"] return ( single_frame_pred, all_frame_pred, one_hot_gt, many_hot_gt, mp4_ids, batch_mp4_scenes_index, fps_batch, ) def validation_step_end(self, output): ( single_frame_pred, all_frame_pred, one_hot_gt, many_hot_gt, mp4_ids, _, _, ) = output # loss_one = self.lossfun(single_frame_pred.reshape(-1,3), one_hot_gt.reshape(-1)) # loss_all = self.lossfun(all_frame_pred.reshape(-1,3), many_hot_gt.reshape(-1)) loss_one = F.cross_entropy( single_frame_pred[:, 15:-15, :].reshape(-1, 2), one_hot_gt[:, 15:-15].reshape(-1), weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as( single_frame_pred ), ) loss_all = F.cross_entropy( all_frame_pred[:, 15:-15, :].reshape(-1, 2), many_hot_gt[:, 15:-15].reshape(-1), weight=torch.tensor([0.15, 0.85], device=single_frame_pred.device).type_as( single_frame_pred ), ) loss_total = loss_one * 0.8 + loss_all * 0.2 self.log( "val_loss", loss_total, on_epoch=True, on_step=True, prog_bar=True, logger=True, ) def validation_epoch_end(self, output): start = time.time() class_threshold_list = [0.1, 0.3, 0.5, 0.7] # 计算每个不同的class_threshold下召准 for class_threshold in class_threshold_list: transition_label_list = [] for output_each in output: ( single_frame_pred, all_frame_pred, one_hot_gt, many_hot_gt, mp4_ids, batch_mp4_scenes_index, fps_batch, ) = output_each transition_label_list = transition_label_list + complete_results_batch( mp4_ids.cpu(), batch_mp4_scenes_index.cpu(), fps_batch.cpu(), single_frame_pred.cpu().float(), class_threshold, ) custom_indicator = pr_call( transition_label_list, thresholds=[0.05, 0.1, 0.2, 0.3] ) self.log( f"{class_threshold}_0.01s_P", custom_indicator[0.05]["precision"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) self.log( f"{class_threshold}_0.01s_R", custom_indicator[0.05]["recall"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) self.log( f"{class_threshold}_0.1s_P", custom_indicator[0.1]["precision"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) self.log( f"{class_threshold}_0.1s_R", custom_indicator[0.1]["recall"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) self.log( f"{class_threshold}_0.2s_P", custom_indicator[0.2]["precision"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) self.log( f"{class_threshold}_0.2s_R", custom_indicator[0.2]["recall"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) self.log( f"{class_threshold}_0.3s_P", custom_indicator[0.3]["precision"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) self.log( f"{class_threshold}_0.3s_R", custom_indicator[0.3]["recall"], on_epoch=True, on_step=False, prog_bar=False, logger=True, ) print("推理耗时:{}".format(time.time() - start)) ## 优化器配置 def configure_optimizers(self): logger.info("configure_optimizers 初始化开始...") # 选择优化器 if self.args.optim == "SGD": optimizer = torch.optim.SGD( self.parameters(), lr=self.learning_rate, momentum=0.9 ) else: optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) # 选择学习率调度方式 if self.args.lr_scheduler == "OneCycleLR": scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=0.0002, verbose=True, epochs=500, steps_per_epoch=7 ) logger.info("configure_optimizers 初始化结束...") return [optimizer], [scheduler] elif self.args.lr_scheduler == "CosineAnnealingLR": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=200, eta_min=5e-7, verbose=True, last_epoch=-1 ) logger.info("configure_optimizers 初始化结束...") return [optimizer], [scheduler] elif self.args.lr_scheduler == "None": logger.info("configure_optimizers 初始化结束...") return optimizer