def test_join_matrix_ints(self):
if "float32" in self.shared.__name__:
raise SkipTest(
"The shared variable constructor"
" need to support other dtype then float32")
# Test mixed dtype. There was a bug that caused crash in the past.
av = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='int8')
bv = numpy.array([[7], [8]], dtype='int32')
a = self.shared(av)
b = as_tensor_variable(bv)
s = join(1, a, b)
want = numpy.array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype='float32')
out = self.eval_outputs_and_check_join([s])
self.assertTrue((out == want).all())
assert (numpy.asarray(grad(s.sum(), b).eval()) == 0).all()
assert (numpy.asarray(grad(s.sum(), a).eval()) == 0).all()
评论列表
文章目录