vae_imdb.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:NVDM-For-Document-Classification 作者: cryanzpj 项目源码 文件源码
def train(self, X_train, y_train):
        #self.saver.restore(self.sess, "./imdbmodel/model.ckpt")
        total_batch = X_train.shape[0] // self.batch_size 
        for e in range(self.epoch):
            perplist = []
            for i in range(total_batch):
                X_batch = X_train[i*self.batch_size:(i+1)*self.batch_size]
                y_batch = y_train[i*self.batch_size:(i+1)*self.batch_size]
                x_batch_id = [_ for _ in itertools.compress(range(self.feature_size), map(lambda x : x>0, X_batch[0].toarray()[0]))]
                feed_dict = {
                        self.input_x : X_batch.toarray(),
                        self.input_y : np.reshape(y_batch, [-1,1]),
                        self.x_id : x_batch_id
                        }
                _, loss =  self.sess.run([
                            self.train_op, 
                            self.loss], feed_dict)
                if np.isnan(loss):
                    import pdb
                    pdb.set_trace()
                if i % self.display_score == 0:
                    p_xi_h = self.sess.run([self.p_xi_h], feed_dict)
                    valid_p = np.mean(np.log(p_xi_h[0][x_batch_id]))
                    perplist.append(valid_p)
                    print("step: {}, perp: {:f}".format(i, np.exp(-np.mean(perplist))))
            # save model every epoch
                if i > 0 and i % 2000 == 0:
                    self.savemodel()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号