estimator_utils_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def test_to_feature_columns_and_input_fn(self):
    df = setup_test_df_3layer()
    feature_columns, input_fn = (
        estimator_utils.to_feature_columns_and_input_fn(
            df,
            base_input_keys_with_defaults={"a": 1,
                                           "b": 2,
                                           "c": 3,
                                           "d": 4},
            label_keys=["g"],
            feature_keys=["a", "b", "f"]))

    expected_feature_column_a = feature_column.DataFrameColumn(
        "a", learn.PredefinedSeries(
            "a", tf.FixedLenFeature(tensor_shape.unknown_shape(), tf.int32, 1)))
    expected_feature_column_b = feature_column.DataFrameColumn(
        "b", learn.PredefinedSeries("b", tf.VarLenFeature(tf.int32)))
    expected_feature_column_f = feature_column.DataFrameColumn(
        "f", learn.TransformedSeries([
            learn.PredefinedSeries("c", tf.FixedLenFeature(
                tensor_shape.unknown_shape(), tf.int32, 3)),
            learn.PredefinedSeries("d", tf.VarLenFeature(tf.int32))
        ], mocks.Mock2x2Transform("iue", "eui", "snt"), "out2"))

    expected_feature_columns = [expected_feature_column_a,
                                expected_feature_column_b,
                                expected_feature_column_f]
    self.assertEqual(sorted(expected_feature_columns), sorted(feature_columns))

    base_features, labels = input_fn()
    expected_base_features = {
        "a": mocks.MockTensor("Tensor a", tf.int32),
        "b": mocks.MockSparseTensor("SparseTensor b", tf.int32),
        "c": mocks.MockTensor("Tensor c", tf.int32),
        "d": mocks.MockSparseTensor("SparseTensor d", tf.int32)
    }
    self.assertEqual(expected_base_features, base_features)

    expected_labels = mocks.MockTensor("Out iue", tf.int32)
    self.assertEqual(expected_labels, labels)

    self.assertEqual(3, len(feature_columns))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号