def save_model_for_prediction(self, save_ckpt_fn, vars_to_save=None):
"""Save model data only needed for prediction.
Args:
save_ckpt_fn: checkpoint file to save.
vars_to_save: a list of variables to save.
"""
if vars_to_save is None:
vars_to_save = slim.get_model_variables()
vars_restore_to_exclude = []
for scope in self.dm_model.restore_scope_exclude:
vars_restore_to_exclude.extend(slim.get_variables(scope))
# remove not restored variables.
vars_to_save = [
v for v in vars_to_save if v not in vars_restore_to_exclude
]
base_model.save_model(save_ckpt_fn, self.sess, vars_to_save)
python类get_variables()的实例源码
def load_pretrained_model(self):
"""
Load the pretrained weights into the non-trainable layer
:return:
"""
print('Load the pretrained weights into the non-trainable layer...')
from tensorflow.python.framework import ops
trainable_variables = slim.get_variables(None, None, ops.GraphKeys.TRAINABLE_VARIABLES)
reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_model_cpkt)
pretrained_model_variables = reader.get_variable_to_shape_map()
for variable in trainable_variables:
variable_name = variable.name.split(':')[0]
if variable_name in self.skip_layer:
continue
if variable_name not in pretrained_model_variables:
continue
print('load ' + variable_name)
with tf.variable_scope('', reuse=True):
var = tf.get_variable(variable_name, trainable=False)
data = reader.get_tensor(variable_name)
self.sess.run(var.assign(data))
def load_model_from_checkpoint_fn(self, model_fn):
"""Load weights from file and keep in memory.
Args:
model_fn: saved model file.
"""
# self.dm_model.use_graph()
print "start loading from checkpoint file..."
if self.vars_to_restore is None:
self.vars_to_restore = slim.get_variables()
restore_fn = slim.assign_from_checkpoint_fn(model_fn, self.vars_to_restore)
print "restoring model from {}".format(model_fn)
restore_fn(self.sess)
print "model restored."
def load_ckpt(self, sess, ckpt='ckpts/vgg_16.ckpt'):
variables = slim.get_variables(scope='vgg_16')
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt, variables)
sess.run(init_assign_op, init_feed_dict)
def build_network(self):
state = tf.placeholder(tf.float32, [None, 84, 84, 4])
cnn_1 = slim.conv2d(state, 16, [8,8], stride=4, scope=self.name + '/cnn_1', activation_fn=nn.relu)
cnn_2 = slim.conv2d(cnn_1, 32, [4,4], stride=2, scope=self.name + '/cnn_2', activation_fn=nn.relu)
flatten = slim.flatten(cnn_2)
fcc_1 = slim.fully_connected(flatten, 256, scope=self.name + '/fcc_1', activation_fn=nn.relu)
adv_probas = slim.fully_connected(fcc_1, self.nb_actions, scope=self.name + '/adv_probas', activation_fn=nn.softmax)
value_state = slim.fully_connected(fcc_1, 1, scope=self.name + '/value_state', activation_fn=None)
tf.summary.scalar("model/cnn1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_1')))
tf.summary.scalar("model/cnn2_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_2')))
tf.summary.scalar("model/fcc1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/fcc_1')))
tf.summary.scalar("model/adv_probas_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/adv_probas')))
tf.summary.scalar("model/value_state_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/value_state')))
#Input
self._tf_state = state
#Output
self._tf_adv_probas = adv_probas
self._tf_value_state = value_state