gater_seq.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:MixtureOfExperts 作者: krishnakalyan3 项目源码 文件源码
def train_model(self, X_train, y_train, X_test, y_test, X_val, y_val):

        for i in range(self.iters):
            split_buckets = self.bucket_function(i)
            experts_out_train = np.empty((self.train_dim[0], self.experts), dtype='float64')
            experts_out_test = np.empty((self.test_dim[0], self.experts), dtype='float64')
            experts_out_val = np.empty((self.val_dim[0], self.experts), dtype='float64')

            j = 0
            for expert_index in sorted(split_buckets):
                print("############################# Expert {} Iter {} ################################".format(j, i))
                X = X_train[split_buckets[expert_index]]
                y = y_train[split_buckets[expert_index]]
                model = self.svc_model(X, y, X_test, y_test, X_val, y_val, i, j)

                experts_out_train[:, expert_index] = model.predict(X_train)
                experts_out_test[:, expert_index] = model.predict(X_test)
                experts_out_val[:, expert_index] = model.predict(X_val)

                j += 1

            gater_model = self.gater()
            early_callback = CustomCallback()
            tb_callback = TensorBoard(log_dir=self.tf_log + str(i))
            history = gater_model.fit([X_train, experts_out_train], y_train, shuffle=True,
                                      batch_size=self.batch_size, verbose=1, validation_data=([X_val, experts_out_val], y_val),
                                      epochs=1000, callbacks=[tb_callback, early_callback])

            train_accuracy = self.moe_eval(gater_model, X_train, y_train, experts_out_train)
            test_accuracy = self.moe_eval(gater_model, X_test, y_test, experts_out_test)
            val_accuracy = self.moe_eval(gater_model, X_val, y_val, experts_out_val)

            print('Train Accuracy', train_accuracy)
            print('Test Accuracy', test_accuracy)
            print('Val Accuracy', val_accuracy)

            tre = 100 - train_accuracy
            tte = 100 - test_accuracy
            vale = 100 - val_accuracy
            expert_units = Model(inputs=gater_model.input,
                                outputs=gater_model.get_layer('layer_op_2').output)

            self.wm_xi = expert_units.predict([X_train, experts_out_train])

            logging.info('{}, {}, {}, {}'.format(i, tre, vale, tte))

        return None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号