nnet.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:NaiveEntailNet 作者: biswajitsc 项目源码 文件源码
def exploremodel(self, seq1, len1, seq2, len2, labels):
        saver = tf.train.Saver()

        with tf.Session() as sess:
            saver.restore(sess, "model.ckpt")

            preds = []

            for d1, l1, d2, l2, l in utils.batch_iter(seq1, len1, seq2, len2, labels):
                val1 = sess.run([self.pred], feed_dict={
                    self.input_seq1: d1,
                    self.input_len1: l1,
                    self.input_seq2: d2,
                    self.input_len2: l2,
                    self.labels: l,
                    self.initial_state: np.zeros((
                        Options.batch_size,
                        2 * Options.lstm_dim * Options.lstm_layers
                    )),
                    self.lstm_keep_prob: 1.0,
                    self.nnet_keep_prob: 1.0
                })

                preds.extend(val1[0])

            classes = np.argmax(labels[:4900], axis=1)
            cm = confusion_matrix(classes, preds)
            print(cm)
            print(np.mean(np.asarray(classes) == np.asarray(preds)))
            for row in cm:
                print(row / np.sum(row))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号