def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols):
x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y)
p_data = tensor.as_tensor_variable(p_data)
p_ind = tensor.as_tensor_variable(p_ind)
p_ptr = tensor.as_tensor_variable(p_ptr)
p_ncols = tensor.as_tensor_variable(p_ncols)
assert p_ncols.dtype == 'int32'
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype,
p_data.type.dtype)
dot_out = scalar.upcast(x.type.dtype, y.type.dtype)
# We call blas ?dot function that take only param of the same type
x = tensor.cast(x, dot_out)
y = tensor.cast(y, dot_out)
return gof.Apply(self, [x, y, p_data, p_ind, p_ptr, p_ncols], [
tensor.tensor(dtype=dtype_out, broadcastable=(False,)),
tensor.tensor(dtype=p_ind.type.dtype, broadcastable=(False,)),
tensor.tensor(dtype=p_ptr.type.dtype, broadcastable=(False,))
])
评论列表
文章目录