python类Example()的实例源码

example_proto_coder.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def encode(self, instance):
    """Encode a tf.transform encoded dict as serialized tf.Example."""
    if self._encode_example_cache is None:
      # Initialize the encode Example cache (used by this and all subsequent
      # calls to encode).
      example = tf.train.Example()
      for feature_handler in self._feature_handlers:
        feature_handler.initialize_encode_cache(example)
      self._encode_example_cache = example

    # Encode and serialize using the Example cache.
    for feature_handler in self._feature_handlers:
      value = instance[feature_handler.name]
      try:
        feature_handler.encode_value(value)
      except TypeError as e:
        raise TypeError('%s while encoding feature "%s"' %
                        (e, feature_handler.name))

    return self._encode_example_cache.SerializeToString()
train_eval_base.py 文件源码 项目:easy-tensorflow 作者: khanhptnk 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def __init__(self, model, loss_fn, data_path, log_dir, graph, input_reader):
    """Initialize a `TrainEvalBase` object.
      Args:
        model: an instance of a subclass of the `ModelBase` class (defined in
          `model_base.py`).
        loss_fn: a tensorflow op, a loss function for training a model. See:
            https://www.tensorflow.org/code/tensorflow/contrib/losses/python/losses/loss_ops.py
          for a list of available loss functions.
        data_path: a string, path to files of tf.Example protos containing data.
        log_dir: a string, logging directory.
        graph: a tensorflow computation graph.
        input_reader: an instance of a subclass of the `InputReaderBase` class
          (defined in `input_reader_base.py`).
    """
    self._data_path = data_path
    self._log_dir = log_dir
    self._train_log_dir = os.path.join(self._log_dir, "train")
    self._eval_log_dir = os.path.join(self._log_dir, "eval")

    self._model = model
    self._loss_fn = loss_fn
    self._graph = graph
    self._input_reader = input_reader

    self._summary_ops = []
dataset_schema.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def as_feature_spec(self, column):
    ind = self.index_fields
    if len(ind) != 1 or len(column.axes) != 1:
      raise ValueError('tf.Example parser supports only 1-d sparse features.')
    index = ind[0]

    if column.domain.dtype not in _TF_EXAMPLE_ALLOWED_TYPES:
      raise ValueError('tf.Example parser supports only types {}, so it is '
                       'invalid to generate a feature_spec with type '
                       '{}.'.format(
                           _TF_EXAMPLE_ALLOWED_TYPES,
                           repr(column.domain.dtype)))

    return tf.SparseFeature(index.name,
                            self._value_field_name,
                            column.domain.dtype,
                            column.axes[0].size,
                            index.is_sorted)
mappers.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def segment_indices(segment_ids, name=None):
  """Returns a `Tensor` of indices within each segment.

  segment_ids should be a sequence of non-decreasing non-negative integers that
  define a set of segments, e.g. [0, 0, 1, 2, 2, 2] defines 3 segments of length
  2, 1 and 3.  The return value is a `Tensor` containing the indices within each
  segment.

  Example input: [0, 0, 1, 2, 2, 2]
  Example output: [0, 1, 0, 0, 1, 2]

  Args:
    segment_ids: A 1-d `Tensor` containing an non-decreasing sequence of
        non-negative integers with type `tf.int32` or `tf.int64`.
    name: (Optional) A name for this operation.

  Returns:
    A `Tensor` containing the indices within each segment.
  """
  with tf.name_scope(name, 'segment_indices'):
    segment_lengths = tf.segment_sum(tf.ones_like(segment_ids), segment_ids)
    segment_starts = tf.gather(tf.concat([[0], tf.cumsum(segment_lengths)], 0),
                               segment_ids)
    return (tf.range(tf.size(segment_ids, out_type=segment_ids.dtype)) -
            segment_starts)
vocab_batcher.py 文件源码 项目:text2text 作者: google 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def _GetExFeatureText(self, example, key):
    """Extracts text for a feature from tf.Example.

    Args:
      example: tf.Example.
      key: Key of the feature to be extracted.

    Returns:
      A feature text extracted.
    """

    values = []
    for value in example.features.feature[key].bytes_list.value:
      values.append(value.decode("utf-8"))

    return values
copynet_batcher.py 文件源码 项目:text2text 作者: google 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def _GetExFeatureText(self, example, key):
    """Extracts text for a feature from tf.Example.

    Args:
      example: tf.Example.
      key: Key of the feature to be extracted.

    Returns:
      A feature text extracted.
    """

    values = []
    for value in example.features.feature[key].bytes_list.value:
      values.append(value.decode("utf-8"))

    return values
texttfrecords.py 文件源码 项目:tefla 作者: openAGI 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def to_example(self, dictionary):
        """Helper: build tf.Example from (string -> int/float/str list) dictionary."""
        features = {}
        for (k, v) in six.iteritems(dictionary):
            if not v:
                raise ValueError("Empty generated field: %s", str((k, v)))
            if isinstance(v[0], six.integer_types):
                features[k] = tf.train.Feature(
                    int64_list=tf.train.Int64List(value=v))
            elif isinstance(v[0], float):
                features[k] = tf.train.Feature(
                    float_list=tf.train.FloatList(value=v))
            elif isinstance(v[0], six.string_types):
                if not six.PY2:  # Convert in python 3.
                    v = [bytes(x, "utf-8") for x in v]
                features[k] = tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=v))
            elif isinstance(v[0], bytes):
                features[k] = tf.train.Feature(
                    bytes_list=tf.train.BytesList(value=v))
            else:
                raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
                                 (k, str(v[0]), str(type(v[0]))))
        return tf.train.Example(features=tf.train.Features(feature=features))
task.py 文件源码 项目:kaggle-youtube-8m 作者: liufuyang 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def get_placeholder_input_fn(config, model_type, vocab_sizes, use_crosses):
  """Wrap the get input features function to provide the metadata."""

  def get_input_features():
    """Read the input features from the given placeholder."""
    columns = feature_columns(config, model_type, vocab_sizes, use_crosses)
    feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(columns)

    # Add a dense feature for the keys, use '' if not on the tf.Example proto.
    feature_spec[KEY_FEATURE_COLUMN] = tf.FixedLenFeature(
        [1], dtype=tf.string, default_value='')

    # Add a placeholder for the serialized tf.Example proto input.
    examples = tf.placeholder(tf.string, shape=(None,))

    features = tf.parse_example(examples, feature_spec)
    # Pass the input tensor so it can be used for export.
    features[EXAMPLES_PLACEHOLDER_KEY] = examples
    return features, None

  # Return a function to input the feaures into the model from a placeholder.
  return get_input_features
02_tfrecord_example.py 文件源码 项目:tf_oreilly 作者: chiphuyen 项目源码 文件源码 阅读 35 收藏 0 点赞 0 评论 0
def read_from_tfrecord(filenames):
    tfrecord_file_queue = tf.train.string_input_producer(filenames, name='queue')
    reader = tf.TFRecordReader()
    _, tfrecord_serialized = reader.read(tfrecord_file_queue)

    # label and image are stored as bytes but could be stored as 
    # int64 or float64 values in a serialized tf.Example protobuf.
    tfrecord_features = tf.parse_single_example(tfrecord_serialized,
                        features={
                            'label': tf.FixedLenFeature([], tf.int64),
                            'shape': tf.FixedLenFeature([], tf.string),
                            'image': tf.FixedLenFeature([], tf.string),
                        }, name='features')
    # image was saved as uint8, so we have to decode as uint8.
    image = tf.decode_raw(tfrecord_features['image'], tf.uint8)
    shape = tf.decode_raw(tfrecord_features['shape'], tf.int32)
    # the image tensor is flattened out, so we have to reconstruct the shape
    image = tf.reshape(image, shape)
    label = tfrecord_features['label']
    return label, shape, image
generator_utils.py 文件源码 项目:tensor2tensor 作者: tensorflow 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def to_example(dictionary):
  """Helper: build tf.Example from (string -> int/float/str list) dictionary."""
  features = {}
  for (k, v) in six.iteritems(dictionary):
    if not v:
      raise ValueError("Empty generated field: %s", str((k, v)))
    if isinstance(v[0], six.integer_types):
      features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
    elif isinstance(v[0], float):
      features[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v))
    elif isinstance(v[0], six.string_types):
      if not six.PY2:  # Convert in python 3.
        v = [bytes(x, "utf-8") for x in v]
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
    elif isinstance(v[0], bytes):
      features[k] = tf.train.Feature(bytes_list=tf.train.BytesList(value=v))
    else:
      raise ValueError("Value for %s is not a recognized type; v: %s type: %s" %
                       (k, str(v[0]), str(type(v[0]))))
  return tf.train.Example(features=tf.train.Features(feature=features))
cluster_measurements.py 文件源码 项目:scalable_analytics 作者: broadinstitute 项目源码 文件源码 阅读 64 收藏 0 点赞 0 评论 0
def _predict_input_fn():
  """Supplies the input to the model.

  Returns:
    A tuple consisting of 1) a dictionary of tensors whose keys are
    the feature names, and 2) a tensor of target labels which for
    clustering must be 'None'.
  """

  # Add a placeholder for the serialized tf.Example proto input.
  examples = tf.placeholder(tf.string, shape=(None,), name="examples")

  raw_features = tf.parse_example(examples, _get_feature_columns())

  dense = _raw_features_to_dense_tensor(raw_features)

  return input_fn_utils.InputFnOps(
      features={DENSE_KEY: dense},
      labels=None,
      default_inputs={EXAMPLE_KEY: examples})
preprocess_measurements.py 文件源码 项目:scalable_analytics 作者: broadinstitute 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def measurements_to_examples(input_data):
  """Converts sparse measurements to TensorFlow Example protos.

  Args:
    input_data: dictionary objects with keys from
      DATA_QUERY_REPLACEMENTS

  Returns:
    TensorFlow Example protos.
  """
  meas_kvs = input_data | 'BucketMeasurements' >> beam.Map(
      lambda row: (row[SAMPLE_COLUMN], row))

  sample_meas_kvs = meas_kvs | 'GroupBySample' >> beam.GroupByKey()

  examples = (
      sample_meas_kvs
      | 'SamplesToExamples' >>
      beam.Map(lambda (key, vals): sample_measurements_to_example(key, vals)))

  return examples
variants_inference.py 文件源码 项目:cloudml-examples 作者: googlegenomics 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def _predict_input_fn():
  """Supplies the input to the model.

  Returns:
    A tuple consisting of 1) a dictionary of tensors whose keys are
    the feature names, and 2) a tensor of target labels if the mode
    is not INFER (and None, otherwise).
  """
  feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
      feature_columns=_get_feature_columns(include_target_column=False))

  feature_spec[FLAGS.id_field] = tf.FixedLenFeature([], dtype=tf.string)
  feature_spec[FLAGS.target_field + "_string"] = tf.FixedLenFeature(
      [], dtype=tf.string)

  # Add a placeholder for the serialized tf.Example proto input.
  examples = tf.placeholder(tf.string, shape=(None,), name="examples")

  features = tf.parse_example(examples, feature_spec)
  features[PREDICTION_KEY] = features[FLAGS.id_field]

  inputs = {PREDICTION_EXAMPLES: examples}

  return input_fn_utils.InputFnOps(
      features=features, labels=None, default_inputs=inputs)
create_fashion_tf_record.py 文件源码 项目:CRF-image-segmentation 作者: therealnidhin 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def get_class_name_from_filename(file_name):
  """Gets the class name from a file.

  Args:
    file_name: The file name to get the class name from.
               ie. "american_pit_bull_terrier_105.jpg"

  Returns:
    example: The converted tf.Example.
  """
  match = re.match(r'([A-Za-z_]+)(-[0-9]+\.jpg)', file_name, re.I)
  return match.groups()[0]
convert_to_tf_example_cifar10.py 文件源码 项目:easy-tensorflow 作者: khanhptnk 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def convert(data, filename):
  images = data["data"]
  labels = data["labels"]
  num_examples = images.shape[0]
  with tf.python_io.TFRecordWriter(filename) as writer:
    for i in xrange(num_examples):
      logging.info("Writing batch " + str(i) + "/" + str(num_examples))
      image = [int(x) for x in images[i, :]]
      label = labels[i]
      example = tf.train.Example()
      features_map = example.features.feature
      features_map["image"].int64_list.value.extend(list(image))
      features_map["label"].int64_list.value.append(label)
      writer.write(example.SerializeToString())
train_eval_base.py 文件源码 项目:easy-tensorflow 作者: khanhptnk 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _load_data(self):
    """Load data from files of tf.Example protos."""
    keys, examples = self._input_reader.read_input(
        self._data_path,
        self._config.batch_size,
        randomize_input=self._model.is_training,
        distort_inputs=self._model.is_training)

    self._observations = examples["decoded_observation"]
    self._labels = examples["decoded_label"]
reader.py 文件源码 项目:magenta 作者: tensorflow 项目源码 文件源码 阅读 38 收藏 0 点赞 0 评论 0
def get_example(self, batch_size):
    """Get a single example from the tfrecord file.

    Args:
      batch_size: Int, minibatch size.

    Returns:
      tf.Example protobuf parsed from tfrecord.
    """
    reader = tf.TFRecordReader()
    num_epochs = None if self.is_training else 1
    capacity = batch_size
    path_queue = tf.train.input_producer(
        [self.record_path],
        num_epochs=num_epochs,
        shuffle=self.is_training,
        capacity=capacity)
    unused_key, serialized_example = reader.read(path_queue)
    features = {
        "note_str": tf.FixedLenFeature([], dtype=tf.string),
        "pitch": tf.FixedLenFeature([1], dtype=tf.int64),
        "velocity": tf.FixedLenFeature([1], dtype=tf.int64),
        "audio": tf.FixedLenFeature([64000], dtype=tf.float32),
        "qualities": tf.FixedLenFeature([10], dtype=tf.int64),
        "instrument_source": tf.FixedLenFeature([1], dtype=tf.int64),
        "instrument_family": tf.FixedLenFeature([1], dtype=tf.int64),
    }
    example = tf.parse_single_example(serialized_example, features)
    return example
example_proto_coder.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def encode_value(self, values):
    """Encodes a feature into its Example proto representation."""
    del self._value[:]
    if self._rank == 0:
      self._value.append(self._cast_fn(values))
    else:
      flattened_values = (values if self._rank == 1 else
                          np.asarray(values).reshape(-1))
      if len(flattened_values) != self._size:
        raise ValueError('FixedLenFeature %r got wrong number of values. '
                         'Expected %d but got %d' %
                         (self._name, self._size, len(flattened_values)))
      self._value.extend(self._cast_fn(flattened_values))
example_proto_coder.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def decode(self, serialized_example_proto):
    """Decode serialized tf.Example as a tf.transform encoded dict."""
    if self._decode_example_cache is None:
      # Initialize the decode Example cache (used by this and all subsequent
      # calls to decode).
      self._decode_example_cache = tf.train.Example()

    example = self._decode_example_cache
    example.ParseFromString(serialized_example_proto)
    feature_map = example.features.feature
    return {feature_handler.name: feature_handler.parse_value(feature_map)
            for feature_handler in self._feature_handlers}
dataset_schema.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 45 收藏 0 点赞 0 评论 0
def as_feature_spec(self, column):
    if not column.is_fixed_size():
      raise ValueError('A column of unknown size cannot be represented as '
                       'fixed-size.')
    if column.domain.dtype not in _TF_EXAMPLE_ALLOWED_TYPES:
      raise ValueError('tf.Example parser supports only types {}, so it is '
                       'invalid to generate a feature_spec with type '
                       '{}.'.format(
                           _TF_EXAMPLE_ALLOWED_TYPES,
                           repr(column.domain.dtype)))
    return tf.FixedLenFeature(column.tf_shape().as_list(),
                              column.domain.dtype,
                              self.default_value)
dataset_schema_test.py 文件源码 项目:transform 作者: tensorflow 项目源码 文件源码 阅读 70 收藏 0 点赞 0 评论 0
def test_feature_spec_unsupported_dtype(self):
    schema = sch.Schema()
    schema.column_schemas['fixed_float_with_default'] = (
        sch.ColumnSchema(tf.float64, [1], sch.FixedColumnRepresentation(0.0)))

    with self.assertRaisesRegexp(ValueError,
                                 'tf.Example parser supports only types '
                                 r'\[tf.string, tf.int64, tf.float32, tf.bool\]'
                                 ', so it is invalid to generate a feature_spec'
                                 ' with type tf.float64.'):
      schema.as_feature_spec()
create_coco_tf_record.py 文件源码 项目:tensorflow_object_detection_create_coco_tfrecord 作者: MetaPeak 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def dict_to_coco_example(img_data):
    """Convert python dictionary formath data of one image to tf.Example proto.
    Args:
        img_data: infomation of one image, inclue bounding box, labels of bounding box,\
            height, width, encoded pixel data.
    Returns:
        example: The converted tf.Example
    """
    bboxes = img_data['bboxes']
    xmin, xmax, ymin, ymax = [], [], [], []
    for bbox in bboxes:
        xmin.append(bbox[0])
        xmax.append(bbox[0] + bbox[2])
        ymin.append(bbox[1])
        ymax.append(bbox[1] + bbox[3])

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(img_data['height']),
        'image/width': dataset_util.int64_feature(img_data['width']),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
        'image/object/class/label': dataset_util.int64_list_feature(img_data['labels']),
        'image/encoded': dataset_util.bytes_feature(img_data['pixel_data']),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf-8')),
    }))
    return example
batch_reader.py 文件源码 项目:FYP-AutoTextSum 作者: MrRexZ 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _TextGenerator(self, example_gen):
    """Generates article and abstract text from tf.Example."""
    while True:
      e = next(example_gen)
      try:
        article_text = self._GetExFeatureText(e, self._article_key)
        abstract_text = self._GetExFeatureText(e, self._abstract_key)
      except ValueError:
        tf.logging.error('Failed to get article or abstract from example')
        continue

      yield (article_text, abstract_text)
batch_reader.py 文件源码 项目:FYP-AutoTextSum 作者: MrRexZ 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def _GetExFeatureText(self, ex, key):
    """Extract text for a feature from td.Example.

    Args:
      ex: tf.Example.
      key: key of the feature to be extracted.
    Returns:
      feature: a feature text extracted.
    """
    return ex.features.feature[key].bytes_list.value[0]
vocab_batcher.py 文件源码 项目:text2text 作者: google 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def __init__(self, data_path, config):
    """Batcher initializer.

    Args:
      data_path: tf.Example filepattern.
      config: model hyperparameters.
    """
    self._data_path = data_path
    self._config = config
    self._input_vocab = config.input_vocab
    self._output_vocab = config.output_vocab
    self._source_key = config.source_key
    self._target_key = config.target_key
    self.use_bucketing = config.use_bucketing
    self._truncate_input = config.truncate_input
    self._input_queue = queue.Queue(QUEUE_NUM_BATCH * config.batch_size)
    self._bucket_input_queue = queue.Queue(QUEUE_NUM_BATCH)
    self._input_threads = []
    for _ in range(DAEMON_READER_THREADS):
      self._input_threads.append(Thread(target=self._FillInputQueue))
      self._input_threads[-1].daemon = True
      self._input_threads[-1].start()
    self._bucketing_threads = []
    for _ in range(BUCKETING_THREADS):
      self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
      self._bucketing_threads[-1].daemon = True
      self._bucketing_threads[-1].start()

    self._watch_thread = Thread(target=self._WatchThreads)
    self._watch_thread.daemon = True
    self._watch_thread.start()
copynet_batcher.py 文件源码 项目:text2text 作者: google 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def __init__(self, data_path, config):
    """Batcher initializer.

    Args:
      data_path: tf.Example filepattern.
      config: model hyperparameters.
    """
    self._data_path = data_path
    self._config = config
    self._input_vocab = config.input_vocab
    self._output_vocab = config.output_vocab
    self._source_key = config.source_key
    self._target_key = config.target_key
    self.use_bucketing = config.use_bucketing
    self._truncate_input = config.truncate_input
    self._input_queue = queue.Queue(QUEUE_NUM_BATCH * config.batch_size)
    self._bucket_input_queue = queue.Queue(QUEUE_NUM_BATCH)
    self._input_threads = []
    for _ in range(DAEMON_READER_THREADS):
      self._input_threads.append(Thread(target=self._FillInputQueue))
      self._input_threads[-1].daemon = True
      self._input_threads[-1].start()
    self._bucketing_threads = []
    for _ in range(BUCKETING_THREADS):
      self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
      self._bucketing_threads[-1].daemon = True
      self._bucketing_threads[-1].start()

    self._watch_thread = Thread(target=self._WatchThreads)
    self._watch_thread.daemon = True
    self._watch_thread.start()
utils.py 文件源码 项目:TensorflowFramework 作者: vahidk 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def parallel_record_writer(iterator, create_example, path, num_threads=4):
  """Create a RecordIO file from data for efficient reading."""

  def _queue(inputs):
    for item in iterator:
      inputs.put(item)
    for _ in range(num_threads):
      inputs.put(None)

  def _map_fn(inputs, outputs):
    while True:
      item = inputs.get()
      if item is None:
        break
      example = create_example(item)
      outputs.put(example)
    outputs.put(None)

  # Read the inputs.
  inputs = mp.Queue()
  mp.Process(target=_queue, args=(inputs,)).start()

  # Convert to tf.Example
  outputs = mp.Queue()
  for _ in range(num_threads):
    mp.Process(target=_map_fn, args=(inputs, outputs)).start()

  # Write the output to file.
  writer = tf.python_io.TFRecordWriter(path)
  counter = 0
  while True:
    example = outputs.get()
    if example is None:
      counter += 1
      if counter == num_threads:
        break
      else:
        continue
    writer.write(example.SerializeToString())
  writer.close()
texttfrecords.py 文件源码 项目:tefla 作者: openAGI 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def generate_files(self, generator, output_filenames, max_cases=None):
        """Generate cases from a generator and save as TFRecord files.

        Generated cases are transformed to tf.Example protos and saved as TFRecords
        in sharded files named output_dir/output_name-00..N-of-00..M=num_shards.

        Args:
          generator: a generator yielding (string -> int/float/str list) dictionaries.
          output_filenames: List of output file paths.
          max_cases: maximum number of cases to get from the generator;
            if None (default), we use the generator until StopIteration is raised.
        """
        num_shards = len(output_filenames)
        writers = [tf.python_io.TFRecordWriter(
            fname) for fname in output_filenames]
        counter, shard = 0, 0
        for case in generator:
            if counter > 0 and counter % 100000 == 0:
                tf.logging.info("Generating case %d." % counter)
            counter += 1
            if max_cases and counter > max_cases:
                break
            sequence_example = self.to_example(case)
            writers[shard].write(sequence_example.SerializeToString())
            shard = (shard + 1) % num_shards

        for writer in writers:
            writer.close()
batch_reader.py 文件源码 项目:savchenko 作者: JuleLaryushina 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def _TextGenerator(self, example_gen):
    """Generates article and abstract text from tf.Example."""
    while True:
      e = example_gen.next()
      try:
        article_text = self._GetExFeatureText(e, self._article_key)
        abstract_text = self._GetExFeatureText(e, self._abstract_key)
      except ValueError:
        tf.logging.error('Failed to get article or abstract from example')
        continue

      yield (article_text, abstract_text)
batch_reader.py 文件源码 项目:savchenko 作者: JuleLaryushina 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _GetExFeatureText(self, ex, key):
    """Extract text for a feature from td.Example.

    Args:
      ex: tf.Example.
      key: key of the feature to be extracted.
    Returns:
      feature: a feature text extracted.
    """
    return ex.features.feature[key].bytes_list.value[0]


问题


面经


文章

微信
公众号

扫码关注公众号