def testMulticlassWithRealValuedColumn(self):
with tf.Graph().as_default():
column = tf.contrib.layers.real_valued_column("age")
features = {"age": tf.constant([[10.], [20.], [30.], [40.]])}
output, column_to_variable, _ = (
tf.contrib.layers.weighted_sum_from_feature_columns(features,
[column],
num_outputs=3))
with self.test_session() as sess:
tf.initialize_all_variables().run()
tf.initialize_all_tables().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (1, 3))
sess.run(weights.assign([[0.01, 0.03, 0.05]]))
self.assertAllClose(output.eval(), [[0.1, 0.3, 0.5], [0.2, 0.6, 1.0],
[0.3, 0.9, 1.5], [0.4, 1.2, 2.0]])
评论列表
文章目录