base.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def _compute_weights(self, labels):
        log.debug('Computing weights from batch labels')
        labels = tf.cast(labels, dtype=tf.float32)
        lshape = tf.cast(tf.shape(labels), dtype=tf.float32)
        weights = tf.divide(tf.reduce_sum(
            labels, axis=0, keep_dims=True), lshape[0])
        return tf.tile(weights, [tf.shape(labels)[0], 1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号