uniform_seq.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MixtureOfExperts 作者: krishnakalyan3 项目源码 文件源码
def train_model(self, x_train, y_train, x_test, y_test, x_val, y_val):

        split_buckets = self.get_random()

        y_hat_train = 0
        y_hat_test = 0
        y_hat_val = 0
        for key in sorted(split_buckets):
            X = x_train[split_buckets[key]]
            y = y_train[split_buckets[key]]
            model = self.svc_model(X, y)
            y_hat_train += model.predict(x_train)
            y_hat_test += model.predict(x_test)
            y_hat_val += model.predict(x_val)

        y_hat_train *= (1/self.experts)
        y_hat_test *= (1 / self.experts)
        y_hat_val *= (1 / self.experts)

        train_error = (1 - accuracy_score(y_train, y_hat_train > 0.5)) * 100
        test_error = (1 - accuracy_score(y_test, y_hat_test > 0.5)) * 100
        val_error = (1 - accuracy_score(y_val, y_hat_val > 0.5)) * 100

        return train_error, val_error, test_error
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号