def build_step(self, signals):
if self.len_match:
super(SparseDotIncBuilder, self).build_step(signals)
return
A = signals.gather(self.A_data)
X = signals.gather(self.X_data)
assert A.get_shape()[0] == self.sparse_indices.get_shape()[0]
# approach 1: using sparse_tensor_dense_matmul
dot = gen_sparse_ops._sparse_tensor_dense_mat_mul(
self.sparse_indices, A, self.A_shape, X)
# approach 2: matmul(a_is_sparse)
# sparse_A = tf.scatter_nd(self.sparse_indices, A, self.A_shape)
# dot = tf.matmul(sparse_A, X, a_is_sparse=self.is_sparse)
dot.set_shape(self.Y_data.shape + (signals.minibatch_size,))
signals.scatter(self.Y_data, dot, mode=self.mode)
评论列表
文章目录