def package(data, volatile=False):
"""Package data for training / evaluation."""
data = map(lambda x: json.loads(x), data)
dat = map(lambda x: map(lambda y: dictionary.word2idx[y], x['text']), data)
maxlen = 0
for item in dat:
maxlen = max(maxlen, len(item))
targets = map(lambda x: x['label'], data)
maxlen = min(maxlen, 500)
for i in range(len(data)):
if maxlen < len(dat[i]):
dat[i] = dat[i][:maxlen]
else:
for j in range(maxlen - len(dat[i])):
dat[i].append(dictionary.word2idx['<pad>'])
dat = Variable(torch.LongTensor(dat), volatile=volatile)
targets = Variable(torch.LongTensor(targets), volatile=volatile)
return dat.t(), targets
train.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录