swarm.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:gmdh 作者: parrt 项目源码 文件源码
def BBPSO_cost(ITERATIONS, SWARM_SIZE):
    # initialize the particles
    particles = [Particle(Network([784, 30, 10])) for i in range(SWARM_SIZE)]
    for p in particles: p.best_score = 1e20
    gbest = None
    for it in range(ITERATIONS):
        # update global best with best of all particles
        gbest = particles[0].best
        gbest_score = particles[0].best_score
        for i in range(SWARM_SIZE):
            p = particles[i]
            if p.best_score < gbest_score:
                gbest = p.best
                gbest_score = p.best_score
        fit = gbest.fitness(X, labels)
        if it % 100 == 0:
            print str(it)+": global best score " + str(gbest_score)+" correct "+str(fit)

        for i in range(SWARM_SIZE):
            p = particles[i]
            pmu = p.best.biases + gbest.biases, \
                  p.best.weights + gbest.weights
            pmu = pmu[0] / 2.0, pmu[1] / 2.0
            psigma = np.abs(p.best.biases - gbest.biases), \
                     np.abs(p.best.weights - gbest.weights)
            pos = Network([784,30,10], mu=pmu, sigma=psigma)
            p.pos = pos
            c = pos.cost(X, labels)
            if c < p.best_score:
                p.best = pos
                p.best_score = c
    print "final best score " + str(gbest_score) + " correct " + str(fit)
    return gbest

# BBPSO()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号