sentiment_trainer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:treehopper 作者: tomekkorbak 项目源码 文件源码
def test(self, dataset):
        self.model.eval()
        self.embedding_model.eval()
        loss = 0
        accuracies = torch.zeros(len(dataset))

        output_trees = []
        outputs = []
        for idx in tqdm(range(len(dataset)), desc='Testing epoch  '+str(self.epoch)+''):
            tree, sent, label = dataset[idx]
            input = Var(sent, volatile=True)
            target = Var(torch.LongTensor([int(label)]), volatile=True)
            if self.args.cuda:
                input = input.cuda()
                target = target.cuda()
            emb = F.torch.unsqueeze(self.embedding_model(input),1)
            output, _, acc, tree = self.model(tree, emb)
            err = self.criterion(output, target)
            loss += err.data[0]
            accuracies[idx] = acc
            output_trees.append(tree)
            outputs.append(tree.output_softmax.data.numpy())
            # predictions[idx] = torch.dot(indices,torch.exp(output.data.cpu()))
        return loss/len(dataset), accuracies, outputs, output_trees
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号