resnet_test.py 文件源码

python
阅读 50 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def load_wrapper_model(sess, restore_path, nlayers, device='/cpu:0'):
    from resnet_imagenet_model_wrapper import ResNetImageNetModelWrapper
    with tf.device(device):
        logger = tfplus.utils.logger.get()
        with logger.verbose_level(2):
            resnet = ResNetImageNetModelWrapper().set_all_options({
                'inp_depth': 3,
                'layers': get_layers(nlayers),
                'strides': [1, 2, 2, 2],
                'channels': [64, 256, 512, 1024, 2048],
                'bottleneck': True,
                'shortcut': 'projection',
                'compatible': True,
                'wd': 1e-4,
                'subtract_mean': True,
                'trainable': False
            })
            inp_var = resnet.build_input()
            out_var = resnet.build(inp_var)
    saver = tf.train.Saver(resnet.res_net.get_save_var_dict())
    saver.restore(sess, restore_path)
    return resnet.res_net, inp_var, out_var['y_out']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号