a3_train.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:text_classification 作者: brightmart 项目源码 文件源码
def do_eval(sess,model,evalX,evalY,batch_size,vocabulary_index2word_label,eval_decoder_input=None):
    #ii=0
    number_examples=len(evalX)
    eval_loss,eval_acc,eval_counter=0.0,0.0,0
    for start,end in zip(range(0,number_examples,batch_size),range(batch_size,number_examples,batch_size)):
        feed_dict = {model.query: evalX[start:end],model.story:np.expand_dims(evalX[start:end],axis=1), model.dropout_keep_prob: 1}
        if not FLAGS.multi_label_flag:
            feed_dict[model.answer_single] = evalY[start:end]
        else:
            feed_dict[model.answer_multilabel] = evalY[start:end]
        curr_eval_loss, logits,curr_eval_acc,pred= sess.run([model.loss_val,model.logits,model.accuracy,model.predictions],feed_dict)#curr_eval_acc--->textCNN.accuracy
        eval_loss,eval_acc,eval_counter=eval_loss+curr_eval_loss,eval_acc+curr_eval_acc,eval_counter+1
        #if ii<20:
            #print("1.evalX[start:end]:",evalX[start:end])
            #print("2.evalY[start:end]:", evalY[start:end])
            #print("3.pred:",pred)
            #ii=ii+1
    return eval_loss/float(eval_counter),eval_acc/float(eval_counter)

#?logits????? get label using logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号