mng.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:dyfunconn 作者: makism 项目源码 文件源码
def fit(self, data):
        """

        :param data:
        :return:
        """
        [n_samples, n_obs] = data.shape
        self.protos = data[self.rng.choice(n_samples, self.n_protos),] # w
        self.context = np.zeros(self.protos.shape)                     # c

        ct = np.zeros((1, n_obs))
        wr = ct
        cr = wr
        for iteration in range(self.iterations):
            sample = data[self.rng.choice(n_samples, 1),]

            ct = (1 - self.a) * wr + self.b * cr

            t = iteration / float(self.iterations)
            lrate = self.lrate_i * (self.lrate_f / float(self.lrate_i)) ** t
            epsilon = self.epsilon_i * (self.lrate_f / float(self.lrate_i)) ** t

            d = (1 - self.a) * pairwise_distances(sample, self.protos) + self.a * pairwise_distances(ct, self.context)
            I = np.argsort(np.argsort(d))

            min_id = np.where(I == 0)[0]

            H = np.exp(-I / epsilon).ravel()

            diff_w = sample - self.protos
            diff_c = ct - self.context
            for i in range(self.n_protos):
                self.protos[i, :] += lrate * H[i] * diff_w[i, :]
                self.context[i, :] += lrate * H[i] * diff_c[i, :]

            wr = self.protos[min_id]
            cr = self.context[min_id]

        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号