def load_model(name):
'''Creates and returns an instance of the model given its class name.
The created model has a single placeholder node for feeding images.
'''
# Find the model class from its name
all_models = models.get_models()
lut = {model.__name__: model for model in all_models}
if name not in lut:
print('Invalid model index. Options are:')
# Display a list of valid model names
for model in all_models:
print('\t* {}'.format(model.__name__))
return None
NetClass = lut[name]
# Create a placeholder for the input image
spec = models.get_data_spec(model_class=NetClass)
data_node = tf.placeholder(tf.float32,
shape=(None, spec.crop_size, spec.crop_size, spec.channels))
# Construct and return the model
return NetClass({'data': data_node})
python类get_models()的实例源码
def main():
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s : " +
"%(module)s (%(lineno)s) - %(levelname)s - %(message)s")
conf = ConfigParser(os.environ)
conf.read(sys.argv[1])
logging.warning('loading datasets...')
datasets = get_data(conf)
logging.warning('loaded these: {0}'.format(datasets.keys()))
logging.warning('loading models...')
models = get_models(conf)
logging.warning('evaluating...')
for data_type, data in datasets.iteritems():
logging.warning('data: {0}'.format(data_type))
r = Regression(conf)
r.featurize_data(data, models)
r.evaluate()
def load_model(name):
'''Creates and returns an instance of the model given its class name.
The created model has a single placeholder node for feeding images.
'''
# Find the model class from its name
all_models = models.get_models()
lut = {model.__name__: model for model in all_models}
if name not in lut:
print('Invalid model index. Options are:')
# Display a list of valid model names
for model in all_models:
print('\t* {}'.format(model.__name__))
return None
NetClass = lut[name]
# Create a placeholder for the input image
spec = models.get_data_spec(model_class=NetClass)
data_node = tf.placeholder(tf.float32,
shape=(None, spec.crop_size, spec.crop_size, spec.channels))
# Construct and return the model
return NetClass({'data': data_node})