vrnn_ar.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:speech 作者: igul222 项目源码 文件源码
def Recurrence(processed_frames, h0, reset):
    """
    processed_frames.shape: (batch size, n frames, DIM)
    h0.shape: (batch size, N_GRUS, DIM)
    reset.shape: ()
    output.shape: (batch size, n frames, DIM)
    """

    # print "warning no recurrence"
    # return T.zeros_like(processed_frames), h0

    learned_h0 = lib.param(
        'Recurrence.h0',
        numpy.zeros((N_GRUS, DIM), dtype=theano.config.floatX)
    )
    learned_h0 = T.alloc(learned_h0, h0.shape[0], N_GRUS, DIM)
    learned_h0 = T.patternbroadcast(learned_h0, [False] * learned_h0.ndim)
    h0 = theano.ifelse.ifelse(reset, learned_h0, h0)

    gru0 = lib.ops.LowMemGRU('Recurrence.GRU0', DIM, DIM, processed_frames, h0=h0[:, 0])
    grus = [gru0]
    for i in xrange(1, N_GRUS):
        gru = lib.ops.LowMemGRU('Recurrence.GRU'+str(i), DIM, DIM, grus[-1], h0=h0[:, i])
        grus.append(gru)

    last_hidden = T.stack([gru[:,-1] for gru in grus], axis=1)

    return (grus[-1], last_hidden)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号