def __init__(self, requests, expert_capacity):
"""Create a TruncatingDispatcher.
Args:
requests: a boolean `Tensor` of shape `[batch, length, num_experts]`.
Alternatively, a float or int Tensor containing zeros and ones.
expert_capacity: a Scalar - maximum number of examples per expert per
batch element.
Returns:
a TruncatingDispatcher
"""
self._requests = tf.to_float(requests)
self._expert_capacity = expert_capacity
expert_capacity_f = tf.to_float(expert_capacity)
self._batch, self._length, self._num_experts = tf.unstack(
tf.shape(self._requests), num=3)
# [batch, length, num_experts]
position_in_expert = tf.cumsum(self._requests, axis=1, exclusive=True)
# [batch, length, num_experts]
self._gates = self._requests * tf.to_float(
tf.less(position_in_expert, expert_capacity_f))
batch_index = tf.reshape(
tf.to_float(tf.range(self._batch)), [self._batch, 1, 1])
length_index = tf.reshape(
tf.to_float(tf.range(self._length)), [1, self._length, 1])
expert_index = tf.reshape(
tf.to_float(tf.range(self._num_experts)), [1, 1, self._num_experts])
# position in a Tensor with shape [batch * num_experts * expert_capacity]
flat_position = (
position_in_expert +
batch_index * (tf.to_float(self._num_experts) * expert_capacity_f) +
expert_index * expert_capacity_f)
# Tensor of shape [batch * num_experts * expert_capacity].
# each element is an integer in [0, length)
self._indices = tf.unsorted_segment_sum(
data=tf.reshape((length_index + 1.0) * self._gates, [-1]),
segment_ids=tf.to_int32(tf.reshape(flat_position, [-1])),
num_segments=self._batch * self._num_experts * expert_capacity)
self._indices = tf.reshape(
self._indices,
[self._batch, self._num_experts, expert_capacity])
# Tensors of shape [batch, num_experts, expert_capacity].
# each element is 0.0 or 1.0
self._nonpadding = tf.minimum(self._indices, 1.0)
# each element is an integer in [0, length)
self._indices = tf.nn.relu(self._indices - 1.0)
# self._flat_indices is [batch, num_experts, expert_capacity], with values
# in [0, batch * length)
self._flat_indices = tf.to_int32(
self._indices +
(tf.reshape(tf.to_float(tf.range(self._batch)), [-1, 1, 1])
* tf.to_float(self._length)))
self._indices = tf.to_int32(self._indices)
评论列表
文章目录