sample.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def interpret_dict( a_dict, model,n_times=1, on_logits=True):
    '''
    pass either a do_dict or a cond_dict.
    The rules for converting arguments to numpy arrays to pass
    to tensorflow are identical
    '''
    if a_dict is None:
        return {}
    elif len(a_dict)==0:
        return {}

    if n_times>1:
        token=tf.placeholder_with_default(2.22)
        a_dict[token]=-2.22

    p_a_dict=take_product(a_dict)

    ##Need divisible batch_size for most models
    if len(p_a_dict)>0:
        L=len(p_a_dict.values()[0])
    else:
        L=0
    print("L is " + str(L))
    print(p_a_dict)

    ##Check compatability batch_size and L
    if L>=model.batch_size:
        if not L % model.batch_size == 0:
            raise ValueError('a_dict must be dividable by batch_size\
                             but instead product of inputs was of length',L)
    elif model.batch_size % L == 0:
        p_a_dict = {key:np.repeat(value,model.batch_size/L,axis=0) for key,value in p_a_dict.items()}
    else:
        raise ValueError('No. of intervened values must divide batch_size.')
    return p_a_dict
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号