def test_local_log_sum_exp1():
# Tests if optimization is applied by checking the presence of the maximum
x = tensor3('x')
check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(2,), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=None)
check_max_log_sum_exp(x, axis=(0, 1, 2), dimshuffle_op=None)
# If a transpose is applied to the sum
transpose_op = DimShuffle((False, False), (1, 0))
check_max_log_sum_exp(x, axis=2, dimshuffle_op=transpose_op)
# If the sum is performed with keepdims=True
x = TensorType(dtype='floatX', broadcastable=(False, True, False))('x')
sum_keepdims_op = x.sum(axis=(0, 1), keepdims=True).owner.op
check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op)
评论列表
文章目录