def do_eval(sess,model,evalX,evalY,batch_size,vocabulary_index2word_label,eval_decoder_input=None):
#ii=0
number_examples=len(evalX)
eval_loss,eval_acc,eval_counter=0.0,0.0,0
for start,end in zip(range(0,number_examples,batch_size),range(batch_size,number_examples,batch_size)):
feed_dict = {model.query: evalX[start:end],model.story:np.expand_dims(evalX[start:end],axis=1), model.dropout_keep_prob: 1}
if not FLAGS.multi_label_flag:
feed_dict[model.answer_single] = evalY[start:end]
else:
feed_dict[model.answer_multilabel] = evalY[start:end]
curr_eval_loss, logits,curr_eval_acc,pred= sess.run([model.loss_val,model.logits,model.accuracy,model.predictions],feed_dict)#curr_eval_acc--->textCNN.accuracy
eval_loss,eval_acc,eval_counter=eval_loss+curr_eval_loss,eval_acc+curr_eval_acc,eval_counter+1
#if ii<20:
#print("1.evalX[start:end]:",evalX[start:end])
#print("2.evalY[start:end]:", evalY[start:end])
#print("3.pred:",pred)
#ii=ii+1
return eval_loss/float(eval_counter),eval_acc/float(eval_counter)
#?logits????? get label using logits
评论列表
文章目录