def retrain():
print 'Start retraining'
tf.reset_default_graph()
policy_network = PolicyNetwork(scope = 'supervised_policy')
f = open(relationPath)
training_pairs = f.readlines()
f.close()
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, 'models/policy_supervised_' + relation)
print "sl_policy restored"
episodes = len(training_pairs)
if episodes > 300:
episodes = 300
REINFORCE(training_pairs, policy_network, episodes)
saver.save(sess, 'models/policy_retrained' + relation)
print 'Retrained model saved'
评论列表
文章目录