def combine(self, x):
"""Return the output from the experts.
When one example goes to multiple experts, the outputs are summed.
Args:
x: a Tensor with shape [batch, num_experts, expert_capacity, depth]
Returns:
a `Tensor` with shape `[batch, length, depth]
"""
depth = tf.shape(x)[-1]
x *= tf.expand_dims(self._nonpadding, -1)
ret = tf.unsorted_segment_sum(
x, self._flat_indices, num_segments=self._batch * self._length)
ret = tf.reshape(ret, [self._batch, self._length, depth])
return ret
评论列表
文章目录