def test_keyed_features_filter(self):
gfile.Glob = self._orig_glob
lines = [
'{"features": {"feature": {"age": {"int64_list": {"value": [2]}}}}}',
'{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
'{"features": {"feature": {"age": {"int64_list": {"value": [1]}}}}}',
'{"features": {"feature": {"age": {"int64_list": {"value": [0]}}}}}',
'{"features": {"feature": {"age": {"int64_list": {"value": [3]}}}}}',
'{"features": {"feature": {"age": {"int64_list": {"value": [5]}}}}}'
]
filename = self._create_temp_file("\n".join(lines))
batch_size = 2
queue_capacity = 4
name = "my_batch"
features = {"age": parsing_ops.FixedLenFeature([], dtypes_lib.int64)}
def filter_fn(keys, examples_json):
del keys
serialized = parsing_ops.decode_json_example(examples_json)
examples = parsing_ops.parse_example(serialized, features)
return math_ops.less(examples["age"], 2)
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
keys, inputs = graph_io._read_keyed_batch_examples_helper(
filename,
batch_size,
reader=io_ops.TextLineReader,
randomize_input=False,
num_epochs=1,
read_batch_size=batch_size,
queue_capacity=queue_capacity,
filter_fn=filter_fn,
name=name)
self.assertAllEqual((None,), keys.get_shape().as_list())
self.assertAllEqual((None,), inputs.get_shape().as_list())
session.run(variables.local_variables_initializer())
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
# First batch of two filtered examples.
out_keys, out_vals = session.run((keys, inputs))
self.assertAllEqual(
[filename.encode("utf-8") + b":2", filename.encode("utf-8") + b":3"],
out_keys)
self.assertAllEqual([lines[1].encode("utf-8"), lines[2].encode("utf-8")],
out_vals)
# Second batch will only have one filtered example as that's the only
# remaining example that satisfies the filtering criterion.
out_keys, out_vals = session.run((keys, inputs))
self.assertAllEqual([filename.encode("utf-8") + b":4"], out_keys)
self.assertAllEqual([lines[3].encode("utf-8")], out_vals)
# Exhausted input.
with self.assertRaises(errors.OutOfRangeError):
session.run((keys, inputs))
coord.request_stop()
coord.join(threads)
graph_io_test.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录