ops.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:ethnicity-tensorflow 作者: jhyuklee 项目源码 文件源码
def mask_by_index(batch_size, input_len, max_time_step):
    with tf.variable_scope('Masking') as scope:
        input_index = tf.range(0, batch_size) * max_time_step + (input_len - 1)
        lengths_transposed = tf.expand_dims(input_index, 1)
        lengths_tiled = tf.tile(lengths_transposed, [1, max_time_step])
        mask_range = tf.range(0, max_time_step)
        range_row = tf.expand_dims(mask_range, 0)
        range_tiled = tf.tile(range_row, [batch_size, 1])
        mask = tf.less_equal(range_tiled, lengths_tiled)
        weight = tf.select(mask, tf.ones([batch_size, max_time_step]),
                           tf.zeros([batch_size, max_time_step]))
        weight = tf.reshape(weight, [-1])
        return weight
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号