def test_local_log_sum_exp2():
# Tests if the optimization works (result is correct) around 1.0
x = tensor3('x')
x_val = 1.0 + numpy.random.rand(4, 3, 2).astype(config.floatX) / 10.0
f = compile_graph_log_sum_exp(x, axis=(1,))
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1))
optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret)
# If a transpose is applied
transpose_op = DimShuffle((False, False), (1, 0))
f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op)
naive_ret = numpy.log(numpy.sum(numpy.exp(x_val), axis=1).T)
optimised_ret = f(x_val)
assert numpy.allclose(naive_ret, optimised_ret)
评论列表
文章目录