keras_utils.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:KATE 作者: hugochan 项目源码 文件源码
def get_config(self):
        config = {'topk': self.topk, 'ctype': self.ctype}
        base_config = super(KCompetitive, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    # def k_comp_sigm(self, x, topk):
    #     print 'run k_comp_sigm'
    #     dim = int(x.get_shape()[1])
    #     if topk > dim:
    #         warnings.warn('topk should not be larger than dim: %s, found: %s, using %s' % (dim, topk, dim))
    #         topk = dim

    #     values, indices = tf.nn.top_k(x, topk) # indices will be [[0, 1], [2, 1]], values will be [[6., 2.], [5., 4.]]

    #     # We need to create full indices like [[0, 0], [0, 1], [1, 2], [1, 1]]
    #     my_range = tf.expand_dims(tf.range(0, K.shape(indices)[0]), 1)  # will be [[0], [1]]
    #     my_range_repeated = tf.tile(my_range, [1, topk])  # will be [[0, 0], [1, 1]]

    #     full_indices = tf.stack([my_range_repeated, indices], axis=2) # change shapes to [N, k, 1] and [N, k, 1], to concatenate into [N, k, 2]
    #     full_indices = tf.reshape(full_indices, [-1, 2])

    #     to_reset = tf.sparse_to_dense(full_indices, tf.shape(x), tf.reshape(values, [-1]), default_value=0., validate_indices=False)

    #     batch_size = tf.to_float(tf.shape(x)[0])
    #     tmp = 1 * batch_size * tf.reduce_sum(x - to_reset, 1, keep_dims=True) / topk

    #     res = tf.sparse_to_dense(full_indices, tf.shape(x), tf.reshape(tf.add(values, tmp), [-1]), default_value=0., validate_indices=False)

    #     return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号