ops.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def batch_segment_mean(s_data, s_indices, n):
    s_data_shp = tf.shape(s_data)
    s_data_flat = tf.reshape(
        s_data, [tf.prod(s_data_shp[:-1]), s_data_shp[-1]])
    s_indices_flat = tf.reshape(s_indices, [-1])
    s_results = tf.unsorted_segment_sum(s_data_flat, s_indices_flat, n)
    s_weights = tf.unsorted_segment_sum(
        tf.ones_like(s_indices_flat),
        s_indices_flat, n)
    return s_results / tf.cast(tf.expand_dims(s_weights, -1), hparams.FLOATX)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号