def __init__(self, S, A, maxlen=1000, mode=None, embedding_dim=1, **kwargs):
super(TableQ2, self).__init__(**kwargs)
self.S = S
self.A = A
if mode == None:
if type(S) == type(A) == gym.spaces.Discrete:
self.mode = 'array'
elif type(A) == gym.spaces.Discrete:
self.mode = 'dictionary'
else:
pass
self.mode = mode
self.maxlen = maxlen
self.embedding_dim = embedding_dim
if self.mode == 'array':
s_dim = get_space_dim(S)
a_dim = get_space_dim(A)
self.table = np.zeros((s_dim, a_dim))
self.maxlen = s_dim
elif self.mode == 'dictionary':
self.table = {0: np.zeros(self.A.n)}
elif self.mode == 'tables':
self.k = 4
self.neigh = KNeighborsRegressor(n_neighbors=self.k)
self.states = np.zeros((self.maxlen,self.embedding_dim))
self.values = np.zeros((self.maxlen, self.A.n))
self.recency= np.zeros((self.maxlen,))
self.i = 0
elif self.mode == 'action_tables':
#self.states = []
#self.recency= []
self.k = 4
self.action_tables = [ [[],[], KNeighborsRegressor(n_neighbors=self.k), []]
for _ in xrange(self.A.n)]
"""
for at in self.action_tables:
states, values, neigh, recency = at
for _ in xrange(self.k):
if self.embedding_dim > 1:
states.append(np.ones(self.embedding_dim))
else:
states.append(1)
values.append(0)
recency.append(0)
#print states, values
#neigh.fit(np.array(states), np.array(values))
s = self._list_to_sklearn(states)
v = self._list_to_sklearn(values)
#print s, v
neigh.fit(s, v)
"""
else:
raise NotImplementedError, 'Sorry, TableQ only supports three modes.'
评论列表
文章目录