model_interpreter.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:TensorFlow_DCIGN 作者: yselivonchyk 项目源码 文件源码
def build_encoder(net, layer_config, i=1, reuse=False):
  if i == len(layer_config):
    return net

  cfg = layer_config[i]
  cfg.shape = net.get_shape().as_list()
  name = cfg.enc_op_name if reuse else None
  cfg.ein = net
  if cfg.type == FC:
    if len(cfg.shape) > 2:
      net = slim.flatten(net)
    net = slim.fully_connected(net, cfg.size, activation_fn=cfg.activation,
                               scope=name, reuse=reuse)
  elif cfg.type == CONV:
    net = slim.conv2d(net, cfg.size, [cfg.kernel, cfg.kernel], stride=cfg.stride,
                      activation_fn=cfg.activation, padding=PADDING,
                      scope=name, reuse=reuse)
  elif cfg.type == POOL_ARG:
    net, cfg.argmax = nut.max_pool_with_argmax(net, cfg.kernel)
    # if not reuse:
    #   mask = nut.fake_arg_max_of_max_pool(cfg.shape, cfg.kernel)
    #   cfg.argmax_dummy = tf.constant(mask.flatten(), shape=mask.shape)
  elif cfg.type == POOL:
    net = slim.max_pool2d(net, kernel_size=[cfg.kernel, cfg.kernel], stride=cfg.kernel)
  elif cfg.type == DO:
    net = tf.nn.dropout(net, keep_prob=cfg.keep_prob)
  elif cfg.type == LOSS:
    cfg.arg1 = net
  elif cfg.type == INPUT:
    assert False

  if not reuse:
    cfg.enc_op_name = net.name.split('/')[0]
  if not reuse:
    ut.print_info('\rencoder_%d\t%s\t%s' % (i, str(net), cfg.enc_op_name), color=CONFIG_COLOR)
  return build_encoder(net, layer_config, i + 1, reuse=reuse)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号