python类all_variables()的实例源码

birds_skip_thought_demo.py 文件源码 项目:how_to_convert_text_to_images 作者: llSourcell 项目源码 文件源码 阅读 42 收藏 0 点赞 0 评论 0
def build_model(sess, embedding_dim, batch_size):
    model = CondGAN(
        lr_imsize=cfg.TEST.LR_IMSIZE,
        hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))

    embeddings = tf.placeholder(
        tf.float32, [batch_size, embedding_dim],
        name='conditional_embeddings')
    with pt.defaults_scope(phase=pt.Phase.test):
        with tf.variable_scope("g_net"):
            c = sample_encoded_context(embeddings, model)
            z = tf.random_normal([batch_size, cfg.Z_DIM])
            fake_images = model.get_generator(tf.concat(1, [c, z]))
        with tf.variable_scope("hr_g_net"):
            hr_c = sample_encoded_context(embeddings, model)
            hr_fake_images = model.hr_get_generator(fake_images, hr_c)

    ckt_path = cfg.TEST.PRETRAINED_MODEL
    if ckt_path.find('.ckpt') != -1:
        print("Reading model parameters from %s" % ckt_path)
        saver = tf.train.Saver(tf.all_variables())
        saver.restore(sess, ckt_path)
    else:
        print("Input a valid model path.")
    return embeddings, fake_images, hr_fake_images
trainer.py 文件源码 项目:how_to_convert_text_to_images 作者: llSourcell 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def build_model(self, sess):
        self.init_opt()
        sess.run(tf.initialize_all_variables())

        if len(self.model_path) > 0:
            print("Reading model parameters from %s" % self.model_path)
            restore_vars = tf.all_variables()
            # all_vars = tf.all_variables()
            # restore_vars = [var for var in all_vars if
            #                 var.name.startswith('g_') or
            #                 var.name.startswith('d_')]
            saver = tf.train.Saver(restore_vars)
            saver.restore(sess, self.model_path)

            istart = self.model_path.rfind('_') + 1
            iend = self.model_path.rfind('.')
            counter = self.model_path[istart:iend]
            counter = int(counter)
        else:
            print("Created model with fresh parameters.")
            counter = 0
        return counter
trainer.py 文件源码 项目:how_to_convert_text_to_images 作者: llSourcell 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def build_model(self, sess):
        self.init_opt()

        sess.run(tf.initialize_all_variables())
        if len(self.model_path) > 0:
            print("Reading model parameters from %s" % self.model_path)
            all_vars = tf.trainable_variables()
            # all_vars = tf.all_variables()
            restore_vars = []
            for var in all_vars:
                if var.name.startswith('g_') or var.name.startswith('d_'):
                    restore_vars.append(var)
                    # print(var.name)
            saver = tf.train.Saver(restore_vars)
            saver.restore(sess, self.model_path)

            istart = self.model_path.rfind('_') + 1
            iend = self.model_path.rfind('.')
            counter = self.model_path[istart:iend]
            counter = int(counter)
        else:
            print("Created model with fresh parameters.")
            counter = 0
        return counter
tflogger.py 文件源码 项目:PyTorchDemystified 作者: hhsecond 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def __initialize(self):
        sess = tf.Session()
        loss = tf.Variable(0.0, name="loss", trainable=False)
        acc = tf.Variable(0.0, name="accuracy", trainable=False)
        loss_summary = tf.summary.scalar("loss", loss)
        acc_summary = tf.summary.scalar("accuracy", acc)
        summary_op = tf.summary.merge([loss_summary, acc_summary])
        summary_writer = tf.summary.FileWriter(self.summary_dir, sess.graph)
        tf.train.Saver(tf.all_variables())
        sess.run(tf.initialize_all_variables())

        self.sess = sess
        self.summary_op = summary_op
        self.summary_writer = summary_writer
        self.loss = loss
        self.acc = acc
tensorflow_backend.py 文件源码 项目:keras 作者: GeekLiB 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _initialize_variables():
    if hasattr(tf, 'global_variables'):
        variables = tf.global_variables()
    else:
        variables = tf.all_variables()

    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = get_session()
        if hasattr(tf, 'variables_initializer'):
            sess.run(tf.variables_initializer(uninitialized_variables))
        else:
            sess.run(tf.initialize_variables(uninitialized_variables))
nn_q_table.py 文件源码 项目:drivebot 作者: matpalm 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def copy_all_vars(from_namespace, to_namespace, affine_coefficient=1.0):
    assert affine_coefficient >= 0.0 and affine_coefficient <= 1.0
    copy_ops = []
    with tf.variable_scope("", reuse=True):  # for grabbing the targets by full namespace
        for src_var in tf.all_variables():
            # ignore any variable not in src namespace
            if not src_var.name.startswith(from_namespace):
                continue
            # fetch reference to target variable with the same name as the src variable
            assert src_var.name.endswith(":0")
            target_var_name = src_var.name.replace(from_namespace, to_namespace).replace(":0", "")
            target_var = tf.get_variable(target_var_name, src_var.get_shape())
            # create a copy op to clobber target with src
            # target = alpha * src + (1.0-alpha) * target
            copy_ops.append(target_var.assign_sub(affine_coefficient * (target_var - src_var)))
    single_copy_op = tf.group(*copy_ops)
    return single_copy_op
test.py 文件源码 项目:automatic-portrait-tf 作者: Corea 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def test(net, image_name):
    image = build_image(image_name)

    with tf.Session() as sess:
        saver = tf.train.Saver(tf.all_variables())
        model_file = tf.train.latest_checkpoint('./model/')
        if model_file:
            saver.restore(sess, model_file)
        else:
            raise Exception('Testing needs pre-trained model!')

        feed_dict = {
            net['image']: image,
            net['drop_rate']: 1
        }
        result = sess.run(tf.argmax(net['score'], dimension=3),
                          feed_dict=feed_dict)
    return result
trainer.py 文件源码 项目:StackGAN 作者: hanzhanggit 项目源码 文件源码 阅读 51 收藏 0 点赞 0 评论 0
def build_model(self, sess):
        self.init_opt()
        sess.run(tf.initialize_all_variables())

        if len(self.model_path) > 0:
            print("Reading model parameters from %s" % self.model_path)
            restore_vars = tf.all_variables()
            # all_vars = tf.all_variables()
            # restore_vars = [var for var in all_vars if
            #                 var.name.startswith('g_') or
            #                 var.name.startswith('d_')]
            saver = tf.train.Saver(restore_vars)
            saver.restore(sess, self.model_path)

            istart = self.model_path.rfind('_') + 1
            iend = self.model_path.rfind('.')
            counter = self.model_path[istart:iend]
            counter = int(counter)
        else:
            print("Created model with fresh parameters.")
            counter = 0
        return counter
trainer.py 文件源码 项目:StackGAN 作者: hanzhanggit 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def build_model(self, sess):
        self.init_opt()

        sess.run(tf.initialize_all_variables())
        if len(self.model_path) > 0:
            print("Reading model parameters from %s" % self.model_path)
            all_vars = tf.trainable_variables()
            # all_vars = tf.all_variables()
            restore_vars = []
            for var in all_vars:
                if var.name.startswith('g_') or var.name.startswith('d_'):
                    restore_vars.append(var)
                    # print(var.name)
            saver = tf.train.Saver(restore_vars)
            saver.restore(sess, self.model_path)

            istart = self.model_path.rfind('_') + 1
            iend = self.model_path.rfind('.')
            counter = self.model_path[istart:iend]
            counter = int(counter)
        else:
            print("Created model with fresh parameters.")
            counter = 0
        return counter
birds_skip_thought_demo.py 文件源码 项目:StackGAN 作者: hanzhanggit 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def build_model(sess, embedding_dim, batch_size):
    model = CondGAN(
        lr_imsize=cfg.TEST.LR_IMSIZE,
        hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))

    embeddings = tf.placeholder(
        tf.float32, [batch_size, embedding_dim],
        name='conditional_embeddings')
    with pt.defaults_scope(phase=pt.Phase.test):
        with tf.variable_scope("g_net"):
            c = sample_encoded_context(embeddings, model)
            z = tf.random_normal([batch_size, cfg.Z_DIM])
            fake_images = model.get_generator(tf.concat(1, [c, z]))
        with tf.variable_scope("hr_g_net"):
            hr_c = sample_encoded_context(embeddings, model)
            hr_fake_images = model.hr_get_generator(fake_images, hr_c)

    ckt_path = cfg.TEST.PRETRAINED_MODEL
    if ckt_path.find('.ckpt') != -1:
        print("Reading model parameters from %s" % ckt_path)
        saver = tf.train.Saver(tf.all_variables())
        saver.restore(sess, ckt_path)
    else:
        print("Input a valid model path.")
    return embeddings, fake_images, hr_fake_images
graph_handler.py 文件源码 项目:bi-att-flow 作者: allenai 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _load(self, sess):
        config = self.config
        vars_ = {var.name.split(":")[0]: var for var in tf.all_variables()}
        if config.load_ema:
            ema = self.model.var_ema
            for var in tf.trainable_variables():
                del vars_[var.name.split(":")[0]]
                vars_[ema.average_name(var)] = var
        saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep)

        if config.load_path:
            save_path = config.load_path
        elif config.load_step > 0:
            save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
        else:
            save_dir = config.save_dir
            checkpoint = tf.train.get_checkpoint_state(save_dir)
            assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
            save_path = checkpoint.model_checkpoint_path
        print("Loading saved model from {}".format(save_path))
        saver.restore(sess, save_path)
graph_handler.py 文件源码 项目:Chinese-QA 作者: distantJing 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def _load(self, sess):
        config = self.config
        vars_ = {var.name.split(":")[0]: var for var in tf.all_variables()}
        if config.load_ema:
            ema = self.model.var_ema
            for var in tf.trainable_variables():
                del vars_[var.name.split(":")[0]]
                vars_[ema.average_name(var)] = var
        saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep)

        if config.load_path:
            save_path = config.load_path
        elif config.load_step > 0:
            save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
        else:
            save_dir = config.save_dir
            checkpoint = tf.train.get_checkpoint_state(save_dir)
            assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
            save_path = checkpoint.model_checkpoint_path
        print("Loading saved model from {}".format(save_path))
        saver.restore(sess, save_path)
tensorflow_backend.py 文件源码 项目:deep-learning-keras-projects 作者: jasmeetsb 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _initialize_variables():
    if hasattr(tf, 'global_variables'):
        variables = tf.global_variables()
    else:
        variables = tf.all_variables()

    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = get_session()
        if hasattr(tf, 'variables_initializer'):
            sess.run(tf.variables_initializer(uninitialized_variables))
        else:
            sess.run(tf.initialize_variables(uninitialized_variables))
sample.py 文件源码 项目:sequelspeare 作者: raidancampbell 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def __init__(self, save_dir=SAVE_DIR, prime_text=PRIME_TEXT, num_sample_symbols=NUM_SAMPLE_SYMBOLS):
        self.save_dir = save_dir
        self.prime_text = prime_text
        self.num_sample_symbols = num_sample_symbols
        with open(os.path.join(Sampler.SAVE_DIR, 'chars_vocab.pkl'), 'rb') as file:
            self.chars, self.vocab = cPickle.load(file)
            self.model = Model(len(self.chars), is_sampled=True)

            # polite GPU memory allocation: don't grab everything you can.
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            config.gpu_options.allocator_type = 'BFC'
            self.sess = tf.Session(config=config)

            tf.initialize_all_variables().run(session=self.sess)
            self.checkpoint = tf.train.get_checkpoint_state(self.save_dir)
            if self.checkpoint and self.checkpoint.model_checkpoint_path:
                tf.train.Saver(tf.all_variables()).restore(self.sess, self.checkpoint.model_checkpoint_path)
resnet.py 文件源码 项目:bone-age 作者: radinformatics 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def __init__(self, checkpoint_path):
        layers = 50
        num_blocks = [3, 4, 6, 3]
        self.inference = lambda images, is_train : inference(images, 
                                                   is_training=is_train, 
                                                   num_classes=NUM_AGES*2,
                                                   num_blocks=num_blocks, 
                                                   bottleneck=True)

        self.x = tf.placeholder(tf.uint8, shape=(256,256,3), name='input_image')
        self.crops = fixed_crops(self.x)
        self.logits = self.inference(self.crops, is_train=False)
        self.pred = tf.nn.softmax(self.logits, name='prediction')

        # Restore saved weights
        restore_variables = tf.trainable_variables() \
                + tf.moving_average_variables()
        self.saver = tf.train.Saver(restore_variables)
        self.sess = tf.Session()
        self.saver.restore(self.sess, checkpoint_path)

        #self.sess.run(tf.initialize_variables([var for var \
        #        in tf.all_variables() if var not in restore_variables]))
layer.py 文件源码 项目:Dialog-System-with-GAN-model 作者: drcut 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def print_all_variables(train_only=False):
    """Print all trainable and non-trainable variables
    without tl.layers.initialize_global_variables(sess)
    Parameters
    ----------
    train_only : boolean
        If True, only print the trainable variables, otherwise, print all variables.
    """
    if train_only:
        t_vars = tf.trainable_variables()
        print("  [*] printing trainable variables")
    else:
        try: # TF1.0
            t_vars = tf.global_variables()
        except: # TF0.12
            t_vars = tf.all_variables()
        print("  [*] printing global variables")
    for idx, v in enumerate(t_vars):
        print("  var {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))
layer.py 文件源码 项目:Dialog-System-with-GAN-model 作者: drcut 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def get_variables_with_name(name, train_only=True, printable=False):
    """Get variable list by a given name scope.
    >>> dense_vars = tl.layers.get_variable_with_name('dense', True, True)
    """
    print("  [*] geting variables with %s" % name)
    # tvar = tf.trainable_variables() if train_only else tf.all_variables()
    if train_only:
        t_vars = tf.trainable_variables()
    else:
        try: # TF1.0
            t_vars = tf.global_variables()
        except: # TF0.12
            t_vars = tf.all_variables()

    d_vars = [var for var in t_vars if name in var.name]
    if printable:
        for idx, v in enumerate(d_vars):
            print("  got {:3}: {:15}   {}".format(idx, v.name, str(v.get_shape())))
    return d_vars
util.py 文件源码 项目:tefla 作者: openAGI 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def dump_vars(sess):
    all_vars = set(tf.all_variables())
    trainable_vars = set(tf.trainable_variables())
    non_trainable_vars = all_vars.difference(trainable_vars)

    def _dump_set(var_set):
        names_vars = map(lambda v: (v.name, v), var_set)
        for n, v in sorted(names_vars, key=lambda nv: nv[0]):
            print("%s=%s" % (n, sess.run(v)))

    print("Variable values:")
    print("-----------")
    print("\n---Trainable vars:")
    _dump_set(trainable_vars)
    print("\n---Non Trainable vars:")
    _dump_set(non_trainable_vars)
    print("-----------")
sample.py 文件源码 项目:jaylyrics_generation_tensorflow 作者: hundred06 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def sample(args):
    # import configuration
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'rb') as f:
        words, vocab = cPickle.load(f)
    # import the trained model
    model = Model(saved_args, True)
    with tf.Session() as sess:
    # initialize the model
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        # sample the new sequence word by word
            literature = model.sample(sess, words, vocab, args.n, args.start, args.sample)
    with codecs.open('result/sequence.txt','a','utf-8') as f:
        f.write(literature+'\n\n')
    print(literature)
optimizers_test.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def testAdaptiveGradientClip(self):
    with self.test_session() as session:
      x, var, loss, global_step = _setup_model()
      clip_gradients = tf.contrib.layers.adaptive_clipping_fn()
      train = tf.contrib.layers.optimize_loss(loss,
                                              global_step,
                                              learning_rate=0.1,
                                              optimizer="SGD",
                                              clip_gradients=clip_gradients)
      tf.global_variables_initializer().run()
      session.run(train, feed_dict={x: 5})
      var_value, global_step_value = session.run([var, global_step])
      self.assertAlmostEqual(var_value, 9.8916, 4)
      self.assertEqual(global_step_value, 1)
      var_count = 0
      for var in tf.all_variables():
        if var.name.startswith("OptimizeLoss/AdaptiveMaxNorm"):
          var_count += 1
      self.assertEqual(2, var_count)
sample.py 文件源码 项目:char-rnn-tf 作者: liusiqi43 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def main(unused_args):
  with open(os.path.join(FLAGS.session_dir, 'labels.pkl')) as f:
    char_to_id = pickle.load(f)
  with open(os.path.join(FLAGS.session_dir, 'config.pkl')) as f:
    config = pickle.load(f)
  with tf.variable_scope('model'):
    m = CharRNN('infer', config)
  with tf.Session() as sess:
    tf.initialize_all_variables().run()
    saver = tf.train.Saver(tf.all_variables())
    ckpt = tf.train.get_checkpoint_state(FLAGS.session_dir)
    if ckpt and ckpt.model_checkpoint_path:
      saver.restore(sess, ckpt.model_checkpoint_path)
      print(ckpt.model_checkpoint_path, 'restored')

      while True:
        seed = raw_input('seed:')
        start_time = time.time()
        print(m.sample(sess, char_to_id, FLAGS.num_steps, seed))
        print(FLAGS.num_steps / (time.time() - start_time), 'cps')
graph_handler.py 文件源码 项目:adversarial-squad 作者: robinjia 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def _load(self, sess):
        config = self.config
        vars_ = {var.name.split(":")[0]: var for var in tf.all_variables()}
        if config.load_ema:
            ema = self.model.var_ema
            for var in tf.trainable_variables():
                del vars_[var.name.split(":")[0]]
                vars_[ema.average_name(var)] = var
        saver = tf.train.Saver(vars_, max_to_keep=config.max_to_keep)

        if config.load_path:
            save_path = config.load_path
        elif config.load_step > 0:
            save_path = os.path.join(config.save_dir, "{}-{}".format(config.model_name, config.load_step))
        else:
            save_dir = config.save_dir
            checkpoint = tf.train.get_checkpoint_state(save_dir)
            assert checkpoint is not None, "cannot load checkpoint at {}".format(save_dir)
            save_path = checkpoint.model_checkpoint_path
        print("Loading saved model from {}".format(save_path))
        saver.restore(sess, save_path)
sample.py 文件源码 项目:word-rnn-tf 作者: jtoy 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def sample(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    model = Model(saved_args, True)
    val_loss_file = args.save_dir + '/val_loss.json'
    with tf.Session() as sess:
        saver = tf.train.Saver(tf.all_variables())
        if os.path.exists(val_loss_file):
            with open(val_loss_file, "r") as text_file:
                text = text_file.read()
                loss_json = json.loads(text)
                losses = loss_json.keys()
                losses.sort(key=lambda x: float(x))
                loss = losses[0]
                model_checkpoint_path =  loss_json[loss]['checkpoint_path']
                #print(model_checkpoint_path)
                saver.restore(sess, model_checkpoint_path)
                result = model.sample(sess, chars, vocab, args.n, args.prime, args.sample_rule, args.temperature)
                print(result) #add this back in later, not sure why its not working
                output = "/data/output/"+ str(int(time.time())) + ".txt"
                with open(output, "w") as text_file:
                    text_file.write(result)
                print(output)
tensorflow_backend.py 文件源码 项目:keras-customized 作者: ambrite 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _initialize_variables():
    if hasattr(tf, 'global_variables'):
        variables = tf.global_variables()
    else:
        variables = tf.all_variables()

    uninitialized_variables = []
    for v in variables:
        if not hasattr(v, '_keras_initialized') or not v._keras_initialized:
            uninitialized_variables.append(v)
            v._keras_initialized = True
    if uninitialized_variables:
        sess = get_session()
        if hasattr(tf, 'variables_initializer'):
            sess.run(tf.variables_initializer(uninitialized_variables))
        else:
            sess.run(tf.initialize_variables(uninitialized_variables))
core.py 文件源码 项目:tensorlight 作者: bsautermeister 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def uninitialized_variables(session, var_list=None):
    """Gets the list of uninitialized variables.
       Note: this has to be evaluated on a session.
    Parameters
    ----------
    session: tf.Session
        The TensorFlow session to scan for uninitialized variables
    var_list: list(tf.Varaible) or None
        The list of variables to filter for uninitialized ones.
        Defaults to tf.all_variables() is used.
    """
    if var_list is None:
        var_list = tf.all_variables()

    reported_var_names = session.run(tf.report_uninitialized_variables(var_list))
    uninit_vars = []
    for name in reported_var_names:
        try:
            uninit_vars.append(tf.get_variable(name))
        except ValueError:
            print("Failed to collect variable {}. Skipping.", name)

    return uninit_vars
base_optimizer.py 文件源码 项目:Sing_Par 作者: wanghm92 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def variables_to_restore(self, moving_avg_variables=None):
    """"""

    name_map = {}
    if moving_avg_variables is None:
      moving_avg_variables = tf.trainable_variables()
      moving_avg_variables += tf.moving_average_variables()
    # Remove duplicates
    moving_avg_variables = set(moving_avg_variables)
    # Collect all the variables with moving average,
    for v in moving_avg_variables:
      name_map[self.average_name(v)] = v
    # Make sure we restore variables without moving average as well.
    for v in list(set(tf.all_variables()) - moving_avg_variables):
      if v.op.name not in name_map:
        name_map[v.op.name] = v
    return name_map

  #===============================================================
base_optimizer.py 文件源码 项目:Parser-v1 作者: tdozat 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def variables_to_restore(self, moving_avg_variables=None):
    """"""

    name_map = {}
    if moving_avg_variables is None:
      moving_avg_variables = tf.trainable_variables()
      moving_avg_variables += tf.moving_average_variables()
    # Remove duplicates
    moving_avg_variables = set(moving_avg_variables)
    # Collect all the variables with moving average,
    for v in moving_avg_variables:
      name_map[self.average_name(v)] = v
    # Make sure we restore variables without moving average as well.
    for v in list(set(tf.all_variables()) - moving_avg_variables):
      if v.op.name not in name_map:
        name_map[v.op.name] = v
    return name_map

  #===============================================================
trajmodel.py 文件源码 项目:RNN-TrajModel 作者: wuhao5688 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def build_LPIRNN_model(self, train_phase):
    config = self.config
    self.lpi_ = self.build_sharedTask_part(train_phase)
    loss_, loss_p_ = self.build_individualTask_part(train_phase, self.lpi_)
    if config.trace_hid_layer:
      self.trace_dict["lpi_"+str(config.trace_input_id)] = self.lpi_ # here you can collect the lpi w.r.t. a given state id
    self.loss_dict["loss"] = loss_
    self.loss_dict["loss_p"] = loss_p_
    # compute grads and update params
    self.build_trainer(self.loss_dict["loss"], tf.trainable_variables())
    if config.use_v2_saver:
      self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=config.max_ckpt_to_keep,
                                  write_version=saver_pb2.SaverDef.V2)
    else:
      self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=config.max_ckpt_to_keep,
                                  write_version=saver_pb2.SaverDef.V1)
network.py 文件源码 项目:cifar10-tensorflow 作者: namakemono 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def __init__(self, image_size=24, num_classes=10, batch_size=50, channels=3):
        self._image_size = image_size
        self._num_classes = num_classes
        self._batch_size = batch_size
        self._channels = channels
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
        self._session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        self._images = tf.placeholder(tf.float32, shape=[None, self._image_size, self._image_size, self._channels])
        self._labels = tf.placeholder(tf.int64, shape=[None])
        self._keep_prob = tf.placeholder(tf.float32)
        self._global_step = tf.Variable(0, tf.int64, name="global_step") 
        self._logits = self._inference(self._images, self._keep_prob)
        self._avg_loss = self._loss(self._labels, self._logits)
        self._train_op = self._train(self._avg_loss)
        self._accuracy = F.accuracy_score(self._labels, self._logits)
        self._saver = tf.train.Saver(tf.all_variables())
        self._session.run(tf.initialize_all_variables())
layers.py 文件源码 项目:deepsleepnet 作者: akaraspt 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def print_all_variables(train_only=False):
    """Print all trainable and non-trainable variables
    without tl.layers.initialize_global_variables(sess)

    Parameters
    ----------
    train_only : boolean
        If True, only print the trainable variables, otherwise, print all variables.
    """
    # tvar = tf.trainable_variables() if train_only else tf.all_variables()
    if train_only:
        t_vars = tf.trainable_variables()
        print("  [*] printing trainable variables")
    else:
        try: # TF1.0
            t_vars = tf.global_variables()
        except: # TF0.12
            t_vars = tf.all_variables()
        print("  [*] printing global variables")
    for idx, v in enumerate(t_vars):
        print("  var {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))


问题


面经


文章

微信
公众号

扫码关注公众号