def __compute_bnn_training_error(self):
"""Compute BNN training error on most recent episode."""
exp = np.reshape(self.episode_buffer_bnn, (len(self.episode_buffer_bnn),-1))
episode_X = np.array([np.hstack([exp[tt,0],exp[tt,1]]) for tt in xrange(exp.shape[0])])
episode_Y = np.array([exp[tt,3] for tt in xrange(exp.shape[0])])
if self.state_diffs:
# subtract previous state
episode_Y -= episode_X[:,:self.num_dims]
l2_errors = self.network.get_td_error(np.hstack([episode_X, np.tile(self.weight_set, (episode_X.shape[0],1))]), episode_Y, 0.0, 1.0)
self.mean_episode_errors[self.instance_iter,self.episode_iter] = np.mean(l2_errors)
self.std_episode_errors[self.instance_iter,self.episode_iter] = np.std(l2_errors)
if self.print_output:
print('BNN Error: {}'.format(self.mean_episode_errors[self.instance_iter,self.episode_iter]))
评论列表
文章目录