def build_train_graph(loss, learning_rate=0.001, clip_norm=5.0):
"""
builds training graph
"""
train_args = {"learning_rate": learning_rate, "clip_norm": clip_norm}
logger.debug("building training graph: %s.", train_args)
learning_rate = tf.placeholder_with_default(learning_rate, [], "learning_rate")
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = layers.optimize_loss(loss, global_step, learning_rate, "Adam",
clip_gradients=clip_norm)
model = {"global_step": global_step, "train_op": train_op,
"learning_rate": learning_rate, "train_args": train_args}
return model
评论列表
文章目录