resnet_test.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def load_new_model(sess, restore_path, nlayers, device='/cpu:0'):
    from resnet_imagenet_model import ResNetImageNetModel
    with tf.device(device):
        logger = tfplus.utils.logger.get()
        with logger.verbose_level(2):
            resnet = ResNetImageNetModel().set_all_options({
                'inp_depth': 3,
                'layers': get_layers(nlayers),
                'strides': [1, 2, 2, 2],
                'channels': [64, 256, 512, 1024, 2048],
                'bottleneck': True,
                'shortcut': 'projection',
                'compatible': True,
                'weight_decay': 1e-4,
                'subtract_mean': True,
                'trainable': False
            })
            inp_var = resnet.build_input()
            out_var = resnet.build(inp_var)
            out_var2 = resnet.build(inp_var)
    saver = tf.train.Saver(resnet.get_save_var_dict())
    saver.restore(sess, restore_path)
    return resnet, inp_var, out_var, out_var2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号