value.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:KerasRL 作者: aejax 项目源码 文件源码
def __init__(self, S, A, maxlen=1000, mode=None, embedding_dim=1, **kwargs):
        super(TableQ2, self).__init__(**kwargs)
        self.S = S
        self.A = A

        if mode == None:
            if type(S) == type(A) == gym.spaces.Discrete:
                self.mode = 'array'
            elif type(A) == gym.spaces.Discrete:
                self.mode = 'dictionary'
            else:
                pass
        self.mode = mode
        self.maxlen = maxlen
        self.embedding_dim = embedding_dim

        if self.mode == 'array':
            s_dim = get_space_dim(S)
            a_dim = get_space_dim(A)
            self.table = np.zeros((s_dim, a_dim))
            self.maxlen = s_dim
        elif self.mode == 'dictionary':
            self.table = {0: np.zeros(self.A.n)}
        elif self.mode == 'tables':
            self.k = 4
            self.neigh = KNeighborsRegressor(n_neighbors=self.k)
            self.states = np.zeros((self.maxlen,self.embedding_dim))
            self.values = np.zeros((self.maxlen, self.A.n))
            self.recency= np.zeros((self.maxlen,))
            self.i = 0
        elif self.mode == 'action_tables':
            #self.states = []
            #self.recency= []
            self.k = 4
            self.action_tables = [ [[],[], KNeighborsRegressor(n_neighbors=self.k), []]
                                   for _ in xrange(self.A.n)]
            """
            for at in self.action_tables:
                states, values, neigh, recency = at
                for _ in xrange(self.k):
                    if self.embedding_dim > 1:
                        states.append(np.ones(self.embedding_dim))
                    else:
                        states.append(1)
                    values.append(0)
                    recency.append(0)
                #print states, values
                #neigh.fit(np.array(states), np.array(values))
                s = self._list_to_sklearn(states)
                v = self._list_to_sklearn(values)
                #print s, v
                neigh.fit(s, v)
            """
        else:
            raise NotImplementedError, 'Sorry, TableQ only supports three modes.'
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号