def kl_preds_v2(model,sess,s_test,a_test,n_rep_per_item=200):
## Compare sample distribution to ground truth
Env = grid_env(False)
n_test_items,state_size = s_test.shape
distances = np.empty([n_test_items,3])
for i in range(n_test_items):
state = s_test[i,:].astype('int32')
action = np.round(a_test[i,:]).astype('int32')
# ground truth
state_truth = np.empty([n_rep_per_item,s_test.shape[1]])
for o in range(n_rep_per_item):
Env.set_state(state.flatten())
s1,r,dead = Env.step(action.flatten())
state_truth[o,:] = s1
truth_count,bins = np.histogramdd(state_truth,bins=[np.arange(8)-0.5]*state_size)
truth_prob = truth_count/n_rep_per_item
# predictions of model
y_sample = sess.run(model.y_sample,{ model.x : state[None,:].repeat(n_rep_per_item,axis=0),
model.y : np.zeros(np.shape(state[None,:])).repeat(n_rep_per_item,axis=0),
model.a : action[None,:].repeat(n_rep_per_item,axis=0),
model.Qtarget : np.zeros(np.shape(action[None,:])).repeat(n_rep_per_item,axis=0),
model.lr : 0,
model.lamb : 1,
model.temp : 0.00001,
model.is_training : False,
model.k: 1})
sample_count,bins = np.histogramdd(y_sample,bins=[np.arange(8)-0.5]*state_size)
sample_prob = sample_count/n_rep_per_item
distances[i,0]= np.sum(truth_prob*(np.log(truth_prob+1e-5)-np.log(sample_prob+1e-5))) # KL(p|p_tilde)
distances[i,1]= np.sum(sample_prob*(np.log(sample_prob+1e-5)-np.log(truth_prob+1e-5))) # Inverse KL(p_tilde|p)
distances[i,2]= norm(np.sqrt(truth_prob) - np.sqrt(sample_prob))/np.sqrt(2)
return np.mean(distances,axis=0)
评论列表
文章目录