def load_wrapper_model(sess, restore_path, nlayers, device='/cpu:0'):
from resnet_imagenet_model_wrapper import ResNetImageNetModelWrapper
with tf.device(device):
logger = tfplus.utils.logger.get()
with logger.verbose_level(2):
resnet = ResNetImageNetModelWrapper().set_all_options({
'inp_depth': 3,
'layers': get_layers(nlayers),
'strides': [1, 2, 2, 2],
'channels': [64, 256, 512, 1024, 2048],
'bottleneck': True,
'shortcut': 'projection',
'compatible': True,
'wd': 1e-4,
'subtract_mean': True,
'trainable': False
})
inp_var = resnet.build_input()
out_var = resnet.build(inp_var)
saver = tf.train.Saver(resnet.res_net.get_save_var_dict())
saver.restore(sess, restore_path)
return resnet.res_net, inp_var, out_var['y_out']
评论列表
文章目录