nn_cell_lib.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:deep_parsimonious 作者: lrjconan 项目源码 文件源码
def run(self, x, eta, idx_center=None, idx_sample=None):
        """ x must be of size [B H W C] """
        h = [None] * self.num_layer
        embeddings = []
        reg_ops = []
        reset_ops = []
        clustering_ops = []

        with tf.variable_scope(self.scope):
            for ii in xrange(self.num_layer):
                if ii == 0:
                    input_vec = x
                else:
                    input_vec = h[ii - 1]

                h[ii] = tf.nn.conv2d(input_vec, self.w[ii], self.conv_filters[
                                     'filter_stride'][ii], padding='SAME')

                if self.add_bias:
                    h[ii] += self.b[ii]

                if self.clustering_type[ii] == 'sample':
                    embedding = h[ii]
                elif self.clustering_type[ii] == 'spatial':
                    embedding = h[ii]
                elif self.clustering_type[ii] == 'channel':
                    embedding = tf.transpose(h[ii], [0, 3, 1, 2])

                if self.clustering_shape[ii] is not None:
                    embedding = tf.reshape(
                        embedding, [-1, self.clustering_shape[ii][1]])

                    embeddings += [embedding]
                    clustering_ops += [kmeans_clustering(embedding, self.cluster_center[
                                                         ii], self.cluster_label[ii], self.num_cluster[ii], eta)]

                    sample_center = tf.stop_gradient(
                        tf.gather(self.cluster_center[ii], self.cluster_label[ii]))
                    reg_ops += [tf.reduce_mean(tf.square(embedding -
                                                         sample_center)) * self.alpha[ii] / 2.0]

                    reset_ops += [tf.scatter_update(self.cluster_center[ii], idx_center[
                        ii], tf.gather(embedding, idx_sample[ii]))]

                if self.act_func[ii] is not None:
                    h[ii] = self.act_func[ii](h[ii])

                if self.pool_func[ii] is not None:
                    h[ii] = self.pool_func[ii](h[ii], ksize=self.pooling['pool_size'][
                                               ii], strides=self.pooling['pool_stride'][ii], padding='SAME')

        return h, embeddings, clustering_ops, reg_ops, reset_ops
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号