test.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:yaset 作者: jtourille 项目源码 文件源码
def read_and_decode_test(filename_queue, feature_columns):
    """
    Read and decode one example from a TFRecords file
    :param feature_columns: list of feature columns
    :param filename_queue: filename queue containing the TFRecords filenames
    :return: list of tensors representing one example
    """

    with tf.device('/cpu:0'):

        # New TFRecord file
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)

        # Contextual TFRecords features
        context_features = {
            "x_length": tf.FixedLenFeature([], dtype=tf.int64),
            "x_id": tf.FixedLenFeature([], dtype=tf.string)
        }

        # Sequential TFRecords features
        sequence_features = {
            "x_tokens": tf.FixedLenSequenceFeature([], dtype=tf.int64),
            "x_chars": tf.FixedLenSequenceFeature([], dtype=tf.int64),
            "x_chars_len": tf.FixedLenSequenceFeature([], dtype=tf.int64),
        }

        for col in feature_columns:
            sequence_features["x_att_{}".format(col)] = tf.FixedLenSequenceFeature([], dtype=tf.int64)

        # Parsing contextual and sequential features
        context_parsed, sequence_parsed = tf.parse_single_sequence_example(
            serialized=serialized_example,
            context_features=context_features,
            sequence_features=sequence_features
        )

        sequence_length = tf.cast(context_parsed["x_length"], tf.int32)
        chars = tf.reshape(sequence_parsed["x_chars"], tf.stack([sequence_length, -1]))

        # Preparing tensor list, casting values to 32 bits when necessary
        tensor_list = [
            context_parsed["x_id"],
            tf.cast(context_parsed["x_length"], tf.int32),
            tf.cast(sequence_parsed["x_tokens"], dtype=tf.int32),
            tf.cast(chars, dtype=tf.int32),
            tf.cast(sequence_parsed["x_chars_len"], dtype=tf.int32),
        ]

        for col in feature_columns:
            tensor_list.append(tf.cast(sequence_parsed["x_att_{}".format(col)], dtype=tf.int32))

        return tensor_list
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号