nvdm_nobatch_new.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:NVDM-For-Document-Classification 作者: cryanzpj 项目源码 文件源码
def train_step(x_batch, y_batch, epoch,predicts,labels):
            """
            A single training step
            """
            y_batch = y_batch.reshape(1,-1)
            x_batch_id = [ _ for _ in itertools.compress(range(10000), map(lambda x: x>0,x_batch[0]))]
            feed_dict = {nvdm.input_x: x_batch,
                         nvdm.input_y:y_batch,
                         nvdm.x_id: x_batch_id}
            '''
            h1b = [v for v in tf.all_variables() if v.name == "h1/b:0"][0]
            h1w = [v for v in tf.all_variables() if v.name == "h1/w:0"][0]
            _, step, summaries, loss, kl, rc, p_xi_h, R, hb, hw, e  = sess.run(
                [nvdm.train_op, global_step, loss_summary, nvdm.loss, nvdm.KL, nvdm.recon_loss, nvdm.p_xi_h, nvdm.R, h1b, h1w, nvdm.e], feed_dict)
            '''

            _, step,  loss,predict = sess.run([nvdm.train_op, nvdm.global_step, nvdm.loss,nvdm.predicts], feed_dict)


            time_str = datetime.datetime.now().isoformat()
            if step % FLAGS.train_every == 0:
                import pdb
                pdb.set_trace()


                score = f1_score_multiclass(np.array(predicts),np.array(labels))
                print("time: {},  epoch: {}, step: {}, loss: {:g}, score: {:g}".format(time_str,epoch, step, loss,score))

                return [],[]


            predicts.append(predict)
            labels.append(y_batch[0].astype(int))

            return predicts,labels

            if np.isnan(loss):
                import pdb
                pdb.set_trace()

            #train_summary_writer.add_summary(summaries, step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号