ptrnets.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:deeptravel 作者: keon 项目源码 文件源码
def gen_hull(p, p_mask, f_encode, f_probi, options):
    # p: n_sizes * n_samples * data_dim
    n_sizes = p.shape[0]
    n_samples = p.shape[1] if p.ndim == 3 else 1
    hprev = f_encode(p_mask, p)  # n_sizes * n_samples * data_dim
    points = numpy.zeros((n_samples, n_sizes), dtype='int64')
    h = hprev[-1]
    c = numpy.zeros((n_samples, options['dim_proj']), dtype=config.floatX)
    xi = numpy.zeros((n_samples,), dtype='int64')
    xi_mask = numpy.ones((n_samples,), dtype=config.floatX)
    for i in range(n_sizes):
        h, c, probi = f_probi(p_mask[i], xi, h, c, hprev, p_mask, p)
        xi = probi.argmax(axis=0)
        xi *= xi_mask.astype(numpy.int64)  # Avoid compatibility problem in numpy 1.10
        xi_mask = (numpy.not_equal(xi, 0)).astype(config.floatX)
        if numpy.equal(xi_mask, 0).all():
            break
        points[:, i] = xi
    return points
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号