nnet.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:NaiveEntailNet 作者: biswajitsc 项目源码 文件源码
def test(self, sess, mode, seq1, len1, seq2, len2, labels):
        acc = 0
        loss = 0
        cnt = 0
        for d1, l1, d2, l2, l in utils.batch_iter(seq1, len1, seq2, len2, labels):
            tacc, tloss, tpred = sess.run(
                [self.accuracy, self.tot_loss, self.pred],
                feed_dict={
                    self.input_seq1: d1,
                    self.input_len1: l1,
                    self.input_seq2: d2,
                    self.input_len2: l2,
                    self.labels: l,
                    self.initial_state: np.zeros((
                        Options.batch_size,
                        2 * Options.lstm_dim * Options.lstm_layers
                    )),
                    self.lstm_keep_prob: 1.0,
                    self.nnet_keep_prob: 1.0
                })
            cnt += 1

            acc += tacc
            loss += tloss

        acc /= cnt
        loss /= cnt

        print("{0} Accuracy = {1}\t {0} Loss = {2}".format(mode, acc, loss))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号