def batch_segment_mean(s_data, s_indices, n):
s_data_shp = tf.shape(s_data)
s_data_flat = tf.reshape(
s_data, [tf.prod(s_data_shp[:-1]), s_data_shp[-1]])
s_indices_flat = tf.reshape(s_indices, [-1])
s_results = tf.unsorted_segment_sum(s_data_flat, s_indices_flat, n)
s_weights = tf.unsorted_segment_sum(
tf.ones_like(s_indices_flat),
s_indices_flat, n)
return s_results / tf.cast(tf.expand_dims(s_weights, -1), hparams.FLOATX)
评论列表
文章目录