python类as_bytes()的实例源码

session_bundle_test.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def testBasic(self):
    base_path = tf.test.test_src_dir_path(
        "contrib/session_bundle/example/half_plus_two/00000123")
    tf.reset_default_graph()
    sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
        base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))

    self.assertTrue(sess)
    asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
    with sess.as_default():
      path1, path2 = sess.run(["filename1:0", "filename2:0"])
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)

      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)

      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      self._checkRegressionSignature(signatures, sess)
      self._checkNamedSigantures(signatures, sess)
async_adder.py 文件源码 项目:stuff 作者: yaroslavvb 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def main():
  global writer
  config = load_config()

  # todo: factor out common logic
  logdir = os.environ["LOGDIR"]
  writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))

  if  config.task_type == 'worker':
    run_worker()
  elif config.task_type == 'ps':
    run_ps()
  else:
    assert False, "Unknown task type "+str(config.task_type)

  writer.Close()
saved_model_half_plus_two.py 文件源码 项目:taas-examples 作者: caicloud 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _write_assets(assets_directory, assets_filename):
  """??????? hall_plus_two ???????????

  Args:
    - assets_directory: ?????????
    - assets_filename: ???????
  Returns:
  ????????
  """
  if not file_io.file_exists(assets_directory):
    file_io.recursive_create_dir(assets_directory)

  path = os.path.join(
    compat.as_bytes(assets_directory),
    compat.as_bytes(assets_filename))
  file_io.write_string_to_file(path, "asset-file-contents")
  return path
experiment.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def _maybe_export(self, eval_result):  # pylint: disable=unused-argument
    """Export the Estimator using export_fn, if defined."""
    export_dir_base = os.path.join(
        compat.as_bytes(self._estimator.model_dir),
        compat.as_bytes("export"))

    export_results = []
    for strategy in self._export_strategies:
      # TODO(soergel): possibly, allow users to decide whether to export here
      # based on the eval_result (e.g., to keep the best export).

      export_results.append(
          strategy.export(
              self._estimator,
              os.path.join(
                  compat.as_bytes(export_dir_base),
                  compat.as_bytes(strategy.name))))

    return export_results
saved_model_export_utils.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def get_timestamped_export_dir(export_dir_base):
  """Builds a path to a new subdirectory within the base directory.

  Each export is written into a new subdirectory named using the
  current time.  This guarantees monotonically increasing version
  numbers even across multiple runs of the pipeline.
  The timestamp used is the number of seconds since epoch UTC.

  Args:
    export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
  Returns:
    The full path of the new subdirectory (which is not actually created yet).
  """
  export_timestamp = int(time.time())

  export_dir = os.path.join(
      compat.as_bytes(export_dir_base),
      compat.as_bytes(str(export_timestamp)))
  return export_dir


# create a simple parser that pulls the export_version from the directory.
saved_model_export_utils_test.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def test_get_most_recent_export(self):
    export_dir_base = tempfile.mkdtemp() + "export/"
    gfile.MkDir(export_dir_base)
    _create_test_export_dir(export_dir_base)
    _create_test_export_dir(export_dir_base)
    _create_test_export_dir(export_dir_base)
    export_dir_4 = _create_test_export_dir(export_dir_base)

    (most_recent_export_dir, most_recent_export_version) = (
        saved_model_export_utils.get_most_recent_export(export_dir_base))

    self.assertEqual(compat.as_bytes(export_dir_4),
                     compat.as_bytes(most_recent_export_dir))
    self.assertEqual(compat.as_bytes(export_dir_4),
                     os.path.join(compat.as_bytes(export_dir_base),
                                  compat.as_bytes(
                                      str(most_recent_export_version))))
visualize_graph.py 文件源码 项目:yolov2 作者: datlife 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def visualize_graph_in_tfboard(filename, output='./log'):
    with tf.Session() as sess:
        model_filename = filename
        with gfile.FastGFile(model_filename, 'rb') as f:
            data = compat.as_bytes(f.read())
            sm = saved_model_pb2.SavedModel()
            sm.ParseFromString(data)
            if 1 != len(sm.meta_graphs):
                print('More than one graph found. Not sure which to write')
                sys.exit(1)

            g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)

        train_writer = tf.summary.FileWriter(output)
        train_writer.add_graph(sess.graph)
        print("Please execute `tensorboard --logdir {}` to view graph".format(output))
logger.py 文件源码 项目:distributional_perspective_on_RL 作者: Kiwoo 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow        
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
logger.py 文件源码 项目:baselines 作者: openai 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
spark_session.py 文件源码 项目:tensoronspark 作者: liangfengsid 项目源码 文件源码 阅读 48 收藏 0 点赞 0 评论 0
def reset(target, containers=None, config=None):
        if target is not None:
            target = compat.as_bytes(target)
        if containers is not None:
            containers = [compat.as_bytes(c) for c in containers]
        else:
            containers = []
        tf_session.TF_Reset(target, containers, config)
tf_Session.py 文件源码 项目:LIE 作者: EmbraceLife 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _name_list(tensor_list):
  """Utility function for transitioning to the new session API.

  Args:
    tensor_list: a list of `Tensor`s.

  Returns:
    A list of each `Tensor`s name (as byte arrays).
  """
  return [compat.as_bytes(t.name) for t in tensor_list]
tf_Session.py 文件源码 项目:LIE 作者: EmbraceLife 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def reset(target, containers=None, config=None):
    """Resets resource containers on `target`, and close all connected sessions.

    A resource container is distributed across all workers in the
    same cluster as `target`.  When a resource container on `target`
    is reset, resources associated with that container will be cleared.
    In particular, all Variables in the container will become undefined:
    they lose their values and shapes.

    NOTE:
    (i) reset() is currently only implemented for distributed sessions.
    (ii) Any sessions on the master named by `target` will be closed.

    If no resource containers are provided, all containers are reset.

    Args:
      target: The execution engine to connect to.
      containers: A list of resource container name strings, or `None` if all of
        all the containers are to be reset.
      config: (Optional.) Protocol buffer with configuration options.

    Raises:
      tf.errors.OpError: Or one of its subclasses if an error occurs while
        resetting containers.
    """
    if target is not None:
      target = compat.as_bytes(target)
    if containers is not None:
      containers = [compat.as_bytes(c) for c in containers]
    else:
      containers = []
    tf_session.TF_Reset(target, containers, config)
session_bundle_test.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def testBasic(self):
    base_path = tf.test.test_src_dir_path(
        "contrib/session_bundle/example/half_plus_two/00000123")
    tf.reset_default_graph()
    sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
        base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))

    self.assertTrue(sess)
    asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
    with sess.as_default():
      path1, path2 = sess.run(["filename1:0", "filename2:0"])
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
      self.assertEqual(
          compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)

      collection_def = meta_graph_def.collection_def

      signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
      self.assertEquals(len(signatures_any), 1)

      signatures = manifest_pb2.Signatures()
      signatures_any[0].Unpack(signatures)
      default_signature = signatures.default_signature
      input_name = default_signature.regression_signature.input.tensor_name
      output_name = default_signature.regression_signature.output.tensor_name
      y = sess.run([output_name], {input_name: np.array([[0], [1], [2], [3]])})
      # The operation is y = 0.5 * x + 2
      self.assertEqual(y[0][0], 2)
      self.assertEqual(y[0][1], 2.5)
      self.assertEqual(y[0][2], 3)
      self.assertEqual(y[0][3], 3.5)
exporter.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def gfile_copy_callback(files_to_copy, export_dir_path):
  """Callback to copy files using `gfile.Copy` to an export directory.

  This method is used as the default `assets_callback` in `Exporter.init` to
  copy assets from the `assets_collection`. It can also be invoked directly to
  copy additional supplementary files into the export directory (in which case
  it is not a callback).

  Args:
    files_to_copy: A dictionary that maps original file paths to desired
      basename in the export directory.
    export_dir_path: Directory to copy the files to.
  """
  logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
  gfile.MakeDirs(export_dir_path)
  for source_filepath, basename in files_to_copy.items():
    new_path = os.path.join(
        compat.as_bytes(export_dir_path), compat.as_bytes(basename))
    logging.info("Copying asset %s to path %s.", source_filepath, new_path)

    if gfile.Exists(new_path):
      # Guard against being restarted while copying assets, and the file
      # existing and being in an unknown state.
      # TODO(b/28676216): Do some file checks before deleting.
      logging.info("Removing file %s.", new_path)
      gfile.Remove(new_path)
    gfile.Copy(source_filepath, new_path)
exporter.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def gfile_copy_callback(files_to_copy, export_dir_path):
  """Callback to copy files using `gfile.Copy` to an export directory.

  This method is used as the default `assets_callback` in `Exporter.init` to
  copy assets from the `assets_collection`. It can also be invoked directly to
  copy additional supplementary files into the export directory (in which case
  it is not a callback).

  Args:
    files_to_copy: A dictionary that maps original file paths to desired
      basename in the export directory.
    export_dir_path: Directory to copy the files to.
  """
  logging.info("Write assets into: %s using gfile_copy.", export_dir_path)
  gfile.MakeDirs(export_dir_path)
  for source_filepath, basename in files_to_copy.items():
    new_path = os.path.join(
        compat.as_bytes(export_dir_path), compat.as_bytes(basename))
    logging.info("Copying asset %s to path %s.", source_filepath, new_path)

    if gfile.Exists(new_path):
      # Guard against being restarted while copying assets, and the file
      # existing and being in an unknown state.
      # TODO(b/28676216): Do some file checks before deleting.
      logging.info("Removing file %s.", new_path)
      gfile.Remove(new_path)
    gfile.Copy(source_filepath, new_path)
benchmark_grpc_recv.py 文件源码 项目:stuff 作者: yaroslavvb 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def run_benchmark(sess, init_op, add_op):
  """Returns MB/s rate of addition."""


  logdir=FLAGS.logdir_prefix+'/'+FLAGS.name
  os.system('mkdir -p '+logdir)

  # TODO: make events follow same format as eager writer
  writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
  filename = compat.as_text(writer.FileName())
  training_util.get_or_create_global_step()

  sess.run(init_op)

  for step in range(FLAGS.iters):
    start_time = time.time()
    for i in range(FLAGS.iters_per_step):
      sess.run(add_op.op)

    elapsed_time = time.time() - start_time
    rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
    event = make_event('rate', rate, step)
    writer.WriteEvent(event)
    writer.Flush()
  writer.Close()
  # add event
tabular_logger.py 文件源码 项目:ray 作者: ray-project 项目源码 文件源码 阅读 72 收藏 0 点赞 0 评论 0
def __init__(self, dir, prefix):
        self.dir = dir
        # Start at 1, because EvWriter automatically generates an object with
        # step = 0.
        self.step = 1
        self.evwriter = pywrap_tensorflow.EventsWriter(
            compat.as_bytes(os.path.join(dir, prefix)))
tabular_logger.py 文件源码 项目:evolution-strategies-starter 作者: openai 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def __init__(self, dir, prefix):
        self.dir = dir
        self.step = 1 # Start at 1, because EvWriter automatically generates an object with step=0
        self.evwriter = pywrap_tensorflow.EventsWriter(compat.as_bytes(os.path.join(dir, prefix)))
logger.py 文件源码 项目:rl-teacher 作者: nottombrown 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow        
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
logger_openai.py 文件源码 项目:gym-sandbox 作者: suqi 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def __init__(self, dir):
        os.makedirs(dir, exist_ok=True)
        self.dir = dir
        self.step = 1
        prefix = 'events'
        path = osp.join(osp.abspath(dir), prefix)
        import tensorflow as tf
        from tensorflow.python import pywrap_tensorflow        
        from tensorflow.core.util import event_pb2
        from tensorflow.python.util import compat
        self.tf = tf
        self.event_pb2 = event_pb2
        self.pywrap_tensorflow = pywrap_tensorflow
        self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
experiment_test.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def export_savedmodel(self, export_dir_base, serving_input_fn, **kwargs):
    tf_logging.info('export_savedmodel called with args: %s, %s, %s' %
                    (export_dir_base, serving_input_fn, kwargs))
    self.export_count += 1
    return os.path.join(
        compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
saved_transform_io_test.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_stale_asset_collections_are_cleaned(self):
    vocabulary_file = os.path.join(
        compat.as_bytes(test.get_temp_dir()), compat.as_bytes('asset'))
    file_io.write_string_to_file(vocabulary_file, 'foo bar baz')

    export_path = os.path.join(tempfile.mkdtemp(), 'export')

    # create a SavedModel including assets
    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        input_string = tf.placeholder(tf.string)
        # Map string through a table loaded from an asset file
        table = lookup.index_table_from_file(
            vocabulary_file, num_oov_buckets=12, default_value=12)
        output = table.lookup(input_string)
        inputs = {'input': input_string}
        outputs = {'output': output}
        saved_transform_io.write_saved_transform_from_session(
            session, inputs, outputs, export_path)

    # Load it and save it again repeatedly, verifying that the asset collections
    # remain valid.
    for _ in [1, 2, 3]:
      with tf.Graph().as_default() as g:
        with tf.Session().as_default() as session:
          input_string = tf.constant('dog')
          inputs = {'input': input_string}
          outputs = saved_transform_io.apply_saved_transform(export_path,
                                                             inputs)

          self.assertEqual(
              1, len(g.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
          self.assertEqual(
              0, len(g.get_collection(tf.saved_model.constants.ASSETS_KEY)))

          # Check that every ASSET_FILEPATHS refers to a Tensor in the graph.
          # If not, get_tensor_by_name() raises KeyError.
          for asset_path in g.get_collection(ops.GraphKeys.ASSET_FILEPATHS):
            tensor_name = asset_path.name
            g.get_tensor_by_name(tensor_name)

          export_path = os.path.join(tempfile.mkdtemp(), 'export')
          saved_transform_io.write_saved_transform_from_session(
              session, inputs, outputs, export_path)
tf_Session.py 文件源码 项目:LIE 作者: EmbraceLife 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def _do_run(self, handle, target_list, fetch_list, feed_dict,
              options, run_metadata):
    """Runs a step based on the given fetches and feeds.

    Args:
      handle: a handle for partial_run. None if this is just a call to run().
      target_list: A list of operations to be run, but not fetched.
      fetch_list: A list of tensors to be fetched.
      feed_dict: A dictionary that maps tensors to numpy ndarrays.
      options: A (pointer to a) [`RunOptions`] protocol buffer, or None
      run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None

    Returns:
      A list of numpy ndarrays, corresponding to the elements of
      `fetch_list`.  If the ith element of `fetch_list` contains the
      name of an operation, the first Tensor output of that operation
      will be returned for that element.

    Raises:
      tf.errors.OpError: Or one of its subclasses on error.
    """
    if self._created_with_new_api:
      # pylint: disable=protected-access
      feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
      fetches = [t._as_tf_output() for t in fetch_list]
      targets = [op._c_op for op in target_list]
      # pylint: enable=protected-access
    else:
      feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
      fetches = _name_list(fetch_list)
      targets = _name_list(target_list)

    def _run_fn(session, feed_dict, fetch_list, target_list, options,
                run_metadata):
      # Ensure any changes to the graph are reflected in the runtime.
      self._extend_graph()
      with errors.raise_exception_on_not_ok_status() as status:
        if self._created_with_new_api:
          return tf_session.TF_SessionRun_wrapper(
              session, options, feed_dict, fetch_list, target_list,
              run_metadata, status)
        else:
          return tf_session.TF_Run(session, options,
                                   feed_dict, fetch_list, target_list,
                                   status, run_metadata)

    def _prun_fn(session, handle, feed_dict, fetch_list):
      assert not self._created_with_new_api, ('Partial runs don\'t work with '
                                              'C API')
      if target_list:
        raise RuntimeError('partial_run() requires empty target_list.')
      with errors.raise_exception_on_not_ok_status() as status:
        return tf_session.TF_PRun(session, handle, feed_dict, fetch_list,
                                  status)

    if handle is None:
      return self._do_call(_run_fn, self._session, feeds, fetches, targets,
                           options, run_metadata)
    else:
      return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
transferLearningV2.py 文件源码 项目:PlantImageRecognition 作者: HeavenMin 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
transferLearningV3.py 文件源码 项目:PlantImageRecognition 作者: HeavenMin 项目源码 文件源码 阅读 43 收藏 0 点赞 0 评论 0
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
Second_Purification.py 文件源码 项目:PlantImageRecognition 作者: HeavenMin 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
First_Purification.py 文件源码 项目:PlantImageRecognition 作者: HeavenMin 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
Train_Test.py 文件源码 项目:PlantImageRecognition 作者: HeavenMin 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def createImageLists(imageDir, testingPercentage, validationPercventage):
    if not gfile.Exists(imageDir):
        print("Image dir'" + imageDir +"'not found.'")
        return None
    result = {}
    subDirs = [x[0] for x in gfile.Walk(imageDir)]
    isRootDir = True
    for subDir in subDirs:
        if isRootDir:
            isRootDir = False
            continue
        extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
        fileList = []
        dirName = os.path.basename(subDir)
        if dirName == imageDir:
            continue
        print("Looking for images in '" + dirName + "'")
        for extension in extensions:
            fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
            fileList.extend(gfile.Glob(fileGlob))
        if not fileList:
            print('No file found')
            continue
        labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
        trainingImages = []
        testingImages =[]
        validationImages = []
        for fileName in fileList:
            baseName = os.path.basename(fileName)
            hashName = re.sub(r'_nohash_.*$', '', fileName)
            hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
            percentHash = ((int(hashNameHased, 16) %
                            (MAX_NUM_IMAGES_PER_CLASS + 1)) *
                            (100.0 / MAX_NUM_IMAGES_PER_CLASS))
            if percentHash < validationPercventage:
                validationImages.append(baseName)
            elif percentHash < (testingPercentage + validationPercventage):
                testingImages.append(baseName)
            else:
                trainingImages.append(baseName)
        result[labelName] = {
            'dir': dirName,
            'training': trainingImages,
            'testing': testingImages,
            'validation': validationImages,
        }
    return result
saver.py 文件源码 项目:Machine-Learning 作者: sfeng15 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def _add_collection_def(meta_graph_def, key):
  """Adds a collection to MetaGraphDef protocol buffer.

  Args:
    meta_graph_def: MetaGraphDef protocol buffer.
    key: One of the GraphKeys or user-defined string.
  """
  if not isinstance(key, six.string_types) and not isinstance(key, bytes):
    logging.warning("Only collections with string type keys will be "
                    "serialized. This key has %s" % type(key))
    return
  collection_list = ops.get_collection(key)
  if not collection_list:
    return
  try:
    col_def = meta_graph_def.collection_def[key]
    to_proto = ops.get_to_proto_function(key)
    proto_type = ops.get_collection_proto_type(key)
    if to_proto:
      kind = "bytes_list"
      for x in collection_list:
        # Additional type check to make sure the returned proto is indeed
        # what we expect.
        proto = to_proto(x)
        assert isinstance(proto, proto_type)
        getattr(col_def, kind).value.append(proto.SerializeToString())
    else:
      kind = _get_kind_name(collection_list[0])
      if kind == "node_list":
        getattr(col_def, kind).value.extend([x.name for x in collection_list])
      elif kind == "bytes_list":
        # NOTE(opensource): This force conversion is to work around the fact
        # that Python3 distinguishes between bytes and strings.
        getattr(col_def, kind).value.extend(
            [compat.as_bytes(x) for x in collection_list])
      else:
        getattr(col_def, kind).value.extend([x for x in collection_list])
  except Exception as e:  # pylint: disable=broad-except
    logging.warning("Error encountered when serializing %s.\n"
                    "Type is unsupported, or the types of the items don't "
                    "match field type in CollectionDef.\n%s" % (key, str(e)))
    if key in meta_graph_def.collection_def:
      del meta_graph_def.collection_def[key]
    return


问题


面经


文章

微信
公众号

扫码关注公众号