layers.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:fast-wavenet 作者: tomlepaine 项目源码 文件源码
def dilated_conv1d(inputs,
                   out_channels,
                   filter_width=2,
                   rate=1,
                   padding='VALID',
                   name=None,
                   gain=np.sqrt(2),
                   activation=tf.nn.relu):
    '''

    Args:
      inputs: (tensor)
      output_channels:
      filter_width:
      rate:
      padding:
      name:
      gain:
      activation:

    Outputs:
      outputs: (tensor)
    '''
    assert name
    with tf.variable_scope(name):
        _, width, _ = inputs.get_shape().as_list()
        inputs_ = time_to_batch(inputs, rate=rate)
        outputs_ = conv1d(inputs_,
                          out_channels=out_channels,
                          filter_width=filter_width,
                          padding=padding,
                          gain=gain,
                          activation=activation)
        _, conv_out_width, _ = outputs_.get_shape().as_list()
        new_width = conv_out_width * rate
        diff = new_width - width
        outputs = batch_to_time(outputs_, rate=rate, crop_left=diff)

        # Add additional shape information.
        tensor_shape = [tf.Dimension(None),
                        tf.Dimension(width),
                        tf.Dimension(out_channels)]
        outputs.set_shape(tf.TensorShape(tensor_shape))

    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号