python类global_variables()的实例源码

sequenceNet.py 文件源码 项目:deep-summarization 作者: harpribot 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _start_session(self):
        """
        Starts the Tensorflow Session

        :return: None
        """
        self.sess.run(tf.global_variables_initializer())
        # initialize the saver node
        # print tf.GraphKeys.GLOBAL_VARIABLES
        self.saver = tf.train.Saver(tf.global_variables())
        # get the latest checkpoint
        last_checkpoint_path = self.checkpointer.get_last_checkpoint()
        if last_checkpoint_path is not None:
            print 'Previous saved tensorflow objects found... Extracting...'
            # restore the tensorflow variables
            self.saver.restore(self.sess, last_checkpoint_path)
            print 'Extraction Complete. Moving Forward....'
variable_mgr.py 文件源码 项目:benchmarks 作者: tensorflow 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def savable_variables(self):
    """Returns a list/dict of savable variables to pass to tf.train.Saver."""
    params = {}
    for v in tf.global_variables():
      assert (v.name.startswith(variable_mgr_util.PS_SHADOW_VAR_PREFIX + '/v0/')
              or v.name in ('global_step:0', 'loss_scale:0',
                            'loss_scale_normal_steps:0')), (
                                'Invalid global variable: %s' % v)
      # We store variables in the checkpoint with the shadow variable prefix
      # removed so we can evaluate checkpoints in non-distributed replicated
      # mode. The checkpoints can also be loaded for training in
      # distributed_replicated mode.
      name = self._strip_port(self._remove_shadow_var_prefix_if_present(v.name))
      params[name] = v
    for v in tf.local_variables():
      # Non-trainable variables, such as batch norm moving averages, do not have
      # corresponding global shadow variables, so we add them here. Trainable
      # local variables have corresponding global shadow variables, which were
      # added in the global variable loop above.
      if v.name.startswith('v0/') and v not in tf.trainable_variables():
        params[self._strip_port(v.name)] = v
    return params
test_dbinterface.py 文件源码 项目:tfutils 作者: neuroailab 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def test_remap_var_list(self):

        # Get a test `var_list` {var.name: var}
        var_list = {var.op.name: var for var in tf.global_variables()}

        # Specify mapping from old var names to new ones.
        mapping = {'model_0/Weights': 'model_0/Filters'}
        self.dbinterface.load_param_dict = mapping

        # Perform the mapping.
        mapped_vars = self.dbinterface.remap_var_list(var_list)

        # Confirm that the mapping has been done correctly.
        for name, var in mapped_vars.items():
            self.log.info('{} mapped to {}'.format(name, var.op.name))
            if name == 'model_0/Filters':
                self.assertEqual(name, mapping[var.op.name])
cnnpredictor.py 文件源码 项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def __init__(self, session, model_scope, result_dir, result_file, k=1):
        """
        Args:
            model_scope: The variable_scope used for the trained model to be restored.
            session: The TensorFlow session used to run the prediction.
            result_dir: The full path to the folder in which the result file locates.
            result_file: The file that saves the training results.
            k: Optional. Number of elements to be predicted.
        """
        tf.train.import_meta_graph(os.path.join(result_dir, result_file + ".meta"))
        all_vars = tf.global_variables()
        model_vars = [var for var in all_vars if var.name.startswith(model_scope)]
        saver = tf.train.Saver(model_vars)
        saver.restore(session, os.path.join(result_dir, result_file))

        # Retrieve the Ops we 'remembered'.
        logits = tf.get_collection(model_scope+"logits")[0]
        self.images_placeholder = tf.get_collection(model_scope+"images")[0]
        self.keep_prob_placeholder = tf.get_collection(model_scope+"keep_prob")[0]

        # Add an Op that chooses the top k predictions. Apply softmax so that
        # we can have the probabilities (percentage) in the output.
        self.eval_op = tf.nn.top_k(tf.nn.softmax(logits), k=k)
        self.session = session
tracking.py 文件源码 项目:PyMDNet 作者: HungWei-Andy 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def tracking(dataset, seq, display, restore_path):
  train_data = reader.read_seq(dataset, seq)
  im_size = proc.load_image(train_data.data[seq].frames[0]).shape[:2]
  config = Config(im_size)

  # create session and saver
  gpu_config = tf.ConfigProto(allow_soft_placement=True)
  sess = tf.InteractiveSession(config=gpu_config)

  # load model, weights
  model = MDNet(config)
  model.build_generator(config.batch_size, reuse=False, dropout=True)
  tf.global_variables_initializer().run()

  # create saver
  saver = tf.train.Saver([v for v in tf.global_variables() if ('conv' in v.name or 'fc4' in v.name or 'fc5' in v.name) \
                          and 'lr_rate' not in v.name], max_to_keep=50)

  # restore from model
  saver.restore(sess, restore_path)

  # run mdnet
  mdnet_run(sess, model, train_data.data[seq].gts[0], train_data.data[seq].frames, config, display)
shalo_base.py 文件源码 项目:shalo 作者: henryre 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def load(self, model_name, verbose=True):
        """Load TensorFlow model from file
            @model_name: save file names
            @verbose: be talkative?
        """
        self.load_info(model_name)
        self._build()
        load_dict = self.save_dict or tf.global_variables()
        saver = tf.train.Saver(load_dict)
        ckpt = tf.train.get_checkpoint_state('./')
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(self.session, ckpt.model_checkpoint_path)
            if verbose:
                print("[{0}] Loaded model <{1}>".format(self.name, model_name))
        else:
            raise Exception("[{0}] No model found at <{1}>".format(
                self.name, model_name
            ))
utils_tf.py 文件源码 项目:cleverhans 作者: tensorflow 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def initialize_uninitialized_global_variables(sess):
    """
    Only initializes the variables of a TensorFlow session that were not
    already initialized.
    :param sess: the TensorFlow session
    :return:
    """
    # List all global variables
    global_vars = tf.global_variables()

    # Find initialized status for all variables
    is_var_init = [tf.is_variable_initialized(var) for var in global_vars]
    is_initialized = sess.run(is_var_init)

    # List all variables that were not initialized previously
    not_initialized_vars = [var for (var, init) in
                            zip(global_vars, is_initialized) if not init]

    # Initialize all uninitialized variables found, if any
    if len(not_initialized_vars):
        sess.run(tf.variables_initializer(not_initialized_vars))
caption_gen.py 文件源码 项目:Caption-Generation 作者: m516825 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def build_model(self):
        self.model = classmap[FLAGS.model_type](hidden_size=FLAGS.hidden, 
                                    vocab_size=self.vocab_size, 
                                    encoder_in_size=self.data.feats.shape[-1], 
                                    encoder_in_length=self.data.feats.shape[1],
                                    decoder_in_length=self.data.decoder_in.shape[-1] - 1, 
                                    word2vec_weight=self.w2v_W,
                                    embedding_size=FLAGS.embedding_dim,
                                    neg_sample_num=self.sample_num,
                                    start_id=self.vocab_processor._mapping['<BOS>'],
                                    end_id=self.vocab_processor._mapping['<EOS>'],
                                    Bk=FLAGS.K)
        self.global_step = tf.Variable(0, name='global_step', trainable=False)

        self.optimizer = tf.train.RMSPropOptimizer(FLAGS.lr)

        tvars = tf.trainable_variables()

        grads, _ = tf.clip_by_global_norm(tf.gradients(self.model.cost, tvars), 5)

        self.updates = self.optimizer.apply_gradients(
                        zip(grads, tvars), global_step=self.global_step)
        self.saver = tf.train.Saver(tf.global_variables())
demo.py 文件源码 项目:tf-sr-zoo 作者: MLJejuCamp2017 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def demo(lr_image, hr_image):
    model_sr = LapSRN(mode = 'demo')
    hr_images_fake, residuals = model_sr.construct_net(lr_image, hr_image)
    ckpt_path = tf.train.latest_checkpoint('checkpoint')
    print(ckpt_path)
    restorer = tf.train.Saver(tf.global_variables())
    with tf.Session() as sess:
        restorer.restore(sess, ckpt_path)
        hr_image_fake_level_2 = hr_images_fake['hr_image_fake_level_1']+residuals['residual_level_1']
        hr_image_fake_level_2 = tf.clip_by_value(hr_image_fake_level_2, 0, 1)
        hr_image_fake_level_2 = sess.run(hr_image_fake_level_2)
        hr_image_fake_level_2 = hr_image_fake_level_2.squeeze()
        lr_image = sess.run(lr_image)
        lr_image = lr_image.squeeze()
        hr_image = sess.run(hr_image)
    psnr_value = psnr(hr_image.squeeze(), hr_image_fake_level_2.squeeze())
    print(psnr_value)
    imshow(hr_image.squeeze())
    imshow(hr_image_fake_level_2)
demo.py 文件源码 项目:tf-sr-zoo 作者: MLJejuCamp2017 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def demo(img_path):
    lr_img, hr_img = imgread(img_path)
    model = pix2pix_model(cfg)
    model.test_model(lr_img, hr_img)
    ckpt_path = tf.train.latest_checkpoint('checkpoint')
    restorer = tf.train.Saver(tf.global_variables())
    with tf.Session() as sess:
        restorer.restore(sess, ckpt_path)
        hr_image_fake = model.fake_hr_image
        hr_image_fake = tf.clip_by_value(hr_image_fake, 0, 1)
        hr_image_fake = sess.run(hr_image_fake)
        hr_image_fake = hr_image_fake.squeeze()
        hr_image = sess.run(hr_img)
    psnr_value = psnr(hr_image.squeeze(), hr_image_fake.squeeze())
    print(psnr_value)
    imshow(hr_image_fake)
    imshow(hr_image.squeeze())
dense_net_3d.py 文件源码 项目:3d-DenseNet 作者: frankgu 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _initialize_session(self):
    """Initialize session, variables, saver"""
    config = tf.ConfigProto()
    # restrict model GPU memory utilization to min required
    config.gpu_options.allow_growth = True
    self.sess = tf.Session(config=config)
    tf_ver = int(tf.__version__.split('.')[1])
    if TF_VERSION <= 0.10:
      self.sess.run(tf.initialize_all_variables())
      logswriter = tf.train.SummaryWriter
    else:
      self.sess.run(tf.global_variables_initializer())
      logswriter = tf.summary.FileWriter
    self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=0)
    self.summary_writer = logswriter(self.logs_path, self.sess.graph)

  # (Updated)
tensorflow_backend.py 文件源码 项目:keras 作者: GeekLiB 项目源码 文件源码 阅读 38 收藏 0 点赞 0 评论 0
def _initialize_variables():
    if hasattr(tf, 'global_variables'):
        variables = tf.global_variables()
    else:
        variables = tf.all_variables()

    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = get_session()
        if hasattr(tf, 'variables_initializer'):
            sess.run(tf.variables_initializer(uninitialized_variables))
        else:
            sess.run(tf.initialize_variables(uninitialized_variables))
help.py 文件源码 项目:tensorflow-yolo 作者: hjimce 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def to_darknet(self):
    darknet_ckpt = self.darknet

    with self.graph.as_default() as g:
        for var in tf.global_variables():
            name = var.name.split(':')[0]
            var_name = name.split('-')
            l_idx = int(var_name[0])
            w_sig = var_name[1].split('/')[-1]
            l = darknet_ckpt.layers[l_idx]
            l.w[w_sig] = var.eval(self.sess)

    for layer in darknet_ckpt.layers:
        for ph in layer.h:
            layer.h[ph] = None

    return darknet_ckpt
importer.py 文件源码 项目:ngraph 作者: NervanaSystems 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def get_restore_op(self):
        """
        Get variable restoring ngraph op from TF model checkpoint

        Returns:
            A `ng.doall` op that restores the stored weights in TF model
            checkpoint
        """
        if self._graph is None:
            raise ValueError("self._graph is None, import meta_graph first.")
        if self._checkpoint_path is None:
            raise ValueError("self._checkpoint_path is None, please specify"
                             "checkpoint_path while importing meta_graph.")
        with self._graph.as_default():
            tf_variables = tf.global_variables()
            ng_variables = self.get_op_handle(tf_variables)
            ng_restore_ops = []
            with tf.Session() as sess:
                checkpoint_path = os.path.join(os.getcwd(),
                                               self._checkpoint_path)
                self.saver.restore(sess, checkpoint_path)
                for tf_variable, ng_variable in zip(tf_variables, ng_variables):
                    val = sess.run(tf_variable)
                    ng_restore_ops.append(ng.assign(ng_variable, val))
            return ng.doall(ng_restore_ops)
basic_test.py 文件源码 项目:sonnet 作者: deepmind 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def testCustomGetter(self):
    """Check that custom getters work appropriately."""

    def custom_getter(getter, *args, **kwargs):
      kwargs["trainable"] = False
      return getter(*args, **kwargs)

    inputs = tf.placeholder(tf.float32, shape=[self.batch_size, self.in_size])

    # Make w and b non-trainable.
    lin1 = snt.Linear(output_size=self.out_size,
                      custom_getter=custom_getter)
    lin1(inputs)
    self.assertEqual(0, len(tf.trainable_variables()))
    self.assertEqual(2, len(tf.global_variables()))

    # Make w non-trainable.
    lin2 = snt.Linear(output_size=self.out_size,
                      custom_getter={"w": custom_getter})
    lin2(inputs)
    self.assertEqual(1, len(tf.trainable_variables()))
    self.assertEqual(4, len(tf.global_variables()))
util.py 文件源码 项目:sonnet 作者: deepmind 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def _get_vars_to_collections(variables):
  """Returns a dict mapping variables to the collections they appear in."""
  var_to_collections = collections.defaultdict(lambda: [])
  if isinstance(variables, dict):
    variables = list(v for _, v in variable_map_items(variables))
  for graph in set(v.graph for v in variables):
    for collection_name in list(graph.collections):
      entries = set(entry for entry in graph.get_collection(collection_name)
                    if isinstance(entry, tf.Variable))
      # For legacy reasons, tf.GraphKeys.GLOBAL_VARIABLES == "variables".
      # Correcting for this here, to avoid confusion.
      if collection_name == tf.GraphKeys.GLOBAL_VARIABLES:
        collection_name = "global_variables"
      for var in entries.intersection(variables):
        var_to_collections[var].append(collection_name)
  return var_to_collections
policy.py 文件源码 项目:MuGo 作者: brilee 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def initialize_variables(self, save_file=None):
        self.session.run(tf.global_variables_initializer())
        if save_file is not None:
            try:
                self.saver.restore(self.session, save_file)
            except:
                # some wizardry here... basically, only restore variables
                # that are in the save file; otherwise, initialize them normally.
                from tensorflow.python.framework import meta_graph
                meta_graph_def = meta_graph.read_meta_graph_file(save_file + '.meta')
                stored_var_names = set([n.name
                    for n in meta_graph_def.graph_def.node
                    if n.op == 'VariableV2'])
                print(stored_var_names)
                var_list = [v for v in tf.global_variables()
                    if v.op.name in stored_var_names]
                # initialize all of the variables
                self.session.run(tf.global_variables_initializer())
                # then overwrite the ones we have in the save file
                # by using a throwaway saver, saved models are automatically
                # "upgraded" to the latest graph definition.
                throwaway_saver = tf.train.Saver(var_list=var_list)
                throwaway_saver.restore(self.session, save_file)
util.py 文件源码 项目:tefla 作者: litan 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def dump_vars(sess, trainable_scopes=None):
    all_vars = set(tf.global_variables())
    trainable_vars = set(trainable_variables(trainable_scopes))
    non_trainable_vars = all_vars.difference(trainable_vars)

    def _dump_set(var_set):
        names_vars = map(lambda v: (v.name, v), var_set)
        for n, v in sorted(names_vars, key=lambda nv: nv[0]):
            print("%s=%s" % (n, sess.run(v)))

    print("Variable values:")
    print("-----------")
    print("\n---Trainable vars:")
    _dump_set(trainable_vars)
    print("\n---Non Trainable vars:")
    _dump_set(non_trainable_vars)
    print("-----------")
util.py 文件源码 项目:tefla 作者: litan 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def show_vars(logger=None, trainable_scopes=None):
    printer = logger.info if logger is not None else print
    all_vars = set(tf.global_variables())
    trainable_vars = set(trainable_variables(trainable_scopes))
    non_trainable_vars = all_vars.difference(trainable_vars)
    local_vars = set(tf.local_variables())

    class nonlocal: pass

    nonlocal.num_params = {}

    def show_var_info(vars, var_type):
        printer('\n---%s vars in model:' % var_type)
        name_shapes = map(lambda v: (v.name, v.get_shape()), vars)
        total_params = 0
        for n, s in sorted(name_shapes, key=lambda ns: ns[0]):
            printer('%s %s' % (n, s))
            total_params += np.prod(s.as_list())
        nonlocal.num_params[var_type] = total_params

    show_var_info(trainable_vars, 'Trainable')
    show_var_info(non_trainable_vars, 'Non Trainable')
    show_var_info(local_vars, 'Local')
    printer('Total number of params:')
    printer(pprint.pformat(nonlocal.num_params))
rainbow.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def save(self, name):
        directory = 'saves/' + name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        directory += 'iteration_{}'.format(self.timestep) + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i, tensor in enumerate(tf.global_variables()):
            value = self.sess.run(tensor)
            np.save(directory + 'weight_{}'.format(i), value)

        if self.scale:
            np.save(directory + 'sums', self.sums)
            np.save(directory + 'sumsquares', self.sumsqrs)
            np.save(directory + 'sumtime', self.sumtime)

        np.save(directory + 'timestep', np.array([self.timestep]))
        np.save(directory + 'train_scores', np.array(self.train_scores))
        np.save(directory + 'test_scores', np.array(self.test_scores))
        print("Agent successfully saved in folder {}".format(directory))
rainbow.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 45 收藏 0 点赞 0 评论 0
def load(self, name, iteration=None):
        try:
            directory = 'saves/' + name + '/'
            if not os.path.exists(directory):
                print('That directory does not exist!')
                raise Exception
            if iteration is None:
                iteration = np.max([int(x[10:]) for x in [dir for dir in os.walk(directory)][0][1]])
            directory += 'iteration_{}'.format(iteration) + '/'

            for i, tensor in enumerate(tf.global_variables()):
                arr = np.load(directory + 'weight_{}.npy'.format(i))
                self.sess.run(tensor.assign(arr))

            if self.scale:
                self.sums = np.load(directory + 'sums.npy')
                self.sumsqrs = np.load(directory + 'sumsquares.npy')
                self.sumtime = np.load(directory + 'sumtime.npy')

            self.timestep = np.load(directory + 'timestep.npy')[0]
            self.train_scores = np.load(directory + 'train_scores.npy').tolist()
            self.test_scores = np.load(directory + 'test_scores.npy').tolist()
            print("Agent successfully loaded from folder {}".format(directory))
        except:
            print("Something is wrong, loading failed")
ddpg_distributed.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def save(self, name):
        directory = 'saves/' + name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        directory += 'iteration_{}'.format(self.timestep) + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i, w in enumerate(tf.global_variables()):
            np.save(directory + 'weight_{}'.format(i), self.sess.run(w))

        if self.scale:
            np.save(directory + 'sums', self.sums)
            np.save(directory + 'sumsquares', self.sumsqrs)
            np.save(directory + 'sumtime', self.sumtime)

        np.save(directory + 'timestep', np.array([self.timestep]))
        np.save(directory + 'train_scores', np.array(self.train_scores))
        np.save(directory + 'test_scores', np.array(self.test_scores))

        print("Agent successfully saved in folder {}".format(directory))
ddpg_single.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def save(self, name):
        directory = 'saves/' + name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        directory += 'iteration_{}'.format(self.timestep) + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i, w in enumerate(tf.global_variables()):
            np.save(directory + 'weight_{}'.format(i), self.sess.run(w))

        if self.scale!='off':
            np.save(directory + 'sums', self.sums)
            np.save(directory + 'sumsquares', self.sumsqrs)
            np.save(directory + 'sumtime', self.sumtime)

        np.save(directory + 'timestep', np.array([self.timestep]))
        np.save(directory + 'train_scores', np.array(self.train_scores))
        np.save(directory + 'test_scores', np.array(self.test_scores))

        print("Agent successfully saved in folder {}".format(directory))
a3c_continuous.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def save(self, name):
        directory = 'saves/' + name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        directory += 'iteration_{}'.format(self.timestep) + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i, tensor in enumerate(tf.global_variables()):
            value = self.sess.run(tensor)
            np.save(directory + 'weight_{}'.format(i), value)

        if self.scale != 'off':
            np.save(directory + 'sums', self.sums)
            np.save(directory + 'sumsquares', self.sumsqrs)
            np.save(directory + 'sumtime', self.sumtime)

        np.save(directory + 'timestep', np.array([self.timestep]))
        np.save(directory + 'train_scores', np.array(self.train_scores))
        np.save(directory + 'test_scores', np.array(self.test_scores))
        print("Agent successfully saved in folder {}".format(directory))
trpo_continuous.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def save(self, name):
        directory = 'saves/' + name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        directory += 'iteration_{}'.format(self.timestep) + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i, tensor in enumerate(tf.global_variables()):
            value = self.sess.run(tensor)
            np.save(directory + 'weight_{}'.format(i), value)

        if self.scale != 'off':
            np.save(directory + 'sums', self.sums)
            np.save(directory + 'sumsquares', self.sumsqrs)
            np.save(directory + 'sumtime', self.sumtime)

        np.save(directory + 'timestep', np.array([self.timestep]))
        np.save(directory + 'train_scores', np.array(self.train_scores))
        np.save(directory + 'test_scores', np.array(self.test_scores))
        print("Agent successfully saved in folder {}".format(directory))
trpo_continuous.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def load(self, name, iteration=None):
        try:
            directory = 'saves/' + name + '/'
            if not os.path.exists(directory):
                print('That directory does not exist!')
                raise Exception
            if iteration is None:
                iteration = np.max([int(x[10:]) for x in [dir for dir in os.walk(directory)][0][1]])
            directory += 'iteration_{}'.format(iteration) + '/'

            for i, tensor in enumerate(tf.global_variables()):
                arr = np.load(directory + 'weight_{}.npy'.format(i))
                self.sess.run(tensor.assign(arr))

            if self.scale != 'off':
                self.sums = np.load(directory + 'sums.npy')
                self.sumsqrs = np.load(directory + 'sumsquares.npy')
                self.sumtime = np.load(directory + 'sumtime.npy')

            self.timestep = np.load(directory + 'timestep.npy')[0]
            self.train_scores = np.load(directory + 'train_scores.npy').tolist()
            self.test_scores = np.load(directory + 'test_scores.npy').tolist()
            print("Agent successfully loaded from folder {}".format(directory))
        except:
            print("Something is wrong, loading failed")
a3c_discrete.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def save(self, name):
        directory = 'saves/' + name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        directory += 'iteration_{}'.format(self.timestep) + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i, tensor in enumerate(tf.global_variables()):
            value = self.sess.run(tensor)
            np.save(directory + 'weight_{}'.format(i), value)

        if self.scale != 'off':
            np.save(directory + 'sums', self.sums)
            np.save(directory + 'sumsquares', self.sumsqrs)
            np.save(directory + 'sumtime', self.sumtime)

        np.save(directory + 'timestep', np.array([self.timestep]))
        np.save(directory + 'train_scores', np.array(self.train_scores))
        np.save(directory + 'test_scores', np.array(self.test_scores))
        print("Agent successfully saved in folder {}".format(directory))
trpo_discrete.py 文件源码 项目:SRLF 作者: Fritz449 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def save(self, name):
        directory = 'saves/' + name + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)
        directory += 'iteration_{}'.format(self.timestep) + '/'
        if not os.path.exists(directory):
            os.makedirs(directory)

        for i, tensor in enumerate(tf.global_variables()):
            value = self.sess.run(tensor)
            np.save(directory + 'weight_{}'.format(i), value)

        if self.scale != 'off':
            np.save(directory + 'sums', self.sums)
            np.save(directory + 'sumsquares', self.sumsqrs)
            np.save(directory + 'sumtime', self.sumtime)

        np.save(directory + 'timestep', np.array([self.timestep]))
        np.save(directory + 'train_scores', np.array(self.train_scores))
        np.save(directory + 'test_scores', np.array(self.test_scores))
        print("Agent successfully saved in folder {}".format(directory))
model.py 文件源码 项目:AM-GAN 作者: ZhimingZhou 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def model_initilization(self, cfg):

        ############################################################################################################################################
        def initialization():
            var_list = tf.global_variables()
            for var in var_list:
                self.sess.run(tf.variables_initializer([var]), feed_dict={self.z: self.sample_z[:cfg.iBatchSize], self.images_lab: self.sample_images[:cfg.iBatchSize], self.fInputNoise: cfg.fInputNoise})
                print(var.op.name)

            #self.sess.run(tf.initialize_all_tables(), feed_dict={self.z: self.sample_z[:cfg.iBatchSize], self.images_lab: self.sample_images[:cfg.iBatchSize], self.fInputNoise: cfg.fInputNoiseBiG})

        print('optimizor initialization')

        if cfg.bLoadCheckpoint:
            if self.load(cfg):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
                initialization()
        else:
            initialization()
tensorflow_backend.py 文件源码 项目:deep-learning-keras-projects 作者: jasmeetsb 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _initialize_variables():
    if hasattr(tf, 'global_variables'):
        variables = tf.global_variables()
    else:
        variables = tf.all_variables()

    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = get_session()
        if hasattr(tf, 'variables_initializer'):
            sess.run(tf.variables_initializer(uninitialized_variables))
        else:
            sess.run(tf.initialize_variables(uninitialized_variables))


问题


面经


文章

微信
公众号

扫码关注公众号