def check_type_forward(self, in_types):
type_check.expect(
in_types.size() == 1,
)
x_type, = in_types
cnt = _count_unknown_dims(self.shape)
if cnt == 0:
type_check.expect(
type_check.prod(x_type.shape) == type_check.prod(self.shape))
else:
known_size = 1
for s in self.shape:
if s > 0:
known_size *= s
size_var = type_check.Variable(known_size,
'known_size(=%d)' % known_size)
type_check.expect(
type_check.prod(x_type.shape) % size_var == 0)
评论列表
文章目录