def concatenate(tensors, axis=-1):
if py_all([is_sparse(x) for x in tensors]):
axis = axis % ndim(tensors[0])
if axis == 0:
return th_sparse_module.basic.vstack(tensors, format='csr')
elif axis == 1:
return th_sparse_module.basic.hstack(tensors, format='csr')
else:
raise Exception('Invalid concat axis for sparse matrix: ' + axis)
else:
return T.concatenate([to_dense(x) for x in tensors], axis=axis)
评论列表
文章目录