models.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:sdp 作者: tansey 项目源码 文件源码
def build(self, input_layer):
        if self._one_hot:
            split_indices = tf.to_int32(tf.argmax(self._labels, 1))
        else:
            split_indices = tf.to_int32(tf.reduce_sum([self._labels[:,i]*int(np.prod(self._num_classes[i+1:])) for i in xrange(len(self._num_classes))], 0))
        self.splits, self.masks = tf.gather(self._split_labels, split_indices), tf.gather(self._split_masks, split_indices)

        # q is the value of the tree nodes
        # m is the value of the multinomial bins
        # z is the log-space version of m
        self._q = tf.reciprocal(1 + tf.exp(-(tf.matmul(input_layer,self.W) + self.b)))
        r = self.splits * tf.log(tf.clip_by_value(self._q, 1e-10, 1.0))
        s = (1 - self.splits) * tf.log(tf.clip_by_value(1 - self._q, 1e-10, 1.0))
        self._multiscale_loss = tf.reduce_mean(-tf.reduce_sum(self.masks * (r+s),
                                        axis=[1]))

        # Convert from multiscale output to multinomial output
        L, R = self.multiscale_splits_masks()
        q_tiles = tf.constant([1, np.prod(self._num_classes)])

        m = tf.map_fn(lambda q_i: self.multiscale_to_multinomial(q_i, L, R, q_tiles), self._q)

        z = tf.log(tf.clip_by_value(m, 1e-10, 1.))

        # Get the trend filtering penalty
        fv = trend_filtering_penalty(z, self._num_classes, self._k, penalty=self._penalty)
        reg = tf.multiply(self._lam, fv)

        self._loss_function = tf.add(self._multiscale_loss, reg)

        # Reshape to the original dimensions of the density
        density_shape = tf.stack([tf.shape(self._q)[0]] + list(self._num_classes))
        self._density = tf.reshape(m, density_shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号