model_interpreter.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:TensorFlow_DCIGN 作者: yselivonchyk 项目源码 文件源码
def build_decoder(net, layer_config, i=None, reuse=False, masks=None):
  i = i if i is not None else len(layer_config) - 1

  cfg = layer_config[i]
  name = cfg.dec_op_name if reuse else None
  if len(layer_config) > i + 1:
    if len(layer_config[i + 1].shape) != len(net.get_shape().as_list()):
      net = tf.reshape(net, layer_config[i + 1].shape)

  if i < 0 or layer_config[i].type == INPUT:
    return net

  if cfg.type == FC:
    net = slim.fully_connected(net, int(np.prod(cfg.shape[1:])), scope=name,
                               activation_fn=cfg.activation, reuse=reuse)
  elif cfg.type == CONV:
    net = slim.conv2d_transpose(net, cfg.shape[-1], [cfg.kernel, cfg.kernel], stride=cfg.stride,
                                activation_fn=cfg.activation, padding=PADDING,
                                scope=name, reuse=reuse)
  elif cfg.type == POOL_ARG:
    if cfg.argmax is not None or masks is not None:
      mask = cfg.argmax if cfg.argmax is not None else masks.pop()
      net = nut.unpool(net, mask=mask, stride=cfg.kernel)
    else:
      net = nut.upsample(net, stride=cfg.kernel, mode='COPY')
  elif cfg.type == POOL:
    net = nut.upsample(net, cfg.kernel)
  elif cfg.type == DO:
    pass
  elif cfg.type == LOSS:
    cfg.arg2 = net
  elif cfg.type == INPUT:
    assert False
  if not reuse:
    cfg.dec_op_name = net.name.split('/')[0]
  if not reuse:
    ut.print_info('\rdecoder_%d \t%s' % (i, str(net)), color=CONFIG_COLOR)
  cfg.dout = net
  return build_decoder(net, layer_config, i - 1, reuse=reuse, masks=masks)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号