def read_and_decode(filename_queue, feature_columns):
"""
Read and decode one example from a TFRecords file
:param feature_columns: list of feature columns
:param filename_queue: filename queue containing the TFRecords filenames
:return: list of tensors representing one example
"""
with tf.device('/cpu:0'):
# New TFRecord file
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
# Contextual TFRecords features
context_features = {
"x_length": tf.FixedLenFeature([], dtype=tf.int64),
"x_id": tf.FixedLenFeature([], dtype=tf.string)
}
# Sequential TFRecords features
sequence_features = {
"x_tokens": tf.FixedLenSequenceFeature([], dtype=tf.int64),
"x_chars": tf.FixedLenSequenceFeature([], dtype=tf.int64),
"x_chars_len": tf.FixedLenSequenceFeature([], dtype=tf.int64),
"y": tf.FixedLenSequenceFeature([], dtype=tf.int64),
}
for col in feature_columns:
sequence_features["x_att_{}".format(col)] = tf.FixedLenSequenceFeature([], dtype=tf.int64)
# Parsing contextual and sequential features
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=serialized_example,
context_features=context_features,
sequence_features=sequence_features
)
sequence_length = tf.cast(context_parsed["x_length"], tf.int32)
chars = tf.reshape(sequence_parsed["x_chars"], tf.stack([sequence_length, -1]))
# Preparing tensor list, casting values to 32 bits when necessary
tensor_list = [
context_parsed["x_id"],
tf.cast(context_parsed["x_length"], tf.int32),
tf.cast(sequence_parsed["x_tokens"], dtype=tf.int32),
tf.cast(chars, dtype=tf.int32),
tf.cast(sequence_parsed["x_chars_len"], dtype=tf.int32),
tf.cast(sequence_parsed["y"], dtype=tf.int32)
]
for col in feature_columns:
tensor_list.append(tf.cast(sequence_parsed["x_att_{}".format(col)], dtype=tf.int32))
return tensor_list
评论列表
文章目录