def load(self, name, iteration=None):
try:
directory = 'saves/' + name + '/'
if not os.path.exists(directory):
print('That directory does not exist!')
raise Exception
if iteration is None:
iteration = np.max([int(x[10:]) for x in [dir for dir in os.walk(directory)][0][1]])
directory += 'iteration_{}'.format(iteration) + '/'
for i, tensor in enumerate(tf.global_variables()):
arr = np.load(directory + 'weight_{}.npy'.format(i))
self.sess.run(tensor.assign(arr))
if self.scale != 'off':
self.sums = np.load(directory + 'sums.npy')
self.sumsqrs = np.load(directory + 'sumsquares.npy')
self.sumtime = np.load(directory + 'sumtime.npy')
self.timestep = np.load(directory + 'timestep.npy')[0]
self.train_scores = np.load(directory + 'train_scores.npy').tolist()
self.test_scores = np.load(directory + 'test_scores.npy').tolist()
print("Agent successfully loaded from folder {}".format(directory))
except:
print("Something is wrong, loading failed")
评论列表
文章目录