nn_cell_lib.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:deep_parsimonious 作者: lrjconan 项目源码 文件源码
def run(self, x, eta, idx_center=None, idx_sample=None):
        h = [None] * self.num_layer
        embeddings = []
        reg_ops = []
        reset_ops = []
        clustering_ops = []

        with tf.variable_scope(self.scope):
            for ii in xrange(self.num_layer):
                with tf.variable_scope('layer_{}'.format(ii)):
                    if ii == 0:
                        input_vec = x
                    else:
                        input_vec = h[ii - 1]

                    h[ii] = tf.matmul(input_vec, self.w[ii])

                    if self.add_bias:
                        h[ii] += self.b[ii]

                    if self.clustering_shape[ii] is not None:
                        embedding = h[ii]
                        embeddings += [embedding]

                        clustering_ops += [kmeans_clustering(embedding, self.cluster_center[
                                                             ii], self.cluster_label[ii], self.num_cluster[ii], eta)]

                        sample_center = tf.stop_gradient(
                            tf.gather(self.cluster_center[ii], self.cluster_label[ii]))
                        reg_ops += [tf.reduce_mean(
                            tf.square(embedding - sample_center)) * self.alpha[ii] / 2.0]

                        reset_ops += [tf.scatter_update(self.cluster_center[ii], idx_center[
                            ii], tf.gather(h[ii], idx_sample[ii]))]

                    if self.act_func and self.act_func[ii] is not None:
                        h[ii] = self.act_func[ii](h[ii])

        return h, embeddings, clustering_ops, reg_ops, reset_ops
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号