test_field.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:text 作者: pytorch 项目源码 文件源码
def test_process(self):
        raw_field = data.RawField()
        field = data.Field(sequential=True, use_vocab=False, batch_first=True)

        # Test tensor-like batch data which is accepted by both RawField and Field
        batch = [[1, 2, 3], [2, 3, 4]]
        batch_tensor = torch.LongTensor(batch)

        raw_field_processed = raw_field.process(batch)
        field_processed = field.process(batch, device=-1, train=False)

        assert raw_field_processed == batch
        assert field_processed.data.equal(batch_tensor)

        # Test non-tensor data which is only accepted by RawField
        any_obj = [object() for _ in range(5)]

        raw_field_processed = raw_field.process(any_obj)
        assert any_obj == raw_field_processed

        with pytest.raises(TypeError):
            field.process(any_obj)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号