def _shard_indices(self, keys):
key_shape = keys.get_shape()
if key_shape.ndims > 1:
# If keys are a matrix (i.e. a single key is a vector), we use the first
# element of each key vector to determine the shard.
keys = array_ops.slice(keys, [0, 0], [key_shape[0].value, 1])
keys = array_ops.reshape(keys, [-1])
indices = math_ops.mod(math_ops.abs(keys), self._num_shards)
return math_ops.cast(indices, dtypes.int32)
评论列表
文章目录