def encoder_decoder(paramsfile, specstr, channels=3, layersplit='encode', shape=(64,64),
poolinv=False):
inp = T.tensor4('inputs')
w,h=shape
build_fn = build_cae if poolinv else build_cae_nopoolinv
network = build_fn(inp, shape=shape,channels=channels,specstr=specstr)
u.load_params(network, paramsfile)
laylist = nn.layers.get_all_layers(network)
enc_layer_idx = next(i for i in xrange(len(laylist)) if laylist[i].name==layersplit)
enc_layer = laylist[enc_layer_idx]
return (lambda x: nn.layers.get_output(enc_layer, inputs=x,deterministic=True).eval(),
lambda x: nn.layers.get_output(network,
inputs={laylist[0]:np.zeros((x.shape[0],channels,w,h),
dtype=theano.config.floatX),
enc_layer:x}, deterministic=True).eval().reshape(-1,channels,w,h))
评论列表
文章目录