def psnr_error(gen_frames, gt_frames):
"""
Computes the Peak Signal to Noise Ratio error between the generated images and the ground
truth images.
@param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the
generator model.
@param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for
each frame in gen_frames.
@return: A scalar tensor. The mean Peak Signal to Noise Ratio error over each frame in the
batch.
"""
shape = tf.shape(gen_frames)
num_pixels = tf.to_float(shape[1] * shape[2] * shape[3])
square_diff = tf.square(gt_frames - gen_frames)
batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(square_diff, [1, 2, 3])))
return tf.reduce_mean(batch_errors)
评论列表
文章目录