python类ConfigProto()的实例源码

train_embedding.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 54 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    self.config = tf.ConfigProto(log_device_placement=log_device_placement,gpu_options=gpu_options)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task))
train.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu)
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task))
train_autoencoder.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task))
train-with-rebuild.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 50 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
        """"Creates a Trainer.

        Args:
          cluster: A tf.train.ClusterSpec if the execution is distributed.
            None otherwise.
          task: A TaskSpec describing the job type and the task index.
        """

        self.cluster = cluster
        self.task = task
        self.is_master = (task.type == "master" and task.index == 0)
        self.train_dir = train_dir
        self.config = tf.ConfigProto(log_device_placement=log_device_placement)

        if self.is_master and self.task.index > 0:
            raise StandardError("%s: Only one replica of master expected",
                                task_as_string(self.task))
train.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task))
train-with-rebuild.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task))
train.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task))
benchmark_cnn.py 文件源码 项目:benchmarks 作者: tensorflow 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def create_config_proto(params):
  """Returns session config proto.

  Args:
    params: Params tuple, typically created by make_params or
            make_params_from_flags.
  """
  config = tf.ConfigProto()
  config.allow_soft_placement = True
  config.intra_op_parallelism_threads = params.num_intra_threads
  config.inter_op_parallelism_threads = params.num_inter_threads
  config.gpu_options.force_gpu_compatible = params.force_gpu_compatible
  if params.gpu_memory_frac_for_testing > 0:
    config.gpu_options.per_process_gpu_memory_fraction = (
        params.gpu_memory_frac_for_testing)
  if params.xla:
    config.graph_options.optimizer_options.global_jit_level = (
        tf.OptimizerOptions.ON_1)
  if params.enable_layout_optimizer:
    config.graph_options.rewrite_options.layout_optimizer = (
        rewriter_config_pb2.RewriterConfig.ON)

  return config
wrapper.py 文件源码 项目:vae-npvc 作者: JeremyCCHsu 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def configure_gpu_settings(gpu_cfg=None):
    session_conf = None
    if gpu_cfg:
        with open(gpu_cfg) as f:
            cfg = json.load(f)
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg['per_process_gpu_memory_fraction'])
        session_conf = tf.ConfigProto(
            allow_soft_placement=cfg['allow_soft_placement'],
            log_device_placement=cfg['log_device_placement'],
            inter_op_parallelism_threads=cfg['inter_op_parallelism_threads'],
            intra_op_parallelism_threads=cfg['intra_op_parallelism_threads'],
            gpu_options=gpu_options)
        # Timeline
        # jit_level = 0
        # session_conf.graph_options.optimizer_options.global_jit_level = jit_level
    #     sess = tf.Session(
    #         config=session_conf)
    # else:
    #     sess = tf.Session()
    return session_conf
train_val.py 文件源码 项目:HandDetection 作者: YunqiuXu 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def train_net(network, imdb, roidb, valroidb, output_dir, tb_dir,
              pretrained_model=None,
              max_iters=40000):
  """Train a Faster R-CNN network."""
  roidb = filter_roidb(roidb)
  valroidb = filter_roidb(valroidb)

  tfconfig = tf.ConfigProto(allow_soft_placement=True)
  tfconfig.gpu_options.allow_growth = True

  with tf.Session(config=tfconfig) as sess:
    sw = SolverWrapper(sess, network, imdb, roidb, valroidb, output_dir, tb_dir,
                       pretrained_model=pretrained_model)
    print('Solving...')
    sw.train_model(sess, max_iters)
    print('done solving')
test_dbinterface.py 文件源码 项目:tfutils 作者: neuroailab 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def setUp(self):
        """Set up class before _each_ test method is executed.

        Creates a tensorflow session and instantiates a dbinterface.

        """
        self.setup_model()
        self.sess = tf.Session(
            config=tf.ConfigProto(
                allow_soft_placement=True,
                gpu_options=tf.GPUOptions(allow_growth=True),
                log_device_placement=self.params['log_device_placement'],
                inter_op_parallelism_threads=self.params['inter_op_parallelism_threads']))

        # TODO: Determine whether this should be called here or
        # in dbinterface.initialize()
        self.sess.run(tf.global_variables_initializer())

        self.dbinterface = base.DBInterface(sess=self.sess,
                                            params=self.params,
                                            cache_dir=self.CACHE_DIR,
                                            save_params=self.save_params,
                                            load_params=self.load_params)

        self.step = 0
train.py 文件源码 项目:neural-fonts 作者: periannath 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def main(_):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        model = UNet(args.experiment_dir, batch_size=args.batch_size, experiment_id=args.experiment_id,
                     input_width=args.image_size, output_width=args.image_size, embedding_num=args.embedding_num,
                     embedding_dim=args.embedding_dim, L1_penalty=args.L1_penalty, Lconst_penalty=args.Lconst_penalty,
                     Ltv_penalty=args.Ltv_penalty, Lcategory_penalty=args.Lcategory_penalty)
        model.register_session(sess)
        if args.flip_labels:
            model.build_model(is_training=True, inst_norm=args.inst_norm, no_target_source=True)
        else:
            model.build_model(is_training=True, inst_norm=args.inst_norm)
        fine_tune_list = None
        if args.fine_tune:
            ids = args.fine_tune.split(",")
            fine_tune_list = set([int(i) for i in ids])
        model.train(lr=args.lr, epoch=args.epoch, resume=args.resume,
                    schedule=args.schedule, freeze_encoder=args.freeze_encoder, fine_tune=fine_tune_list,
                    sample_steps=args.sample_steps, checkpoint_steps=args.checkpoint_steps,
                    flip_labels=args.flip_labels, no_val=args.no_val)
tracking.py 文件源码 项目:PyMDNet 作者: HungWei-Andy 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def tracking(dataset, seq, display, restore_path):
  train_data = reader.read_seq(dataset, seq)
  im_size = proc.load_image(train_data.data[seq].frames[0]).shape[:2]
  config = Config(im_size)

  # create session and saver
  gpu_config = tf.ConfigProto(allow_soft_placement=True)
  sess = tf.InteractiveSession(config=gpu_config)

  # load model, weights
  model = MDNet(config)
  model.build_generator(config.batch_size, reuse=False, dropout=True)
  tf.global_variables_initializer().run()

  # create saver
  saver = tf.train.Saver([v for v in tf.global_variables() if ('conv' in v.name or 'fc4' in v.name or 'fc5' in v.name) \
                          and 'lr_rate' not in v.name], max_to_keep=50)

  # restore from model
  saver.restore(sess, restore_path)

  # run mdnet
  mdnet_run(sess, model, train_data.data[seq].gts[0], train_data.data[seq].frames, config, display)
predict.py 文件源码 项目:tensorflow_seq2seq_chatbot 作者: higepon 项目源码 文件源码 阅读 38 收藏 0 点赞 0 评论 0
def predict():
    # Only allocate part of the gpu memory when predicting.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2)
    tf_config = tf.ConfigProto(gpu_options=gpu_options)

    with tf.Session(config=tf_config) as sess:

        predictor = EasyPredictor(sess)

        sys.stdout.write("> ")
        sys.stdout.flush()
        line = sys.stdin.readline()
        while line:
            replies = predictor.predict(line)
            for i, text in enumerate(replies):
                print(i, text)
            print("> ", end="")
            sys.stdout.flush()
            line = sys.stdin.readline()
iclr_2017_benchmark.py 文件源码 项目:fold 作者: tensorflow 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def run(self):
    """Build a graph and run the model on random input data."""
    _logger.info("Creating graph.")
    with tf.Graph().as_default():
      _logger.info("Building model.")
      self.build_model_loss()

      _logger.info("Starting session.")
      config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
      with tf.Session(config=config) as sess:
        _logger.info("Initializing variables.")
        sess.run(tf.global_variables_initializer())
        _logger.info("Starting timing test.")
        self.evaluate(sess)

      _logger.info("Ending session.")
distributed_sit.py 文件源码 项目:GoogleCloudSetup 作者: tmulc18 项目源码 文件源码 阅读 56 收藏 0 点赞 0 评论 0
def main():
    # Graph
    with tf.device('/cpu:0'):
        a = tf.Variable(tf.truncated_normal(shape=[2]),dtype=tf.float32)
        b = tf.Variable(tf.truncated_normal(shape=[2]),dtype=tf.float32)
        c=a+b

        target = tf.constant(100.,shape=[2],dtype=tf.float32)
        loss = tf.reduce_mean(tf.square(c-target))

        opt = tf.train.GradientDescentOptimizer(.0001).minimize(loss)

    # Session
    #sv = tf.train.Supervisor(logdir='/tmp/mydir')
    sv = tf.train.Supervisor(logdir='/tmp/mydir')
    gpu_options = tf.GPUOptions(allow_growth=True,allocator_type="BFC",visible_device_list="%d"%FLAGS.gpu_id)
    config = tf.ConfigProto(gpu_options=gpu_options,allow_soft_placement=False,device_count={'GPU':1},log_device_placement=True)
    sess = sv.prepare_or_wait_for_session(config=config)
    for i in range(1000):
        sess.run(opt)
        if i % 10 == 0:
            r = sess.run(c)
            print(r)
        time.sleep(.1)
main.py 文件源码 项目:RFR-solution 作者: baoblackcoal 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def main(_):
  gpu_options = tf.GPUOptions(
      per_process_gpu_memory_fraction=calc_gpu_fraction(FLAGS.gpu_fraction))

  with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    config = get_config(FLAGS) or FLAGS

    if config.env_type == 'simple':
      env = SimpleGymEnvironment(config)
    else:
      env = GymEnvironment(config)

    if not tf.test.is_gpu_available() and FLAGS.use_gpu:
      raise Exception("use_gpu flag is true when no GPUs are available")

    if not FLAGS.use_gpu:
      config.cnn_format = 'NHWC'

    agent = Agent(config, env, sess)

    if FLAGS.is_train:
      agent.train()
    else:
      agent.play()
ocr_eval.py 文件源码 项目:tf-cnn-lstm-ocr-captcha 作者: Luonic 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def evaluate():
    """Eval ocr for a number of steps."""
    with tf.Graph().as_default() as g:
        images, labels, seq_lengths = ocr.inputs()
        logits, timesteps = ocr.inference(images, FLAGS.eval_batch_size, train=True)
        ler = ocr.create_label_error_rate(logits, labels, timesteps)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        config = tf.ConfigProto(
            device_count={'GPU': 0}
        )
        sess = tf.Session(config=config)
        sess.run(init_op)

        saver = tf.train.Saver()

        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)

        while True:
            eval_once(saver, summary_writer, ler, summary_op)
            if FLAGS.run_once:
                break
            # print("Waiting for next evaluation for " + str(FLAGS.eval_interval_secs) + " sec")
            time.sleep(FLAGS.eval_interval_secs)
dense_net_3d.py 文件源码 项目:3d-DenseNet 作者: frankgu 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _initialize_session(self):
    """Initialize session, variables, saver"""
    config = tf.ConfigProto()
    # restrict model GPU memory utilization to min required
    config.gpu_options.allow_growth = True
    self.sess = tf.Session(config=config)
    tf_ver = int(tf.__version__.split('.')[1])
    if TF_VERSION <= 0.10:
      self.sess.run(tf.initialize_all_variables())
      logswriter = tf.train.SummaryWriter
    else:
      self.sess.run(tf.global_variables_initializer())
      logswriter = tf.summary.FileWriter
    self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=0)
    self.summary_writer = logswriter(self.logs_path, self.sess.graph)

  # (Updated)
demo.py 文件源码 项目:blitznet 作者: dvornikita 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def main(argv=None):  # pylint: disable=unused-argument
    assert args.detect or args.segment, "Either detect or segment should be True"
    assert args.ckpt > 0, "Specify the number of checkpoint"
    net = ResNet(config=net_config, depth=50, training=False)
    loader = Loader()


    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        detector = Detector(sess, net, loader, net_config, no_gt=args.no_seg_gt,
                            folder=osp.join(loader.folder, 'output'))
        detector.restore_from_ckpt(args.ckpt)
        for name in loader.get_filenames():
            image = loader.load_image(name)
            h, w = image.shape[:2]
            print('Processing {}'.format(name + loader.data_format))
            detector.feed_forward(img=image, name=name, w=w, h=h, draw=True,
                                  seg_gt=None, gt_bboxes=None, gt_cats=None)
    print('Done')
run_mujoco.py 文件源码 项目:baselines 作者: openai 项目源码 文件源码 阅读 49 收藏 0 点赞 0 评论 0
def train(env_id, num_timesteps, seed):
    env=gym.make(env_id)
    env = bench.Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
    set_global_seeds(seed)
    env.seed(seed)
    gym.logger.setLevel(logging.WARN)

    with tf.Session(config=tf.ConfigProto()):
        ob_dim = env.observation_space.shape[0]
        ac_dim = env.action_space.shape[0]
        with tf.variable_scope("vf"):
            vf = NeuralNetValueFunction(ob_dim, ac_dim)
        with tf.variable_scope("pi"):
            policy = GaussianMlpPolicy(ob_dim, ac_dim)

        learn(env, policy=policy, vf=vf,
            gamma=0.99, lam=0.97, timesteps_per_batch=2500,
            desired_kl=0.002,
            num_timesteps=num_timesteps, animate=False)

        env.close()
run_dqn_atari.py 文件源码 项目:rl_algorithms 作者: DanielTakeshi 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def get_session():
    tf.reset_default_graph()
    tf_config = tf.ConfigProto(
        inter_op_parallelism_threads=1,
        intra_op_parallelism_threads=1)

    # This was the default provided in the starter code.
    #session = tf.Session(config=tf_config)

    # Use this if I want to see what is on the GPU.
    #session = tf.Session(config=tf.ConfigProto(log_device_placement=True))

    # Use this for limiting memory allocated for the GPU.
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
    session = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

    print("AVAILABLE GPUS: ", get_available_gpus())
    return session
process.py 文件源码 项目:DocumentSegmentation 作者: SeguinBe 项目源码 文件源码 阅读 170 收藏 0 点赞 0 评论 0
def process(input_dir, output_dir, model_dir, resizing_size, gpu):
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3, visible_device_list=gpu)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)).as_default():
        m = loader.LoadedModel(model_dir)

    os.makedirs(output_dir, exist_ok=True)

    input_filenames = glob(os.path.join(input_dir, '*.jpg')) + \
                      glob(os.path.join(input_dir, '*.png')) + \
                      glob(os.path.join(input_dir, '*.tif')) + \
                      glob(os.path.join(input_dir, '*.jp2'))

    for path in tqdm(input_filenames):
        img = Image.open(path).resize(resizing_size)
        mat = np.asarray(img)
        if len(mat.shape) == 2:
            mat = np.stack([mat, mat, mat], axis=2)
        predictions = m.predict(mat[None], prediction_key='labels')[0]
        plt.imsave(os.path.join(output_dir, os.path.relpath(path, input_dir)), predictions)
train_ensemble.py 文件源码 项目:youtube-8m 作者: wangheda 项目源码 文件源码 阅读 57 收藏 0 点赞 0 评论 0
def __init__(self, cluster, task, train_dir, log_device_placement=True):
    """"Creates a Trainer.

    Args:
      cluster: A tf.train.ClusterSpec if the execution is distributed.
        None otherwise.
      task: A TaskSpec describing the job type and the task index.
    """

    self.cluster = cluster
    self.task = task
    self.is_master = (task.type == "master" and task.index == 0)
    self.train_dir = train_dir
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu)
    self.config = tf.ConfigProto(log_device_placement=log_device_placement)

    if self.is_master and self.task.index > 0:
      raise StandardError("%s: Only one replica of master expected",
                          task_as_string(self.task))
validation_check.py 文件源码 项目:neurobind 作者: Kyubyong 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def validation_check():
    # Load graph
    g = Graph(is_training=False); print("Graph loaded")

    # Load data
    X, Y = load_data(mode="val")

    with g.graph.as_default():
        sv = tf.train.Supervisor()
        with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
            # Restore parameters
            sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)); print("Restored!")

            # Get model
            mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1]  # model name

            # Inference
            if not os.path.exists(hp.results): os.mkdir(hp.results)
            with open(os.path.join(hp.results, "validation_results.txt"), 'a') as fout:
                expected, predicted = [], []
                for step in range(len(X) // hp.batch_size):
                    x = X[step * hp.batch_size: (step + 1) * hp.batch_size]
                    y = Y[step * hp.batch_size: (step + 1) * hp.batch_size]

                    # predict intensities
                    logits = sess.run(g.logits, {g.x: x})

                    expected.extend(list(y))
                    predicted.extend(list(logits))

                # Get spearman coefficients
                score, _ = spearmanr(expected, predicted)
                fout.write("{}\t{}\n".format(mname, score))
classifier_tf.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 47 收藏 0 点赞 0 评论 0
def __init__(self, checkpoint_file):

        checkpoint_dir = os.path.dirname(checkpoint_file)
        hparams_file = os.path.join(checkpoint_dir, "hparams.txt")
        hparams_dict = {}
        if os.path.isfile(hparams_file):
            with open(hparams_file) as f:
                hparams_dict = ast.literal_eval(f.read())
        self.hparams = TensorflowClassifierHparams(**hparams_dict)
        self.graph = tf.Graph()
        with self.graph.as_default():
            print("loading from file {}".format(checkpoint_file))
            config = tf.ConfigProto(
                device_count={'GPU': 0}, )
            config.gpu_options.visible_device_list = ""
            self.session = tf.Session(config=config)
            new_saver = tf.train.import_meta_graph(checkpoint_file + ".meta", clear_devices=True)
            new_saver.restore(self.session, checkpoint_file)

            self.features = {}

            if self.hparams.use_image:
                self.features["image"] = self.graph.get_tensor_by_name("image:0")
            if self.hparams.use_observation:
                self.features["observation"] = self.graph.get_tensor_by_name("observation:0")
            if self.hparams.use_action:
                self.features["action"] = self.graph.get_tensor_by_name("action:0")
            self.prediction = tf.get_collection('prediction')[0]
            self.loss = tf.get_collection('loss')[0]
            self.threshold = tf.get_collection('threshold')[0]
classifier_tf.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def __init__(self, checkpoint_file):

        checkpoint_dir = os.path.dirname(checkpoint_file)
        hparams_file = os.path.join(checkpoint_dir, "hparams.txt")
        hparams_dict = {}
        if os.path.isfile(hparams_file):
            with open(hparams_file) as f:
                hparams_dict = ast.literal_eval(f.read())
        self.hparams = TensorflowClassifierHparams(**hparams_dict)
        self.graph = tf.Graph()
        with self.graph.as_default():
            print("loading from file {}".format(checkpoint_file))
            config = tf.ConfigProto(
                device_count={'GPU': 0}, )
            config.gpu_options.visible_device_list = ""
            self.session = tf.Session(config=config)
            new_saver = tf.train.import_meta_graph(checkpoint_file + ".meta", clear_devices=True)
            new_saver.restore(self.session, checkpoint_file)

            self.features = {}

            if self.hparams.use_image:
                self.features["image"] = self.graph.get_tensor_by_name("image:0")
            if self.hparams.use_observation:
                self.features["observation"] = self.graph.get_tensor_by_name("observation:0")
            if self.hparams.use_action:
                self.features["action"] = self.graph.get_tensor_by_name("action:0")
            self.prediction = tf.get_collection('prediction')[0]
            self.loss = tf.get_collection('loss')[0]
            self.threshold = tf.get_collection('threshold')[0]
classifier_tf.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def __init__(self, checkpoint_file):

        checkpoint_dir = os.path.dirname(checkpoint_file)
        hparams_file = os.path.join(checkpoint_dir, "hparams.txt")
        hparams_dict = {}
        if os.path.isfile(hparams_file):
            with open(hparams_file) as f:
                hparams_dict = ast.literal_eval(f.read())
        self.hparams = TensorflowClassifierHparams(**hparams_dict)
        self.graph = tf.Graph()
        with self.graph.as_default():
            print("loading from file {}".format(checkpoint_file))
            config = tf.ConfigProto(
                device_count={'GPU': 0}, )
            config.gpu_options.visible_device_list = ""
            self.session = tf.Session(config=config)
            new_saver = tf.train.import_meta_graph(checkpoint_file + ".meta", clear_devices=True)
            new_saver.restore(self.session, checkpoint_file)

            self.features = {}

            if self.hparams.use_image:
                self.features["image"] = self.graph.get_tensor_by_name("image:0")
            if self.hparams.use_observation:
                self.features["observation"] = self.graph.get_tensor_by_name("observation:0")
            if self.hparams.use_action:
                self.features["action"] = self.graph.get_tensor_by_name("action:0")
            self.prediction = tf.get_collection('prediction')[0]
            self.loss = tf.get_collection('loss')[0]
            self.threshold = tf.get_collection('threshold')[0]
classifier_tf.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def __init__(self, checkpoint_file):

        checkpoint_dir = os.path.dirname(checkpoint_file)
        hparams_file = os.path.join(checkpoint_dir, "hparams.txt")
        hparams_dict = {}
        if os.path.isfile(hparams_file):
            with open(hparams_file) as f:
                hparams_dict = ast.literal_eval(f.read())
        self.hparams = TensorflowClassifierHparams(**hparams_dict)
        self.graph = tf.Graph()
        with self.graph.as_default():
            print("loading from file {}".format(checkpoint_file))
            config = tf.ConfigProto(
                device_count={'GPU': 0}, )
            config.gpu_options.visible_device_list = ""
            self.session = tf.Session(config=config)
            new_saver = tf.train.import_meta_graph(checkpoint_file + ".meta", clear_devices=True)
            new_saver.restore(self.session, checkpoint_file)

            self.features = {}

            if self.hparams.use_image:
                self.features["image"] = self.graph.get_tensor_by_name("image:0")
            if self.hparams.use_observation:
                self.features["observation"] = self.graph.get_tensor_by_name("observation:0")
            if self.hparams.use_action:
                self.features["action"] = self.graph.get_tensor_by_name("action:0")
            self.prediction = tf.get_collection('prediction')[0]
            self.loss = tf.get_collection('loss')[0]
            self.threshold = tf.get_collection('threshold')[0]
tf_util.py 文件源码 项目:distributional_perspective_on_RL 作者: Kiwoo 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def single_threaded_session():
    tf_config = tf.ConfigProto(
        inter_op_parallelism_threads=1,
        intra_op_parallelism_threads=1)
    return tf.Session(config=tf_config)


问题


面经


文章

微信
公众号

扫码关注公众号