tf_model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:char-rnn-text-generation 作者: yxtay 项目源码 文件源码
def build_train_graph(loss, learning_rate=0.001, clip_norm=5.0):
    """
    builds training graph
    """
    train_args = {"learning_rate": learning_rate, "clip_norm": clip_norm}
    logger.debug("building training graph: %s.", train_args)

    learning_rate = tf.placeholder_with_default(learning_rate, [], "learning_rate")
    global_step = tf.Variable(0, name='global_step', trainable=False)
    train_op = layers.optimize_loss(loss, global_step, learning_rate, "Adam",
                                    clip_gradients=clip_norm)

    model = {"global_step": global_step, "train_op": train_op,
             "learning_rate": learning_rate, "train_args": train_args}
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号