sketch_rnn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Pytorch-Sketch-RNN 作者: alexis-jacq 项目源码 文件源码
def sample_bivariate_normal(mu_x,mu_y,sigma_x,sigma_y,rho_xy, greedy=False):
    # inputs must be floats
    if greedy:
      return mu_x,mu_y
    mean = [mu_x, mu_y]
    sigma_x *= np.sqrt(hp.temperature)
    sigma_y *= np.sqrt(hp.temperature)
    cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],\
        [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号