python类get_variables()的实例源码

dm_learner.py 文件源码 项目:deepmodels 作者: learningsociety 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def save_model_for_prediction(self, save_ckpt_fn, vars_to_save=None):
    """Save model data only needed for prediction.

    Args:
      save_ckpt_fn: checkpoint file to save.
      vars_to_save: a list of variables to save.
    """
    if vars_to_save is None:
      vars_to_save = slim.get_model_variables()
      vars_restore_to_exclude = []
      for scope in self.dm_model.restore_scope_exclude:
        vars_restore_to_exclude.extend(slim.get_variables(scope))
      # remove not restored variables.
      vars_to_save = [
          v for v in vars_to_save if v not in vars_restore_to_exclude
      ]
    base_model.save_model(save_ckpt_fn, self.sess, vars_to_save)
inception_v1.py 文件源码 项目:Deep_Learning_In_Action 作者: SunnyMarkLiu 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def load_pretrained_model(self):
        """
        Load the pretrained weights into the non-trainable layer
        :return:
        """
        print('Load the pretrained weights into the non-trainable layer...')
        from tensorflow.python.framework import ops
        trainable_variables = slim.get_variables(None, None, ops.GraphKeys.TRAINABLE_VARIABLES)

        reader = pywrap_tensorflow.NewCheckpointReader(self.pre_trained_model_cpkt)
        pretrained_model_variables = reader.get_variable_to_shape_map()
        for variable in trainable_variables:
            variable_name = variable.name.split(':')[0]
            if variable_name in self.skip_layer:
                continue
            if variable_name not in pretrained_model_variables:
                continue
            print('load ' + variable_name)
            with tf.variable_scope('', reuse=True):
                var = tf.get_variable(variable_name, trainable=False)
                data = reader.get_tensor(variable_name)
                self.sess.run(var.assign(data))
dm_learner.py 文件源码 项目:deepmodels 作者: learningsociety 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def load_model_from_checkpoint_fn(self, model_fn):
    """Load weights from file and keep in memory.

    Args:
      model_fn: saved model file.
    """
    # self.dm_model.use_graph()
    print "start loading from checkpoint file..."
    if self.vars_to_restore is None:
      self.vars_to_restore = slim.get_variables()
    restore_fn = slim.assign_from_checkpoint_fn(model_fn, self.vars_to_restore)
    print "restoring model from {}".format(model_fn)
    restore_fn(self.sess)
    print "model restored."
vgg.py 文件源码 项目:cnn-visualizer 作者: penny4860 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def load_ckpt(self, sess, ckpt='ckpts/vgg_16.ckpt'):
        variables = slim.get_variables(scope='vgg_16')
        init_assign_op, init_feed_dict = slim.assign_from_checkpoint(ckpt, variables)
        sess.run(init_assign_op, init_feed_dict)
q_network.py 文件源码 项目:RL-Universe 作者: Bifrost-Research 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def build_network(self):

        state = tf.placeholder(tf.float32, [None, 84, 84, 4])

        cnn_1 = slim.conv2d(state, 16, [8,8], stride=4, scope=self.name + '/cnn_1', activation_fn=nn.relu)

        cnn_2 = slim.conv2d(cnn_1, 32, [4,4], stride=2, scope=self.name + '/cnn_2', activation_fn=nn.relu)

        flatten = slim.flatten(cnn_2)

        fcc_1 = slim.fully_connected(flatten, 256, scope=self.name + '/fcc_1', activation_fn=nn.relu)

        adv_probas = slim.fully_connected(fcc_1, self.nb_actions, scope=self.name + '/adv_probas', activation_fn=nn.softmax)

        value_state = slim.fully_connected(fcc_1, 1, scope=self.name + '/value_state', activation_fn=None)

        tf.summary.scalar("model/cnn1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_1')))
        tf.summary.scalar("model/cnn2_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/cnn_2')))
        tf.summary.scalar("model/fcc1_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/fcc_1')))
        tf.summary.scalar("model/adv_probas_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/adv_probas')))
        tf.summary.scalar("model/value_state_global_norm", tf.global_norm(slim.get_variables(scope=self.name + '/value_state')))

        #Input
        self._tf_state = state

        #Output
        self._tf_adv_probas = adv_probas
        self._tf_value_state = value_state


问题


面经


文章

微信
公众号

扫码关注公众号