test_nanguardmode.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_NanGuardMode():
    # Tests if NanGuardMode is working by feeding in numpy.inf and numpy.nans
    # intentionally. A working implementation should be able to capture all
    # the abnormalties.
    x = T.matrix()
    w = theano.shared(numpy.random.randn(5, 7).astype(theano.config.floatX))
    y = T.dot(x, w)

    fun = theano.function(
        [x], y,
        mode=NanGuardMode(nan_is_error=True, inf_is_error=True)
    )
    a = numpy.random.randn(3, 5).astype(theano.config.floatX)
    infa = numpy.tile(
        (numpy.asarray(100.) ** 1000000).astype(theano.config.floatX), (3, 5))
    nana = numpy.tile(
        numpy.asarray(numpy.nan).astype(theano.config.floatX), (3, 5))
    biga = numpy.tile(
        numpy.asarray(1e20).astype(theano.config.floatX), (3, 5))

    fun(a)  # normal values

    # Temporarily silence logger
    _logger = logging.getLogger("theano.compile.nanguardmode")
    try:
        _logger.propagate = False
        assert_raises(AssertionError, fun, infa)  # INFs
        assert_raises(AssertionError, fun, nana)  # NANs
        assert_raises(AssertionError, fun, biga)  # big values
    finally:
        _logger.propagate = True

    # slices
    a = numpy.random.randn(3, 4, 5).astype(theano.config.floatX)
    infa = numpy.tile(
        (numpy.asarray(100.) ** 1000000).astype(theano.config.floatX),
        (3, 4, 5))
    nana = numpy.tile(
        numpy.asarray(numpy.nan).astype(theano.config.floatX), (3, 4, 5))
    biga = numpy.tile(
        numpy.asarray(1e20).astype(theano.config.floatX), (3, 4, 5))

    x = T.tensor3()
    y = x[:, T.arange(2), T.arange(2)]
    fun = theano.function(
        [x], y,
        mode=NanGuardMode(nan_is_error=True, inf_is_error=True)
    )
    fun(a)  # normal values
    try:
        _logger.propagate = False
        assert_raises(AssertionError, fun, infa)  # INFs
        assert_raises(AssertionError, fun, nana)  # NANs
        assert_raises(AssertionError, fun, biga)  # big values
    finally:
        _logger.propagate = True
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号