ops.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:sampleRNN_ICLR2017 作者: soroushmehr 项目源码 文件源码
def __Recurrent(name, hidden_dims, step_fn, inputs, non_sequences=[], h0s=None):
    if not isinstance(inputs, list):
        inputs = [inputs]

    if not isinstance(hidden_dims, list):
        hidden_dims = [hidden_dims]

    if h0s is None:
        h0s = [None]*len(hidden_dims)

    for i in xrange(len(hidden_dims)):
        if h0s[i] is None:
            h0_unbatched = lib.param(
                name + '.h0_' + str(i),
                numpy.zeros((hidden_dims[i],), dtype=theano.config.floatX)
            )
            num_batches = inputs[0].shape[1]
            h0s[i] = T.alloc(h0_unbatched, num_batches, hidden_dims[i])

        h0s[i] = T.patternbroadcast(h0s[i], [False] * h0s[i].ndim)

    outputs, _ = theano.scan(
        step_fn,
        sequences=inputs,
        outputs_info=h0s,
        non_sequences=non_sequences
    )

    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号