feature_column_ops_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testMultivalentCrossUsageInPredictionsWithPartition(self):
    # bucket size has to be big enough to allwo sharding.
    language = tf.contrib.layers.sparse_column_with_hash_bucket(
        "language", hash_bucket_size=64 << 19)
    country = tf.contrib.layers.sparse_column_with_hash_bucket(
        "country", hash_bucket_size=64 << 18)
    country_language = tf.contrib.layers.crossed_column(
        [language, country], hash_bucket_size=64 << 18)
    with tf.Graph().as_default():
      features = {
          "language": tf.SparseTensor(values=["english", "spanish"],
                                      indices=[[0, 0], [0, 1]],
                                      shape=[1, 2]),
          "country": tf.SparseTensor(values=["US", "SV"],
                                     indices=[[0, 0], [0, 1]],
                                     shape=[1, 2])
      }
      with tf.variable_scope(
          "weighted_sum_from_feature_columns",
          features.values(),
          partitioner=tf.min_max_variable_partitioner(
              max_partitions=10, min_slice_size=((64 << 20) - 1))) as scope:
        output, column_to_variable, _ = (
            tf.contrib.layers.weighted_sum_from_feature_columns(
                features, [country, language, country_language],
                num_outputs=1,
                scope=scope))
      with self.test_session() as sess:
        tf.initialize_all_variables().run()
        tf.initialize_all_tables().run()

        self.assertEqual(2, len(column_to_variable[country]))
        self.assertEqual(3, len(column_to_variable[language]))
        self.assertEqual(2, len(column_to_variable[country_language]))

        weights = column_to_variable[country_language]
        for partition_variable in weights:
          sess.run(partition_variable.assign(partition_variable + 0.4))
        # There are four crosses each with 0.4 weight.
        # score = 0.4 + 0.4 + 0.4 + 0.4
        self.assertAllClose(output.eval(), [[1.6]])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号