feature_column_test.py 文件源码

python
阅读 51 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testCreateSequenceFeatureSpec(self):
    sparse_col = tf.contrib.layers.sparse_column_with_hash_bucket(
        "sparse_column", hash_bucket_size=100)
    embedding_col = tf.contrib.layers.embedding_column(
        tf.contrib.layers.sparse_column_with_hash_bucket(
            "sparse_column_for_embedding",
            hash_bucket_size=10),
        dimension=4)
    sparse_id_col = tf.contrib.layers.sparse_column_with_keys(
        "id_column", ["marlo", "omar", "stringer"])
    weighted_id_col = tf.contrib.layers.weighted_sparse_column(
        sparse_id_col, "id_weights_column")
    real_valued_col1 = tf.contrib.layers.real_valued_column(
        "real_valued_column", dimension=2)
    real_valued_col2 = tf.contrib.layers.real_valued_column(
        "real_valued_default_column", dimension=5, default_value=3.0)

    feature_columns = set([sparse_col, embedding_col, weighted_id_col,
                           real_valued_col1, real_valued_col2])

    feature_spec = fc._create_sequence_feature_spec_for_parsing(feature_columns)

    expected_feature_spec = {
        "sparse_column": tf.VarLenFeature(tf.string),
        "sparse_column_for_embedding": tf.VarLenFeature(tf.string),
        "id_column": tf.VarLenFeature(tf.string),
        "id_weights_column": tf.VarLenFeature(tf.float32),
        "real_valued_column": tf.FixedLenSequenceFeature(
            shape=[2], dtype=tf.float32, allow_missing=False),
        "real_valued_default_column": tf.FixedLenSequenceFeature(
            shape=[5], dtype=tf.float32, allow_missing=True)}

    self.assertDictEqual(expected_feature_spec, feature_spec)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号