wavenet.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:WaveNet-Theano 作者: huyouare 项目源码 文件源码
def sample_from_softmax(softmax_var):
    #softmax_var assumed to be of shape (batch_size, num_classes)
    old_shape = softmax_var.shape

    softmax_var_reshaped = softmax_var.reshape((-1,softmax_var.shape[softmax_var.ndim-1]))

    return T.argmax(
        T.cast(
            srng.multinomial(pvals=softmax_var_reshaped),
            theano.config.floatX
            ).reshape(old_shape),
        axis = softmax_var.ndim-1
        )

# inputs.shape: (batch size, length, input_dim)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号