def check_two_tensor_operation(function_name, x_input_shape,
y_input_shape, **kwargs):
xval = np.random.random(x_input_shape) - 0.5
xth = KTH.variable(xval)
xtf = KTF.variable(xval)
yval = np.random.random(y_input_shape) - 0.5
yth = KTH.variable(yval)
ytf = KTF.variable(yval)
zth = KTH.eval(getattr(KTH, function_name)(xth, yth, **kwargs))
ztf = KTF.eval(getattr(KTF, function_name)(xtf, ytf, **kwargs))
assert zth.shape == ztf.shape
assert_allclose(zth, ztf, atol=1e-05)
评论列表
文章目录