def is_1pexp(t, only_process_constants=True):
"""
Returns
-------
object
If 't' is of the form (1+exp(x)), return (False, x).
Else return None.
"""
if t.owner and t.owner.op == tensor.add:
scalars, scalar_inputs, nonconsts = \
opt.scalarconsts_rest(t.owner.inputs,
only_process_constants=only_process_constants)
# scalar_inputs are potentially dimshuffled and filled with scalars
if len(nonconsts) == 1:
maybe_exp = nonconsts[0]
if maybe_exp.owner and maybe_exp.owner.op == tensor.exp:
# Verify that the constant terms sum to 1.
if scalars:
scal_sum = scalars[0]
for s in scalars[1:]:
scal_sum = scal_sum + s
if numpy.allclose(scal_sum, 1):
return False, maybe_exp.owner.inputs[0]
# Before 7987b51 there used to be a bug where *any* constant
# was considered as if it was equal to 1, and thus this
# function would incorrectly identify it as (1 + exp(x)).
if config.warn.identify_1pexp_bug:
warnings.warn(
'Although your current code is fine, please note that '
'Theano versions prior to 0.5 (more specifically, '
'prior to commit 7987b51 on 2011-12-18) may have '
'yielded an incorrect result. To remove this warning, '
'either set the `warn.identify_1pexp_bug` config '
'option to False, or `warn.ignore_bug_before` to at '
'least \'0.4.1\'.')
return None
评论列表
文章目录