relatedness.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:SentEval 作者: facebookresearch 项目源码 文件源码
def trainepoch(self, X, y, nepoches=1):
        self.model.train()
        for _ in range(self.nepoch, self.nepoch + nepoches):
            permutation = np.random.permutation(len(X))
            all_costs = []
            for i in range(0, len(X), self.batch_size):
                # forward
                idx = torch.from_numpy(permutation[i:i + self.batch_size]).long().cuda()
                Xbatch = Variable(X.index_select(0, idx))
                ybatch = Variable(y.index_select(0, idx))
                output = self.model(Xbatch)
                # loss
                loss = self.loss_fn(output, ybatch)
                all_costs.append(loss.data[0])
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                # Update parameters
                self.optimizer.step()
        self.nepoch += nepoches
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号