def test_partial_function():
import numpy as np
from theano.tests import unittest_tools as utt
def check_partial_function(linker_name):
x = tensor.scalar('input')
y = x ** 2
f = theano.function([x], [y + 7, y - 9, y / 14.], mode=Mode(
optimizer=None, linker=linker_name))
assert f(3, output_subset=[0, 1, 2]) == f(3)
assert f(4, output_subset=[0, 2]) == [f(4)[0], f(4)[2]]
utt.assert_allclose(f(5), np.array([32., 16., 1.7857142857142858]))
check_partial_function(vm.VM_Linker(allow_partial_eval=True, use_cloop=False))
check_partial_function('cvm')
评论列表
文章目录