def _full_batch_training_op(self, inputs, cluster_idx_list, cluster_centers):
"""Creates an op for training for full batch case.
Args:
inputs: list of input Tensors.
cluster_idx_list: A vector (or list of vectors). Each element in the
vector corresponds to an input row in 'inp' and specifies the cluster id
corresponding to the input.
cluster_centers: Tensor Ref of cluster centers.
Returns:
An op for doing an update of mini-batch k-means.
"""
cluster_sums = []
cluster_counts = []
epsilon = tf.constant(1e-6, dtype=inputs[0].dtype)
for inp, cluster_idx in zip(inputs, cluster_idx_list):
with ops.colocate_with(inp):
cluster_sums.append(tf.unsorted_segment_sum(inp,
cluster_idx,
self._num_clusters))
cluster_counts.append(tf.unsorted_segment_sum(
tf.reshape(tf.ones(tf.reshape(tf.shape(inp)[0], [-1])), [-1, 1]),
cluster_idx,
self._num_clusters))
with ops.colocate_with(cluster_centers):
new_clusters_centers = tf.add_n(cluster_sums) / (
tf.cast(tf.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
if self._clusters_l2_normalized():
new_clusters_centers = tf.nn.l2_normalize(new_clusters_centers, dim=1)
return tf.assign(cluster_centers, new_clusters_centers)
评论列表
文章目录