def load(cls, dirname, session, training=False):
"""
Load a previously saved file.
:param dirname: directory with model files
:param session: tensorflow session
:param training: whether to create training tensors
:return: an instance of MultiFeedForward
:rtype: MultiFeedForwardClassifier
"""
params = utils.load_parameters(dirname)
model = cls._init_from_load(params, training)
tensorflow_file = os.path.join(dirname, 'model')
saver = tf.train.Saver(tf.trainable_variables())
saver.restore(session, tensorflow_file)
# if training, optimizer values still have to be initialized
if training:
train_vars = [v for v in tf.global_variables()
if v.name.startswith('training')]
init_op = tf.variables_initializer(train_vars)
session.run(init_op)
return model
评论列表
文章目录