models.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:video_predict 作者: tencia 项目源码 文件源码
def encoder_decoder(paramsfile, specstr, channels=3, layersplit='encode', shape=(64,64),
        poolinv=False):
    inp = T.tensor4('inputs')
    w,h=shape
    build_fn = build_cae if poolinv else build_cae_nopoolinv
    network = build_fn(inp, shape=shape,channels=channels,specstr=specstr)
    u.load_params(network, paramsfile)
    laylist = nn.layers.get_all_layers(network)
    enc_layer_idx = next(i for i in xrange(len(laylist)) if laylist[i].name==layersplit)
    enc_layer = laylist[enc_layer_idx]
    return (lambda x: nn.layers.get_output(enc_layer, inputs=x,deterministic=True).eval(),
            lambda x: nn.layers.get_output(network,
                inputs={laylist[0]:np.zeros((x.shape[0],channels,w,h),
                    dtype=theano.config.floatX),
                    enc_layer:x}, deterministic=True).eval().reshape(-1,channels,w,h))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号