def train_step(x_batch, y_batch, epoch):
"""
A single training step
"""
x_batch_id = [ _ for _ in itertools.compress(range(10000), map(lambda x: x>0,x_batch[0]))]
feed_dict = {nvdm.input_x: x_batch, nvdm.x_id: x_batch_id}
'''
h1b = [v for v in tf.all_variables() if v.name == "h1/b:0"][0]
h1w = [v for v in tf.all_variables() if v.name == "h1/w:0"][0]
_, step, summaries, loss, kl, rc, p_xi_h, R, hb, hw, e = sess.run(
[nvdm.train_op, global_step, loss_summary, nvdm.loss, nvdm.KL, nvdm.recon_loss, nvdm.p_xi_h, nvdm.R, h1b, h1w, nvdm.e], feed_dict)
'''
_, step, loss = sess.run([nvdm.train_op, nvdm.global_step, nvdm.loss], feed_dict)
time_str = datetime.datetime.now().isoformat()
if step % FLAGS.train_every == 0:
print("time: {}, epoch: {}, step: {}, loss: {:g}".format(time_str,epoch, step, loss))
if np.isnan(loss):
import pdb
pdb.set_trace()
#train_summary_writer.add_summary(summaries, step)
nvdm_nobatch.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录