def _enc_upsampling_conv(encoding,
audio_length,
filter_length=1024,
time_stride=512):
"""Upsample local conditioning encoding to match time dim. of audio
:param encoding: [mb, timeframe, channels] Local conditionining encoding
:param audio_length: Length of time dimension of audio
:param filter_length: transpose conv. filter length
:param time_stride: stride along time dimension (upsamp. factor)
:return: upsampled local conditioning encoding
"""
with tf.variable_scope('upsampling_conv'):
batch_size, _, enc_channels = encoding.get_shape().as_list()
shape = tf.shape(encoding)
strides = [1, 1, time_stride, 1]
output_length = (shape[1] - 1) * time_stride + filter_length
output_shape = tf.stack(
[batch_size, 1, output_length, enc_channels])
kernel_shape = [1, filter_length, enc_channels, enc_channels]
biases_shape = [enc_channels]
upsamp_weights = tf.get_variable(
'weights',
kernel_shape,
initializer=tf.uniform_unit_scaling_initializer(1.0))
upsamp_biases = tf.get_variable(
'biases',
biases_shape,
initializer=tf.constant_initializer(0.0))
encoding = tf.reshape(encoding,
[batch_size, 1, shape[1], enc_channels])
upsamp_conv = tf.nn.conv2d_transpose(
encoding,
upsamp_weights, output_shape, strides, padding='VALID')
output = tf.nn.bias_add(upsamp_conv, upsamp_biases)
output = tf.reshape(output,
[batch_size, output_length, enc_channels])
output_sliced = tf.slice(
output, [0, 0, 0],
tf.stack([-1, audio_length, -1]))
output_sliced.set_shape([batch_size, audio_length, enc_channels])
return output_sliced
# especially for global conditioning coz it doesn't algin with audio input
# on the time dimension, and needs broadcasting its value to input;
# for local conditioning, we've already match their size
评论列表
文章目录