def check_type_forward(self, in_types):
n_in = type_check.eval(in_types.size())
if n_in != 3:
raise type_check.InvalidType(
'%s == %s' % (in_types.size(), n_in))
x_type, gamma_type, beta_type = in_types[:3]
M = type_check.eval(gamma_type.ndim)
type_check.expect(
x_type.dtype.kind == 'f',
x_type.ndim >= gamma_type.ndim + 1,
x_type.shape[1:1 + M] == gamma_type.shape,
gamma_type.dtype == x_type.dtype,
beta_type.dtype == x_type.dtype,
gamma_type.shape == beta_type.shape,
)
function.py 文件源码
python
阅读 39
收藏 0
点赞 0
评论 0
评论列表
文章目录