def init_training_mode():
""" init_training_mode.
Creates `is_training` variable and its ops if they haven't be created
yet. This op is required if you are using layers such as dropout or
batch normalization independently of TFLearn models (DNN or Trainer class).
"""
# 'is_training' collection stores the training mode variable
coll = tf.get_collection('is_training')
if len(coll) == 0:
tr_var = variable(
"is_training", dtype=tf.bool, shape=[],
initializer=tf.constant_initializer(False),
trainable=False)
tf.add_to_collection('is_training', tr_var)
# 'is_training_ops' stores the ops to update training mode variable
a = tf.assign(tr_var, True)
b = tf.assign(tr_var, False)
tf.add_to_collection('is_training_ops', a)
tf.add_to_collection('is_training_ops', b)
评论列表
文章目录