def build_encoder(net, layer_config, i=1, reuse=False):
if i == len(layer_config):
return net
cfg = layer_config[i]
cfg.shape = net.get_shape().as_list()
name = cfg.enc_op_name if reuse else None
cfg.ein = net
if cfg.type == FC:
if len(cfg.shape) > 2:
net = slim.flatten(net)
net = slim.fully_connected(net, cfg.size, activation_fn=cfg.activation,
scope=name, reuse=reuse)
elif cfg.type == CONV:
net = slim.conv2d(net, cfg.size, [cfg.kernel, cfg.kernel], stride=cfg.stride,
activation_fn=cfg.activation, padding=PADDING,
scope=name, reuse=reuse)
elif cfg.type == POOL_ARG:
net, cfg.argmax = nut.max_pool_with_argmax(net, cfg.kernel)
# if not reuse:
# mask = nut.fake_arg_max_of_max_pool(cfg.shape, cfg.kernel)
# cfg.argmax_dummy = tf.constant(mask.flatten(), shape=mask.shape)
elif cfg.type == POOL:
net = slim.max_pool2d(net, kernel_size=[cfg.kernel, cfg.kernel], stride=cfg.kernel)
elif cfg.type == DO:
net = tf.nn.dropout(net, keep_prob=cfg.keep_prob)
elif cfg.type == LOSS:
cfg.arg1 = net
elif cfg.type == INPUT:
assert False
if not reuse:
cfg.enc_op_name = net.name.split('/')[0]
if not reuse:
ut.print_info('\rencoder_%d\t%s\t%s' % (i, str(net), cfg.enc_op_name), color=CONFIG_COLOR)
return build_encoder(net, layer_config, i + 1, reuse=reuse)
model_interpreter.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录