timit_for_srnn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:srnn 作者: marcofraccaro 项目源码 文件源码
def create_test_set(x_lst):
    n = len(x_lst)
    x_lens = np.array(map(len, x_lst))
    max_len = max(map(len, x_lst)) - 1
    u_out = np.zeros((n, max_len, OUTDIM), dtype='float32')*np.nan
    x_out = np.zeros((n, max_len, OUTDIM), dtype='float32')*np.nan
    for row, vec in enumerate(x_lst):
        l = len(vec) - 1
        u = vec[:-1]  # all but last element
        x = vec[1:]   # all but first element

        x_out[row, :l] = x
        u_out[row, :l] = u

    mask = np.invert(np.isnan(x_out))
    x_out[np.isnan(x_out)] = 0
    u_out[np.isnan(u_out)] = 0
    mask = mask[:, :, 0]
    assert np.all((mask.sum(axis=1)+1) == x_lens)
    return u_out, x_out, mask.astype('float32')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号