def fit(self, state, action, q, **fit_params):
"""
Fit the model.
Args:
state (np.ndarray): states;
action (np.ndarray): actions;
q (np.ndarray): target q-values;
**fit_params (dict): other parameters used by the fit method
of each regressor.
"""
state, q = self._preprocess(state, q)
for i in xrange(len(self.model)):
idxs = np.argwhere((action == i)[:, 0]).ravel()
if idxs.size:
self.model[i].fit(state[idxs, :], q[idxs], **fit_params)
评论列表
文章目录