sampling.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:adversarial-variational-bayes 作者: gdikov 项目源码 文件源码
def sample_standard_normal_noise(inputs, **kwargs):
    from keras.backend import shape, random_normal
    n_samples = kwargs.get('n_samples', shape(inputs)[0])
    n_basis_noise_vectors = kwargs.get('n_basis', -1)
    data_dim = kwargs.get('data_dim', 1)
    noise_dim = kwargs.get('noise_dim', data_dim)
    seed = kwargs.get('seed', 7)

    if n_basis_noise_vectors > 0:
        samples_isotropic = random_normal(shape=(n_samples, n_basis_noise_vectors, noise_dim),
                                          mean=0, stddev=1, seed=seed)
    else:
        samples_isotropic = random_normal(shape=(n_samples, noise_dim),
                                          mean=0, stddev=1, seed=seed)
    op_mode = kwargs.get('mode', 'none')
    if op_mode == 'concatenate':
        concat = Concatenate(axis=1, name='enc_noise_concatenation')([inputs, samples_isotropic])
        return concat
    elif op_mode == 'add':
        resized_noise = Dense(data_dim, activation=None, name='enc_resized_noise_sampler')(samples_isotropic)
        added_noise_data = Add(name='enc_adding_noise_data')([inputs, resized_noise])
        return added_noise_data
    return samples_isotropic
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号