def build_decoder(net, layer_config, i=None, reuse=False, masks=None):
i = i if i is not None else len(layer_config) - 1
cfg = layer_config[i]
name = cfg.dec_op_name if reuse else None
if len(layer_config) > i + 1:
if len(layer_config[i + 1].shape) != len(net.get_shape().as_list()):
net = tf.reshape(net, layer_config[i + 1].shape)
if i < 0 or layer_config[i].type == INPUT:
return net
if cfg.type == FC:
net = slim.fully_connected(net, int(np.prod(cfg.shape[1:])), scope=name,
activation_fn=cfg.activation, reuse=reuse)
elif cfg.type == CONV:
net = slim.conv2d_transpose(net, cfg.shape[-1], [cfg.kernel, cfg.kernel], stride=cfg.stride,
activation_fn=cfg.activation, padding=PADDING,
scope=name, reuse=reuse)
elif cfg.type == POOL_ARG:
if cfg.argmax is not None or masks is not None:
mask = cfg.argmax if cfg.argmax is not None else masks.pop()
net = nut.unpool(net, mask=mask, stride=cfg.kernel)
else:
net = nut.upsample(net, stride=cfg.kernel, mode='COPY')
elif cfg.type == POOL:
net = nut.upsample(net, cfg.kernel)
elif cfg.type == DO:
pass
elif cfg.type == LOSS:
cfg.arg2 = net
elif cfg.type == INPUT:
assert False
if not reuse:
cfg.dec_op_name = net.name.split('/')[0]
if not reuse:
ut.print_info('\rdecoder_%d \t%s' % (i, str(net)), color=CONFIG_COLOR)
cfg.dout = net
return build_decoder(net, layer_config, i - 1, reuse=reuse, masks=masks)
model_interpreter.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录