test_preprocessing.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:nnmnkwii 作者: r9y9 项目源码 文件源码
def test_adjast_frame_length_divisible():
    D = 5
    T = 10

    x = np.random.rand(T, D)
    assert T == adjast_frame_length(x, pad=True, divisible_by=1).shape[0]
    assert T == adjast_frame_length(x, pad=True, divisible_by=2).shape[0]
    print(adjast_frame_length(x, pad=True, divisible_by=3).shape[0])
    assert T + 2 == adjast_frame_length(x, pad=True, divisible_by=3).shape[0]
    assert T + 2 == adjast_frame_length(x, pad=True, divisible_by=4).shape[0]

    assert T == adjast_frame_length(x, pad=False, divisible_by=1).shape[0]
    assert T == adjast_frame_length(x, pad=False, divisible_by=2).shape[0]
    assert T - 1 == adjast_frame_length(x, pad=False, divisible_by=3).shape[0]
    assert T - 2 == adjast_frame_length(x, pad=False, divisible_by=4).shape[0]

    # Should preserve dtype
    for dtype in [np.float32, np.float64]:
        x = np.random.rand(T, D).astype(dtype)
        assert x.dtype == adjast_frame_length(x, pad=True, divisible_by=3).dtype
        assert x.dtype == adjast_frame_length(x, pad=False, divisible_by=3).dtype
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号