helpers.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:multimodal_varinf 作者: tmoer 项目源码 文件源码
def kl_preds_v2(model,sess,s_test,a_test,n_rep_per_item=200):
    ## Compare sample distribution to ground truth
    Env = grid_env(False)
    n_test_items,state_size = s_test.shape  
    distances = np.empty([n_test_items,3])

    for i in range(n_test_items):        
        state = s_test[i,:].astype('int32')
        action = np.round(a_test[i,:]).astype('int32')

        # ground truth   
        state_truth = np.empty([n_rep_per_item,s_test.shape[1]])     
        for o in range(n_rep_per_item):
            Env.set_state(state.flatten())
            s1,r,dead = Env.step(action.flatten())
            state_truth[o,:] = s1
        truth_count,bins = np.histogramdd(state_truth,bins=[np.arange(8)-0.5]*state_size) 
        truth_prob = truth_count/n_rep_per_item

        # predictions of model
        y_sample = sess.run(model.y_sample,{ model.x       : state[None,:].repeat(n_rep_per_item,axis=0),
                                           model.y       : np.zeros(np.shape(state[None,:])).repeat(n_rep_per_item,axis=0),
                                           model.a       : action[None,:].repeat(n_rep_per_item,axis=0),
                                           model.Qtarget : np.zeros(np.shape(action[None,:])).repeat(n_rep_per_item,axis=0),
                                           model.lr      : 0,
                                           model.lamb : 1,
                                           model.temp    : 0.00001,                                         
                                           model.is_training : False,
                                           model.k: 1})
        sample_count,bins = np.histogramdd(y_sample,bins=[np.arange(8)-0.5]*state_size) 
        sample_prob = sample_count/n_rep_per_item

        distances[i,0]= np.sum(truth_prob*(np.log(truth_prob+1e-5)-np.log(sample_prob+1e-5))) # KL(p|p_tilde)
        distances[i,1]= np.sum(sample_prob*(np.log(sample_prob+1e-5)-np.log(truth_prob+1e-5))) # Inverse KL(p_tilde|p)
        distances[i,2]= norm(np.sqrt(truth_prob) - np.sqrt(sample_prob))/np.sqrt(2)

    return np.mean(distances,axis=0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号