def test_sort_schedule_fn():
import theano
from theano.gof.sched import sort_schedule_fn, make_depends
x = theano.tensor.matrix('x')
y = theano.tensor.dot(x[:5] * 2, x.T + 1).T
def str_cmp(a, b):
return cmp(str(a), str(b)) # lexicographical sort
linker = theano.OpWiseCLinker(schedule=sort_schedule_fn(str_cmp))
mode = theano.Mode(linker=linker)
f = theano.function((x,), (y,), mode=mode)
nodes = f.maker.linker.make_all()[-1]
depends = make_depends()
for a, b in zip(nodes[:-1], nodes[1:]):
if not depends((b, a)):
assert str(a) < str(b)
评论列表
文章目录