classifier.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:SentEval 作者: facebookresearch 项目源码 文件源码
def trainepoch(self, X, y, epoch_size=1):
        self.model.train()
        for _ in range(self.nepoch, self.nepoch + epoch_size):
            permutation = np.random.permutation(len(X))
            all_costs = []
            for i in range(0, len(X), self.batch_size):
                # forward
                idx = torch.LongTensor(permutation[i:i + self.batch_size])
                if isinstance(X, torch.cuda.FloatTensor):
                    idx = idx.cuda()
                Xbatch = Variable(X.index_select(0, idx))
                ybatch = Variable(y.index_select(0, idx))
                if self.cudaEfficient:
                    Xbatch = Xbatch.cuda()
                    ybatch = ybatch.cuda()
                output = self.model(Xbatch)
                # loss
                loss = self.loss_fn(output, ybatch)
                all_costs.append(loss.data[0])
                # backward
                self.optimizer.zero_grad()
                loss.backward()
                # Update parameters
                self.optimizer.step()
        self.nepoch += epoch_size
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号