ctc_cost.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:KGP-ASR 作者: KGPML 项目源码 文件源码
def _add_blanks(y, blank_symbol, y_mask=None):
    """Add blanks to a matrix and updates mask
    Input shape: output_seq_len x num_batch
    Output shape: 2*output_seq_len+1 x num_batch
    """
    # for y
    y_extended = y.T.dimshuffle(0, 1, 'x')
    blanks = tensor.zeros_like(y_extended) + blank_symbol
    concat = tensor.concatenate([y_extended, blanks], axis=2)
    res = concat.reshape((concat.shape[0],
                          concat.shape[1] * concat.shape[2])).T
    begining_blanks = tensor.zeros((1, res.shape[1])) + blank_symbol
    blanked_y = tensor.concatenate([begining_blanks, res], axis=0)
    # for y_mask
    if y_mask is not None:
        y_mask_extended = y_mask.T.dimshuffle(0, 1, 'x')
        concat = tensor.concatenate([y_mask_extended,
                                     y_mask_extended], axis=2)
        res = concat.reshape((concat.shape[0],
                              concat.shape[1] * concat.shape[2])).T
        begining_blanks = tensor.ones((1, res.shape[1]), dtype=floatX)
        blanked_y_mask = tensor.concatenate([begining_blanks, res], axis=0)
    else:
        blanked_y_mask = None
    return blanked_y.astype('int32'), blanked_y_mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号