|
"""Utility functions of the UVQ model. |
|
|
|
Copyright 2022 Google LLC |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
https://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
""" |
|
|
|
|
|
import csv |
|
import os |
|
import subprocess |
|
import tempfile |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import tensorflow as tf |
|
|
|
from tensorflow.compat.v1 import gfile, logging, saved_model |
|
from tensorflow.compat.v1 import ConfigProto, Session, Graph |
|
|
|
|
|
|
|
VIDEO_HEIGHT = 720 |
|
VIDEO_WIDTH = 1280 |
|
VIDEO_FPS = 5 |
|
VIDEO_CHANNEL = 3 |
|
|
|
|
|
DIM_HEIGHT_FEATURE = 16 |
|
DIM_WIDTH_FEATURE = 16 |
|
DIM_CHANNEL_FEATURE = 100 |
|
|
|
|
|
INPUT_HEIGHT_CONTENT = 496 |
|
INPUT_WIDTH_CONTENT = 496 |
|
INPUT_CHANNEL_CONTENT = 3 |
|
DIM_LABEL_CONTENT = 3862 |
|
|
|
def extend_array(rgb, total_len): |
|
"""Extends the byte array (or truncates) to be total_len""" |
|
missing = total_len - len(rgb) |
|
if missing < 0: |
|
rgb = rgb[0 : total_len] |
|
else: |
|
rgb.extend(bytearray(missing)) |
|
return rgb |
|
|
|
|
|
def load_video(filepath, video_length, transpose=False): |
|
"""Load input video.""" |
|
|
|
if transpose: |
|
transpose_param = 'transpose=1,' |
|
else: |
|
transpose_param = '' |
|
|
|
|
|
fd, temp_filename = tempfile.mkstemp() |
|
fd_small, temp_filename_small = tempfile.mkstemp() |
|
cmd = ( |
|
'ffmpeg -i %s -filter_complex ' |
|
' "[0:v]%sscale=w=%d:h=%d:flags=bicubic:force_original_aspect_ratio=1,' |
|
'pad=%d:%d:(ow-iw)/2:(oh-ih)/2,format=rgb24,split=2[out1][tmp],[tmp]scale=%d:%d:flags=bilinear[out2]"' |
|
' -map [out1] -r %d -f rawvideo -pix_fmt rgb24 -y %s' |
|
' -map [out2] -r %d -f rawvideo -pix_fmt rgb24 -y %s' |
|
) % (filepath, transpose_param, VIDEO_WIDTH, VIDEO_HEIGHT, |
|
VIDEO_WIDTH, VIDEO_HEIGHT, INPUT_WIDTH_CONTENT, INPUT_HEIGHT_CONTENT, |
|
VIDEO_FPS, temp_filename, VIDEO_FPS, temp_filename_small) |
|
|
|
try: |
|
logging.info('Run with cmd:% s\n', cmd) |
|
subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) |
|
except subprocess.CalledProcessError as error: |
|
logging.fatal('Run with cmd: %s \n terminated with return code %s\n%s', |
|
cmd, str(error.returncode), error.output) |
|
raise error |
|
|
|
|
|
with gfile.Open(temp_filename, 'rb') as rgb_file, gfile.Open(temp_filename_small, 'rb') as rgb_file_small: |
|
single_frame_size = VIDEO_WIDTH * VIDEO_HEIGHT * VIDEO_CHANNEL |
|
full_decode_size = video_length * VIDEO_FPS * single_frame_size |
|
assert rgb_file.size() >= single_frame_size, f'Decoding failed to output a single frame: {rgb_file.size()} < {single_frame_size}' |
|
if rgb_file.size() < full_decode_size: |
|
logging.warn('Decoding may be truncated: %d bytes (%d frames) < %d bytes (%d frames),' |
|
' or video length (%ds) may be too incorrect', |
|
rgb_file.size(), rgb_file.size() / single_frame_size, |
|
full_decode_size, full_decode_size / single_frame_size, |
|
video_length) |
|
|
|
rgb = extend_array(bytearray(rgb_file.read()), full_decode_size) |
|
rgb_small = extend_array(bytearray(rgb_file_small.read()), |
|
video_length * VIDEO_FPS * INPUT_WIDTH_CONTENT * INPUT_HEIGHT_CONTENT * VIDEO_CHANNEL) |
|
video = (np.reshape(np.frombuffer(rgb, 'uint8'), |
|
(video_length, int(VIDEO_FPS), VIDEO_HEIGHT, VIDEO_WIDTH, 3))/255.0 - 0.5) *2 |
|
video_resized = (np.reshape(np.frombuffer(rgb_small, 'uint8'), |
|
(video_length, int(VIDEO_FPS), INPUT_HEIGHT_CONTENT, INPUT_WIDTH_CONTENT, 3))/255.0 - 0.5) *2 |
|
|
|
|
|
os.close(fd) |
|
os.remove(temp_filename) |
|
os.close(fd_small) |
|
os.remove(temp_filename_small) |
|
logging.info('Load %s done successfully.', filepath) |
|
return video, video_resized |
|
|
|
|
|
def generate_content_feature(video, model_path, input_node, output_nodes): |
|
"""Extract features from ContentNet.""" |
|
with Session(graph=Graph(), config=ConfigProto( |
|
allow_soft_placement=True, log_device_placement=False)) as sess: |
|
saved_model.loader.load( |
|
sess, [saved_model.tag_constants.SERVING], model_path) |
|
|
|
label = np.ndarray((video.shape[0], DIM_LABEL_CONTENT), np.float32) |
|
feature = np.ndarray((video.shape[0], DIM_HEIGHT_FEATURE, |
|
DIM_WIDTH_FEATURE, DIM_CHANNEL_FEATURE), np.float32) |
|
|
|
patch = np.ndarray((1, INPUT_HEIGHT_CONTENT, INPUT_WIDTH_CONTENT, |
|
INPUT_CHANNEL_CONTENT), np.float32) |
|
|
|
for k in range(video.shape[0]): |
|
patch[0, :] = video[k, 0, :, :, :] |
|
patch_feature, patch_label = sess.run(output_nodes, |
|
feed_dict={input_node: patch}) |
|
|
|
feature[k, :, :, :] = patch_feature |
|
label[k, :] = patch_label |
|
return feature, label |
|
|
|
|
|
def generate_subnet_feature(video, model_path, input_width, input_height, |
|
input_fps, feature_width, feature_height, |
|
feature_channel, label_dim, input_node, |
|
output_nodes): |
|
"""Extract features from CompresionNet or DistortionNet.""" |
|
with Session(graph=Graph(), config=ConfigProto( |
|
allow_soft_placement=True, log_device_placement=False)) as sess: |
|
saved_model.loader.load( |
|
sess, [saved_model.tag_constants.SERVING], model_path) |
|
|
|
num_h = int(VIDEO_HEIGHT / input_height) |
|
num_w = int(VIDEO_WIDTH / input_width) |
|
|
|
label = np.ndarray((video.shape[0], num_h, num_w, label_dim), np.float32) |
|
feature = np.ndarray((video.shape[0], num_h * feature_height, |
|
num_w * feature_width, feature_channel), np.float32) |
|
|
|
if input_fps == 1: |
|
patch = np.ndarray((1, input_height, input_width, |
|
video.shape[-1]), np.float32) |
|
else: |
|
patch = np.ndarray((1, input_fps, input_height, input_width, |
|
video.shape[-1]), np.float32) |
|
|
|
for k in range(video.shape[0]): |
|
for j in range(num_h): |
|
for i in range(num_w): |
|
if input_fps == 1: |
|
patch[0, :] = video[k, 0, j * input_height:(j + 1) * input_height, |
|
i * input_width:(i + 1) * input_width, :] |
|
else: |
|
patch[0, :] = video[k, :, j * input_height:(j + 1) * input_height, |
|
i * input_width:(i + 1) * input_width, :] |
|
|
|
patch_feature, patch_label = sess.run(output_nodes, |
|
feed_dict={input_node: patch}) |
|
|
|
feature[k, j * feature_height:(j + 1) * feature_height, i * |
|
feature_width:(i + 1) * feature_width, :] = patch_feature |
|
|
|
label[k, j, i, :] = patch_label |
|
|
|
return feature, label |
|
|
|
|
|
def generate_features(video_id, video_length, filepath, model_dir, output_dir, |
|
transpose=False): |
|
"""Generate features from input video.""" |
|
video, video_resized = load_video(filepath, video_length, transpose) |
|
|
|
feature_compression, label_compression = generate_subnet_feature( |
|
video, '%s/compressionnet_baseline' % model_dir, |
|
320, 180, 5, |
|
4, 4, 100, 1, |
|
'input_orig:0', |
|
['feature_layer_orig:0', 'compress_level_orig:0']) |
|
|
|
feature_content, label_content = generate_content_feature( |
|
video_resized, '%s/contentnet_baseline' % model_dir, |
|
'map/TensorArrayV2Stack/TensorListStack:0', |
|
['final_conv2d/Conv2D:0', 'class_confidence:0']) |
|
|
|
feature_distortion, label_distortion = generate_subnet_feature( |
|
video, '%s/distortionnet_baseline' % model_dir, |
|
640, 360, 1, |
|
8, 8, 100, 26, |
|
'input_images:0', |
|
['feature_map:0', 'dist_type_prediction/dist_type_predictions:0']) |
|
|
|
|
|
fd, temp = tempfile.mkstemp() |
|
|
|
feature_compression.astype('float32').tofile(temp) |
|
out_feature = '%s/%s_feature_compression.binary' % (output_dir, video_id) |
|
gfile.Copy(temp, out_feature, overwrite=True) |
|
|
|
feature_content.astype('float32').tofile(temp) |
|
out_feature = '%s/%s_feature_content.binary' % (output_dir, video_id) |
|
gfile.Copy(temp, out_feature, overwrite=True) |
|
|
|
feature_distortion.astype('float32').tofile(temp) |
|
out_feature = '%s/%s_feature_distortion.binary' % (output_dir, video_id) |
|
gfile.Copy(temp, out_feature, overwrite=True) |
|
|
|
|
|
np.savetxt(temp, label_compression.reshape(label_compression.shape[0], -1), |
|
fmt='%0.3f', delimiter=',') |
|
out_feature = '%s/%s_label_compression.csv' % (output_dir, video_id) |
|
gfile.Copy(temp, out_feature, overwrite=True) |
|
|
|
np.savetxt(temp, label_content.reshape(label_content.shape[0], -1), |
|
fmt='%0.3f', delimiter=',') |
|
out_feature = '%s/%s_label_content.csv' % (output_dir, video_id) |
|
gfile.Copy(temp, out_feature, overwrite=True) |
|
|
|
np.savetxt(temp, label_distortion.reshape(label_distortion.shape[0], -1), |
|
fmt='%0.3f', delimiter=',') |
|
out_feature = '%s/%s_label_distortion.csv' % (output_dir, video_id) |
|
gfile.Copy(temp, out_feature, overwrite=True) |
|
|
|
os.close(fd) |
|
os.remove(temp) |
|
|
|
|
|
def load_features(video_id, dim_time, feature_dir): |
|
"""Load pre-generated features.""" |
|
input_compression_feature = '%s/%s_feature_compression.binary' % ( |
|
feature_dir, video_id) |
|
with gfile.Open(input_compression_feature, 'rb') as input_file: |
|
s = input_file.read() |
|
with Session() as sess: |
|
feature_1d = tf.io.decode_raw(s, out_type=tf.float32) |
|
|
|
feature = tf.reshape(feature_1d, |
|
[1, dim_time, DIM_HEIGHT_FEATURE, |
|
DIM_WIDTH_FEATURE, DIM_CHANNEL_FEATURE]) |
|
feature_compression = sess.run(feature) |
|
|
|
input_content_feature = '%s/%s_feature_content.binary' % ( |
|
feature_dir, video_id) |
|
with gfile.Open(input_content_feature, 'rb') as input_file: |
|
s = input_file.read() |
|
with Session() as sess: |
|
feature_1d = tf.io.decode_raw(s, out_type=tf.float32) |
|
|
|
feature = tf.reshape(feature_1d, |
|
[1, dim_time, DIM_HEIGHT_FEATURE, |
|
DIM_WIDTH_FEATURE, DIM_CHANNEL_FEATURE]) |
|
feature_content = sess.run(feature) |
|
|
|
input_distortion_feature = '%s/%s_feature_distortion.binary' % ( |
|
feature_dir, video_id) |
|
with gfile.Open(input_distortion_feature, 'rb') as input_file: |
|
s = input_file.read() |
|
with Session() as sess: |
|
feature_1d = tf.io.decode_raw(s, out_type=tf.float32) |
|
|
|
feature = tf.reshape(feature_1d, |
|
[1, dim_time, DIM_HEIGHT_FEATURE, |
|
DIM_WIDTH_FEATURE, DIM_CHANNEL_FEATURE]) |
|
feature_distortion = sess.run(feature) |
|
|
|
return feature_compression, feature_content, feature_distortion |
|
|
|
|
|
def prediction(video_id, video_length, model_dir, feature_dir, output_dir): |
|
"""Predict quality (MOS).""" |
|
trainset = 'ytugc20s' |
|
all_trainset_subindex = ['0', '1', '2', '3', '4'] |
|
|
|
all_feature = ['compression', |
|
'content', |
|
'distortion', |
|
'compression_content', |
|
'compression_distortion', |
|
'content_distortion', |
|
'compression_content_distortion', |
|
] |
|
aggregation_model = 'avgpool' |
|
all_outputs = [] |
|
for feature in all_feature: |
|
aggregated_mos = 0 |
|
for trainset_subindex in all_trainset_subindex: |
|
model_name = '%s_%s_%s_%s' % ( |
|
trainset, trainset_subindex, aggregation_model, feature) |
|
|
|
with Session(graph=Graph(), config=ConfigProto( |
|
allow_soft_placement=True, log_device_placement=False)) as sess: |
|
saved_model.loader.load( |
|
sess, [saved_model.tag_constants.SERVING], |
|
'%s/aggregationnet_baseline/%s' % ( |
|
model_dir, model_name)) |
|
|
|
[feature_compression, feature_content, feature_distortion |
|
] = load_features(video_id, video_length, feature_dir) |
|
feature_compression = feature_compression[ |
|
:, 0:video_length, :, :, :] |
|
feature_content = feature_content[ |
|
:, 0:video_length, :, :, :] |
|
feature_distortion = feature_distortion[ |
|
:, 0:video_length, :, :, :] |
|
|
|
pred_mos = sess.run( |
|
'Model/mos:0', |
|
feed_dict={'feature_compression:0': feature_compression, |
|
'feature_content:0': feature_content, |
|
'feature_distortion:0': feature_distortion, |
|
}) |
|
pred_mos = pred_mos[0][0] |
|
aggregated_mos += pred_mos |
|
|
|
aggregated_mos /= len(all_trainset_subindex) |
|
all_outputs.append([video_id, feature, aggregated_mos]) |
|
|
|
fd, temp = tempfile.mkstemp() |
|
with gfile.Open(temp, 'w') as f: |
|
writer = csv.writer(f) |
|
writer.writerows(all_outputs) |
|
out_file = '%s/%s_uvq.csv' % (output_dir, video_id) |
|
gfile.Copy(temp, out_file, overwrite=True) |
|
os.close(fd) |
|
os.remove(temp) |
|
|