policy_agent.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DeepPath 作者: xwhan 项目源码 文件源码
def retrain():
    print 'Start retraining'
    tf.reset_default_graph()
    policy_network = PolicyNetwork(scope = 'supervised_policy')

    f = open(relationPath)
    training_pairs = f.readlines()
    f.close()

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, 'models/policy_supervised_' + relation)
        print "sl_policy restored"
        episodes = len(training_pairs)
        if episodes > 300:
            episodes = 300
        REINFORCE(training_pairs, policy_network, episodes)
        saver.save(sess, 'models/policy_retrained' + relation)
    print 'Retrained model saved'
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号