def setUp(self):
super(CoreBinaryOpsTest, self).setUp()
self.x_probs_broadcast_tensor = tf.reshape(
self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size])
self.channel_probs_broadcast_tensor = tf.reshape(
self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size])
# == and != are not element-wise for tf.Tensor, so they shouldn't be
# elementwise for LabeledTensor, either.
self.ops = [
('add', operator.add, tf.add, core.add),
('sub', operator.sub, tf.sub, core.sub),
('mul', operator.mul, tf.mul, core.mul),
('div', operator.truediv, tf.div, core.div),
('mod', operator.mod, tf.mod, core.mod),
('pow', operator.pow, tf.pow, core.pow_function),
('equal', None, tf.equal, core.equal),
('less', operator.lt, tf.less, core.less),
('less_equal', operator.le, tf.less_equal, core.less_equal),
('not_equal', None, tf.not_equal, core.not_equal),
('greater', operator.gt, tf.greater, core.greater),
('greater_equal', operator.ge, tf.greater_equal, core.greater_equal),
]
self.test_lt_1 = self.x_probs_lt
self.test_lt_2 = self.channel_probs_lt
self.test_lt_1_broadcast = self.x_probs_broadcast_tensor
self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor
self.broadcast_axes = [self.a0, self.a1, self.a3]
评论列表
文章目录