distributions.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def get_interv_table(model,intrv=True):

    n_batches=25
    table_outputs=[]
    d_vals=np.linspace(TINY,0.6,n_batches)
    for name in model.cc.node_names:
        outputs=[]
        for d_val in d_vals:
            do_dict={model.cc.node_dict[name].label_logit : d_val*np.ones((model.batch_size,1))}
            outputs.append(model.sess.run(model.fake_labels,do_dict))

        out=np.vstack(outputs)
        table_outputs.append(out)

    table=np.stack(table_outputs,axis=2)

    np.mean(np.round(table),axis=0)

    return table

#dT=pd.DataFrame(index=p_names, data=T, columns=do_names)
#T=np.mean(np.round(table),axis=0)
#table=get_interv_table(model)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号