transforms.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:wtte-rnn 作者: ragulpr 项目源码 文件源码
def get_padded_seq_lengths(padded):
    """Returns the number of (seq_len) non-nan elements per sequence.

    :param padded: 2d or 3d tensor with dim 2 the time dimension
    """
    if len(padded.shape) == 2:
        # (n_seqs,n_timesteps)
        seq_lengths = np.count_nonzero(~np.isnan(padded), axis=1)
    elif len(padded.shape) == 3:
        # (n_seqs,n_timesteps,n_features,..)
        seq_lengths = np.count_nonzero(~np.isnan(padded[:, :, 0]), axis=1)
    else:
        print('not yet implemented')
        # TODO

    return seq_lengths
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号