def dot(x, y):
def _dot(X):
x, y = X
x_ndim = ndim(x)
y_ndim = ndim(y)
if x_ndim == 2 and y_ndim == 2:
return torch.mm(x, y)
if x_ndim == 2 and y_ndim == 1:
return torch.mv(x, y)
if x_ndim == 1 and y_ndim == 2:
return torch.mv(y, x)
if x_ndim == 1 and y_ndim == 1:
return torch.dot(x, y)
else:
raise Exception('Unsupported tensor ranks for dot operation : ' + str(x_ndim) + ' and ' + str(y_ndim) + '.')
def _compute_output_shape(X):
x, y = _get_shape(X[0]), _get_shape(X[1])
x_ndim = len(x)
y_ndim = len(y)
if x_ndim == 2 and y_ndim == 2:
return (x[0], y[1])
if x_ndim == 2 and y_ndim == 1:
return (x[0],)
if x_ndim == 1 and y_ndim == 2:
return (y[0],)
if x_ndim == 1 and y_ndim == 1:
return (0,)
return get_op(_dot, output_shape=_compute_output_shape)([x, y])
评论列表
文章目录