def load_model(name):
'''Creates and returns an instance of the model given its class name.
The created model has a single placeholder node for feeding images.
'''
# Find the model class from its name
all_models = models.get_models()
lut = {model.__name__: model for model in all_models}
if name not in lut:
print('Invalid model index. Options are:')
# Display a list of valid model names
for model in all_models:
print('\t* {}'.format(model.__name__))
return None
NetClass = lut[name]
# Create a placeholder for the input image
spec = models.get_data_spec(model_class=NetClass)
data_node = tf.placeholder(tf.float32,
shape=(None, spec.crop_size, spec.crop_size, spec.channels))
# Construct and return the model
return NetClass({'data': data_node})
validate.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录