def check_type_forward(self, in_types):
n_in = in_types.size()
type_check.expect(3 <= n_in, n_in <= 4)
x_type = in_types[0]
v_type = in_types[1]
g_type = in_types[2]
type_check.expect(
x_type.dtype.kind == "f",
v_type.dtype.kind == "f",
g_type.dtype.kind == "f",
x_type.ndim == 4,
v_type.ndim == 4,
g_type.ndim == 4,
x_type.shape[1] == v_type.shape[1],
)
if type_check.eval(n_in) == 4:
b_type = in_types[3]
type_check.expect(
b_type.dtype == x_type.dtype,
b_type.ndim == 1,
b_type.shape[0] == v_type.shape[0],
)
convolution_2d.py 文件源码
python
阅读 35
收藏 0
点赞 0
评论 0
评论列表
文章目录