codes.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:neural-decoder 作者: Krastanov 项目源码 文件源码
def find_threshold(Lsmall=3, Llarge=5, p=0.8, high=1, low=0.79, samples=1000, logfile=None):
    '''Use binary search (between two sizes of codes) to find the threshold for the toric code.'''
    ps = []
    samples_small = []
    samples_large = []
    def step(p):
        ps.append(p)
        samples_small.append(stat_estimator(sample(Lsmall, p, samples=samples)))
        samples_large.append(stat_estimator(sample(Llarge, p, samples=samples)))
    def intersection(xs, y1s, y2s, log=True):
        d = np.linalg.det
        if log:
            y1s, y2s = np.log([y1s,y2s])
        ones = np.array([1.,1.])
        dx  = d([xs , ones])
        dy1 = d([y1s, ones])
        dy2 = d([y2s, ones])
        x = (d([xs, y1s])-d([xs, y2s])) / (dy2-dy1)
        y = (d([xs, y1s])*dy2 - d([xs, y2s])*dy1) / dx / (dy2-dy1)
        if log:
            y = np.exp(y)
        return x, y
    step(p)
    if logfile:
        with open(logfile, 'w') as f:
            ss = samples_small[0]
            sl = samples_large[0]
            f.write(str((np.vstack([ps, [ss[0]], [ss[1]-ss[0]], [ss[2]-ss[0]], [sl[0]], [sl[1]-sl[0]], [sl[2]-sl[0]]]), (ss[0]+sl[0])/2, ps[0])))
    else:
        f = plt.figure()
        s = f.add_subplot(1,1,1)
    while not (samples_large[-1][1]<samples_small[-1][0]<samples_large[-1][2]
            or samples_small[-1][1]<samples_large[-1][0]<samples_small[-1][2]):
        if samples_small[-1][0]<samples_large[-1][0]:
            p, high = low+(ps[-1]-low)/2, p
        else:
            p, low = ps[-1]+(high-ps[-1])/2, p
        step(p)
        _argsort = np.argsort(ps)
        _ps = np.array(ps)[_argsort]
        _ss = np.array(samples_small)
        _small = _ss[_argsort,0]
        _small_err = np.abs(_ss[_argsort,1:].T - _small)
        _sl = np.array(samples_large)
        _large = _sl[_argsort,0]
        _large_err = np.abs(_sl[_argsort,1:].T - _large)
        ix, iy = intersection(ps[-2:],[_[0] for _ in samples_small[-2:]],[_[0] for _ in samples_large[-2:]])
        if logfile:
            with open(logfile, 'w') as f:
                f.write(str((np.vstack([_ps, _small, _small_err, _large, _large_err]), iy, ix)))
        else:
            s.clear()
            s.errorbar(_ps,_small,yerr=_small_err,alpha=0.6,label=str(Lsmall))
            s.errorbar(_ps,_large,yerr=_large_err,alpha=0.6,label=str(Llarge))
            s.plot([ix],[iy],'ro',alpha=0.5)
            s.set_title('intersection at p = %f'%ix)
            s.set_yscale('log')
            display.clear_output(wait=True)
            display.display(f)

    return ps, samples_small, samples_large
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号