def _compute_weights(self, labels):
log.debug('Computing weights from batch labels')
labels = tf.cast(labels, dtype=tf.float32)
lshape = tf.cast(tf.shape(labels), dtype=tf.float32)
weights = tf.divide(tf.reduce_sum(
labels, axis=0, keep_dims=True), lshape[0])
return tf.tile(weights, [tf.shape(labels)[0], 1])
评论列表
文章目录