def train_step(self,sess,counter):
'''
This is a generic function that will be called by the Trainer class
once per iteration. The simplest body for this part would be simply
"sess.run(self.train_op)". But you may have more complications.
Running self.summary_op is handeled by Trainer.Supervisor and doesn't
need to be addressed here
Only counters, not epochs are explicitly kept track of
'''
###You can wait until counter>N to do stuff for example:
if self.config.pretrain_LabelerR and counter < self.config.pretrain_LabelerR_no_of_iters:
sess.run(self.d_label_optim)
else:
if np.mod(counter, 3) == 0:
sess.run(self.g_optim)
sess.run([self.train_op,self.k_t_update,self.inc_step])#all ops
else:
sess.run([self.g_optim, self.k_t_update ,self.inc_step])
sess.run(self.g_optim)
评论列表
文章目录