def train(self):
self.loadData()
try:
self.load_state_dict(torch.load(self.model_path+'params.pkl'))
except Exception as e:
print(e)
print("No model!")
loss_track = []
for epoch in range(self.max_epoches):
start = time.time()
inputs, targets = self.next(1, shuffle=False)
loss, logits = self.step(inputs, targets, self.max_length)
loss_track.append(loss)
_,v = torch.topk(logits, 1)
pre = v.cpu().data.numpy().T.tolist()[0][0]
tar = targets.cpu().data.numpy().T.tolist()[0]
stop = time.time()
if epoch % self.show_epoch == 0:
print("-"*50)
print("epoch:", epoch)
print(" loss:", loss)
print(" target:%s\n output:%s" % (tar, pre))
print(" per-time:", (stop-start))
torch.save(self.state_dict(), self.model_path+'params.pkl')
评论列表
文章目录