def build_training_process(self):
wider_side_obj, wider_entropy = tf.cond(
tf.greater(self.wider_seg_deeper, 0),
lambda: self.get_wider_side_obj(),
lambda: (tf.constant(0.0, dtype=tf.float32), tf.constant(0.0, dtype=tf.float32))
)
batch_size = array_ops.shape(self.reward)[0]
deeper_side_obj, deeper_entropy = tf.cond(
self.has_deeper,
lambda: self.get_deeper_side_obj(),
lambda: (tf.constant(0.0, dtype=tf.float32), tf.constant(0.0, dtype=tf.float32))
)
self.obj = wider_side_obj + deeper_side_obj
entropy_term = wider_entropy * tf.cast(self.wider_seg_deeper, tf.float32) + \
deeper_entropy * tf.cast(batch_size - self.wider_seg_deeper, tf.float32)
entropy_term /= tf.cast(batch_size, tf.float32)
optimizer = BasicModel.build_optimizer(self.learning_rate, self.opt_config[0], self.opt_config[1])
self.train_step = optimizer.minimize(- self.obj - self.entropy_penalty * entropy_term)
评论列表
文章目录