cnn_classifier.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TextAsGraphClassification 作者: NightmareNyx 项目源码 文件源码
def fit(self, X, y=None):
        """
        This should fit classifier. All the "work" should be done here.
        Note: assert is not a good choice here and you should rather
        use try/except blog with exceptions. This is just for short syntax.
        """

        # Generate batches
        batches = batch_iter(
            list(zip(X, y)), self.FLAGS.batch_size, self.FLAGS.num_epochs)

        # Training loop. For each batch...
        for batch in batches:
            x_batch, y_batch = zip(*batch)
            feed_dict = {
                self.cnn.input_x: x_batch,
                self.cnn.input_y: y_batch,
                self.cnn.dropout_keep_prob: self.FLAGS.dropout_keep_prob
            }
            _, loss, accuracy = self.sess.run(
                [self.optimizer, self.cnn.loss, self.cnn.accuracy],
                feed_dict)
            # print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, accuracy))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号