def check_type_forward(self, in_types):
n_in = in_types.size()
type_check.expect(4 <= n_in, n_in <= 5)
x_type = in_types[0]
w_type = in_types[1]
b_type = in_types[2]
ct_type = in_types[3]
type_check.expect(
x_type.dtype.kind == "f",
w_type.dtype.kind == "f",
b_type.dtype.kind == "f",
x_type.ndim == 3,
w_type.ndim == 2,
b_type.ndim == 1,
b_type.shape[0] * 3 == w_type.shape[0] * 2,
ct_type.dtype == x_type.dtype,
ct_type.ndim == 2,
ct_type.shape[1] == x_type.shape[1],
)
if type_check.eval(n_in) == 5:
mask_x_type = in_types[4]
type_check.expect(
mask_x_type.dtype == x_type.dtype,
mask_x_type.ndim == 2,
mask_x_type.shape[1] == x_type.shape[1],
)
# x: (batchsize, feature_dimension, seq_length)
评论列表
文章目录