def train_step(x_batch, y_batch, epoch,predicts,labels):
"""
A single training step
"""
y_batch = y_batch.reshape(1,-1)
x_batch_id = [ _ for _ in itertools.compress(range(10000), map(lambda x: x>0,x_batch[0]))]
feed_dict = {nvdm.input_x: x_batch,
nvdm.input_y:y_batch,
nvdm.x_id: x_batch_id}
'''
h1b = [v for v in tf.all_variables() if v.name == "h1/b:0"][0]
h1w = [v for v in tf.all_variables() if v.name == "h1/w:0"][0]
_, step, summaries, loss, kl, rc, p_xi_h, R, hb, hw, e = sess.run(
[nvdm.train_op, global_step, loss_summary, nvdm.loss, nvdm.KL, nvdm.recon_loss, nvdm.p_xi_h, nvdm.R, h1b, h1w, nvdm.e], feed_dict)
'''
_, step, loss,predict = sess.run([nvdm.train_op, nvdm.global_step, nvdm.loss,nvdm.predicts], feed_dict)
time_str = datetime.datetime.now().isoformat()
if step % FLAGS.train_every == 0:
import pdb
pdb.set_trace()
score = f1_score_multiclass(np.array(predicts),np.array(labels))
print("time: {}, epoch: {}, step: {}, loss: {:g}, score: {:g}".format(time_str,epoch, step, loss,score))
return [],[]
predicts.append(predict)
labels.append(y_batch[0].astype(int))
return predicts,labels
if np.isnan(loss):
import pdb
pdb.set_trace()
#train_summary_writer.add_summary(summaries, step)
nvdm_nobatch_new.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录