def testWeightedSparseColumnDtypes(self):
ids = tf.contrib.layers.sparse_column_with_keys(
"ids", ["marlo", "omar", "stringer"])
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights")
self.assertDictEqual(
{"ids": tf.VarLenFeature(tf.string),
"weights": tf.VarLenFeature(tf.float32)},
weighted_ids.config)
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
dtype=tf.int32)
self.assertDictEqual(
{"ids": tf.VarLenFeature(tf.string),
"weights": tf.VarLenFeature(tf.int32)},
weighted_ids.config)
with self.assertRaisesRegexp(ValueError,
"dtype is not convertible to float"):
weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights",
dtype=tf.string)
评论列表
文章目录