def test_torch(self):
try:
import torch
except ImportError:
# pass by default if no torch available
return
st = SharedTable({'a': torch.FloatTensor([1]), 'b': torch.LongTensor(2)})
assert st['a'][0] == 1.0
assert len(st) == 2
assert 'b' in st
del st['b']
assert 'b' not in st
assert len(st) == 1
if torch.cuda.is_available():
st = SharedTable({'a': torch.cuda.FloatTensor([1]), 'b': torch.cuda.LongTensor(2)})
assert st['a'][0] == 1.0
assert len(st) == 2
assert 'b' in st
del st['b']
assert 'b' not in st
assert len(st) == 1
评论列表
文章目录