exp3.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:striatum 作者: ntucllab 项目源码 文件源码
def get_action(self, context=None, n_actions=None):
        """Return the action to perform

        Parameters
        ----------
        context : {array-like, None}
            The context of current state, None if no context available.

        n_actions: int (default: None)
            Number of actions wanted to recommend users. If None, only return
            one action. If -1, get all actions.

        Returns
        -------
        history_id : int
            The history id of the action.

        recommendations : list of dict
            Each dict contains
            {Action object, estimated_reward, uncertainty}.
        """
        if self._action_storage.count() == 0:
            return self._get_action_with_empty_action_storage(context,
                                                              n_actions)

        probs = self._exp3_probs()
        if n_actions == -1:
            n_actions = self._action_storage.count()

        action_ids = list(six.viewkeys(probs))
        prob_array = np.asarray([probs[action_id]
                                 for action_id in action_ids])
        recommendation_ids = self.random_state.choice(
            action_ids, size=n_actions, p=prob_array, replace=False)

        if n_actions is None:
            recommendations = self._recommendation_cls(
                action=self._action_storage.get(recommendation_ids),
                estimated_reward=probs[recommendation_ids],
                uncertainty=probs[recommendation_ids],
                score=probs[recommendation_ids],
            )
        else:
            recommendations = []  # pylint: disable=redefined-variable-type
            for action_id in recommendation_ids:
                recommendations.append(self._recommendation_cls(
                    action=self._action_storage.get(action_id),
                    estimated_reward=probs[action_id],
                    uncertainty=probs[action_id],
                    score=probs[action_id],
                ))

        history_id = self._history_storage.add_history(context, recommendations)
        return history_id, recommendations
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号