def load_new_model(sess, restore_path, nlayers, device='/cpu:0'):
from resnet_imagenet_model import ResNetImageNetModel
with tf.device(device):
logger = tfplus.utils.logger.get()
with logger.verbose_level(2):
resnet = ResNetImageNetModel().set_all_options({
'inp_depth': 3,
'layers': get_layers(nlayers),
'strides': [1, 2, 2, 2],
'channels': [64, 256, 512, 1024, 2048],
'bottleneck': True,
'shortcut': 'projection',
'compatible': True,
'weight_decay': 1e-4,
'subtract_mean': True,
'trainable': False
})
inp_var = resnet.build_input()
out_var = resnet.build(inp_var)
out_var2 = resnet.build(inp_var)
saver = tf.train.Saver(resnet.get_save_var_dict())
saver.restore(sess, restore_path)
return resnet, inp_var, out_var, out_var2
评论列表
文章目录