def _build_input_fn(input_file_pattern, batch_size, mode):
"""Build input function.
Args:
input_file_pattern: The file patter for examples
batch_size: Batch size
mode: The execution mode, as defined in tf.contrib.learn.ModeKeys.
Returns:
Tuple, dictionary of feature column name to tensor and labels.
"""
def _input_fn():
"""Supplies the input to the model.
Returns:
A tuple consisting of 1) a dictionary of tensors whose keys are
the feature names, and 2) a tensor of target labels if the mode
is not INFER (and None, otherwise).
"""
logging.info("Reading files from %s", input_file_pattern)
input_files = sorted(list(tf.gfile.Glob(input_file_pattern)))
logging.info("Reading files from %s", input_files)
include_target_column = (mode != tf.contrib.learn.ModeKeys.INFER)
features_spec = tf.contrib.layers.create_feature_spec_for_parsing(
feature_columns=_get_feature_columns(include_target_column))
if FLAGS.use_gzip:
def gzip_reader():
return tf.TFRecordReader(
options=tf.python_io.TFRecordOptions(
compression_type=TFRecordCompressionType.GZIP))
reader_fn = gzip_reader
else:
reader_fn = tf.TFRecordReader
features = tf.contrib.learn.io.read_batch_features(
file_pattern=input_files,
batch_size=batch_size,
queue_capacity=3*batch_size,
randomize_input=mode == tf.contrib.learn.ModeKeys.TRAIN,
feature_queue_capacity=FLAGS.feature_queue_capacity,
reader=reader_fn,
features=features_spec)
target = None
if include_target_column:
target = features.pop(FLAGS.target_field)
return features, target
return _input_fn
variants_inference.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录