def standard_split(cls, root):
"""
Use standard train/dev/test/other splits 2-21/22/23/24, respectively.
"""
train = []; dev = []; test = []; other = []
for d in path(root).listdir():
if d.isdir():
number = int(d.basename())
# for some reason we drop sections < 2.
if 2 <= number <= 21:
train.append(d)
elif number == 22:
dev.append(d)
elif number == 23:
test.append(d)
elif number == 24:
other.append(d)
train.sort()
assert len(train) == 20 and len(test) == 1 and len(dev) == 1
return cls(train, dev, test, other)
评论列表
文章目录