def test_cse():
e = a*a + b*b + sympy.exp(-a*a - b*b)
e2 = sympy.cse(e)
f = g.llvm_callable([a, b], e2)
res = float(e.subs({a: 2.3, b: 0.1}).evalf())
jit_res = f(2.3, 0.1)
assert isclose(jit_res, res)
python类cse()的实例源码
def test_cse_multiple():
e1 = a*a
e2 = a*a + b*b
e3 = sympy.cse([e1, e2])
raises(NotImplementedError,
lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate'))
f = g.llvm_callable([a, b], e3)
jit_res = f(0.1, 1.5)
assert len(jit_res) == 2
res = eval_cse(e3, {a: 0.1, b: 1.5})
assert isclose(res[0], jit_res[0])
assert isclose(res[1], jit_res[1])
def test_callback_cubature_multiple():
e1 = a*a
e2 = a*a + b*b
e3 = sympy.cse([e1, e2, 4*e2])
f = g.llvm_callable([a, b], e3, callback_type='cubature')
# Number of input variables
ndim = 2
# Number of output expression values
outdim = 3
m = ctypes.c_int(ndim)
fdim = ctypes.c_int(outdim)
array_type = ctypes.c_double * ndim
out_array_type = ctypes.c_double * outdim
inp = {a: 0.2, b: 1.5}
array = array_type(inp[a], inp[b])
out_array = out_array_type()
jit_ret = f(m, array, None, fdim, out_array)
assert jit_ret == 0
res = eval_cse(e3, inp)
assert isclose(out_array[0], res[0])
assert isclose(out_array[1], res[1])
assert isclose(out_array[2], res[2])
def genfcode(lambdastr, use_cse=False):
"""
Python lambda string -> C function code
Optionally cse() is used to eliminate common subexpressions.
"""
# TODO: verify lambda string
# interpret lambda string
varstr, fstr = lambdastr.split(': ')
varstr = varstr.lstrip('lambda ')
# generate C variable string
cvars = varstr.split(',')
cvarstr = ''
for v in cvars:
cvarstr += 'double %s, ' % v
cvarstr = cvarstr.rstrip(', ')
# convert function string to C syntax
if not use_cse:
cfstr = ''
finalexpr = cexpr(fstr)
else:
# eliminate common subexpressions
subs, finalexpr = cse(sympify(fstr), _gentmpvars())
if len(finalexpr) != 1:
raise ValueError("Length should be 1")
vardec = ''
cfstr = ''
for symbol, expr in subs:
vardec += ' double %s;\n' % symbol.name
cfstr += ' %s = %s;\n' % (
symbol.name, cexpr(str(expr.evalf(dps))))
cfstr = vardec + cfstr
finalexpr = cexpr(str(finalexpr[0].evalf(dps)))
# generate C code
code = """
inline double f(%s)
{
%s
return %s;
}
""" % (cvarstr, cfstr, finalexpr)
return code
def print_as_array(m, mname, sufix=None, use_cse=False, header=None,
print_file=True, collect_for=None, pow_by_mul=True, order='C',
op='+='):
ls = []
if use_cse:
subs, m_list = sympy.cse(m)
for i, v in enumerate(m_list):
m[i] = v
if sufix is None:
namesufix = '{0}'.format(mname)
else:
namesufix = '{0}_{1}'.format(mname, sufix)
filename = 'print_{0}.txt'.format(namesufix)
if header:
ls.append(header)
if use_cse:
ls.append('# cdefs')
num = 9
for i, sub in enumerate(subs[::num]):
ls.append('cdef double ' + ', '.join(
map(str, [j[0] for j in subs[num*i:num*(i+1)]])))
ls.append('# subs')
for sub in subs:
ls.append('{0} = {1}'.format(*sub))
ls.append('# {0}'.format(namesufix))
num = len([i for i in list(m) if i])
ls.append('# {0}_num={1}'.format(namesufix, num))
if order == 'C':
miter = enumerate(np.ravel(m))
elif order == 'F':
miter = enumerate(np.ravel(m.T))
miter = list(miter)
for i, v in miter:
if v:
if collect_for is not None:
v = collect(v, collect_for, evaluate=False)
ls.append('{0}[pos+{1}] +='.format(mname, i))
for k, expr in v.items():
ls.append('# collected for {k}'.format(k=k))
ls.append(' {expr}'.format(expr=k*expr))
else:
if pow_by_mul:
v = str(v)
for p in re.findall(r'\w+\*\*\d+', v):
var, exp = p.split('**')
v = v.replace(p, '(' + '*'.join([var]*int(exp)) + ')')
ls.append('{0}[pos+{1}] {2} {3}'.format(mname, i, op, v))
string = '\n'.join(ls)
if print_file:
with open(filename, 'w') as f:
f.write(string)
return string