seq2seq.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: yanwii 项目源码 文件源码
def train(self):
        self.loadData()
        try:
            self.load_state_dict(torch.load(self.model_path+'params.pkl'))
        except Exception as e:
            print(e)
            print("No model!")
        loss_track = []

        for epoch in range(self.max_epoches):
            start = time.time()
            inputs, targets = self.next(1, shuffle=False)
            loss, logits = self.step(inputs, targets, self.max_length)
            loss_track.append(loss)
            _,v = torch.topk(logits, 1)
            pre = v.cpu().data.numpy().T.tolist()[0][0]
            tar = targets.cpu().data.numpy().T.tolist()[0]
            stop = time.time()
            if epoch % self.show_epoch == 0:
                print("-"*50)
                print("epoch:", epoch)
                print("    loss:", loss)
                print("    target:%s\n    output:%s" % (tar, pre))
                print("    per-time:", (stop-start))
                torch.save(self.state_dict(), self.model_path+'params.pkl')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号