renet.py 文件源码

python
阅读 43 收藏 0 点赞 0 评论 0

项目:lazyprogrammer 作者: inhwane 项目源码 文件源码
def renet_layer_lr_allscan(X, rnn1, rnn2, w, h, wp, hp):
    # list_of_images = []
    C = X.shape[0]
    X = X.dimshuffle((1, 0, 2)).reshape((h/hp, hp*C*w)) # split the rows for the first scan
    def rnn_pass(x):
        x = x.reshape((hp, C, w)).dimshuffle((2, 1, 0)).reshape((w/wp, C*wp*hp))
        h1 = rnn1.output(x)
        h2 = rnn2.output(x, go_backwards=True)
        img = T.concatenate([h1.T, h2.T])
        # list_of_images.append(img)
        return img

    results, _ = theano.scan(
        fn=rnn_pass,
        sequences=X,
        outputs_info=None,
        n_steps=h/hp,
    )
    return results.dimshuffle((1, 0, 2))
    # return T.stacklists(list_of_images).dimshuffle((1, 0, 2))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号