model.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:the-wavenet-pianist 作者: 821760408-sp 项目源码 文件源码
def _enc_upsampling_conv(encoding,
                             audio_length,
                             filter_length=1024,
                             time_stride=512):
        """Upsample local conditioning encoding to match time dim. of audio  
        :param encoding: [mb, timeframe, channels] Local conditionining encoding
        :param audio_length: Length of time dimension of audio 
        :param filter_length: transpose conv. filter length
        :param time_stride: stride along time dimension (upsamp. factor)
        :return: upsampled local conditioning encoding
        """
        with tf.variable_scope('upsampling_conv'):
            batch_size, _, enc_channels = encoding.get_shape().as_list()
            shape = tf.shape(encoding)
            strides = [1, 1, time_stride, 1]
            output_length = (shape[1] - 1) * time_stride + filter_length
            output_shape = tf.stack(
                [batch_size, 1, output_length, enc_channels])

            kernel_shape = [1, filter_length, enc_channels, enc_channels]
            biases_shape = [enc_channels]

            upsamp_weights = tf.get_variable(
                'weights',
                kernel_shape,
                initializer=tf.uniform_unit_scaling_initializer(1.0))
            upsamp_biases = tf.get_variable(
                'biases',
                biases_shape,
                initializer=tf.constant_initializer(0.0))

            encoding = tf.reshape(encoding,
                                  [batch_size, 1, shape[1], enc_channels])
            upsamp_conv = tf.nn.conv2d_transpose(
                encoding,
                upsamp_weights, output_shape, strides, padding='VALID')
            output = tf.nn.bias_add(upsamp_conv, upsamp_biases)

            output = tf.reshape(output,
                                [batch_size, output_length, enc_channels])
            output_sliced = tf.slice(
                output, [0, 0, 0],
                tf.stack([-1, audio_length, -1]))
            output_sliced.set_shape([batch_size, audio_length, enc_channels])
            return output_sliced

    # especially for global conditioning coz it doesn't algin with audio input
    # on the time dimension, and needs broadcasting its value to input;
    # for local conditioning, we've already match their size
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号