def test_process(self):
raw_field = data.RawField()
field = data.Field(sequential=True, use_vocab=False, batch_first=True)
# Test tensor-like batch data which is accepted by both RawField and Field
batch = [[1, 2, 3], [2, 3, 4]]
batch_tensor = torch.LongTensor(batch)
raw_field_processed = raw_field.process(batch)
field_processed = field.process(batch, device=-1, train=False)
assert raw_field_processed == batch
assert field_processed.data.equal(batch_tensor)
# Test non-tensor data which is only accepted by RawField
any_obj = [object() for _ in range(5)]
raw_field_processed = raw_field.process(any_obj)
assert any_obj == raw_field_processed
with pytest.raises(TypeError):
field.process(any_obj)
评论列表
文章目录