def run(self, x, eta, idx_center=None, idx_sample=None):
""" x must be of size [B H W C] """
h = [None] * self.num_layer
embeddings = []
reg_ops = []
reset_ops = []
clustering_ops = []
with tf.variable_scope(self.scope):
for ii in xrange(self.num_layer):
if ii == 0:
input_vec = x
else:
input_vec = h[ii - 1]
h[ii] = tf.nn.conv2d(input_vec, self.w[ii], self.conv_filters[
'filter_stride'][ii], padding='SAME')
if self.add_bias:
h[ii] += self.b[ii]
if self.clustering_type[ii] == 'sample':
embedding = h[ii]
elif self.clustering_type[ii] == 'spatial':
embedding = h[ii]
elif self.clustering_type[ii] == 'channel':
embedding = tf.transpose(h[ii], [0, 3, 1, 2])
if self.clustering_shape[ii] is not None:
embedding = tf.reshape(
embedding, [-1, self.clustering_shape[ii][1]])
embeddings += [embedding]
clustering_ops += [kmeans_clustering(embedding, self.cluster_center[
ii], self.cluster_label[ii], self.num_cluster[ii], eta)]
sample_center = tf.stop_gradient(
tf.gather(self.cluster_center[ii], self.cluster_label[ii]))
reg_ops += [tf.reduce_mean(tf.square(embedding -
sample_center)) * self.alpha[ii] / 2.0]
reset_ops += [tf.scatter_update(self.cluster_center[ii], idx_center[
ii], tf.gather(embedding, idx_sample[ii]))]
if self.act_func[ii] is not None:
h[ii] = self.act_func[ii](h[ii])
if self.pool_func[ii] is not None:
h[ii] = self.pool_func[ii](h[ii], ksize=self.pooling['pool_size'][
ii], strides=self.pooling['pool_stride'][ii], padding='SAME')
return h, embeddings, clustering_ops, reg_ops, reset_ops
评论列表
文章目录