def build_functions(self):
S = Input(shape=self.state_size)
NS = Input(shape=self.state_size)
A = Input(shape=(1,), dtype='int32')
R = Input(shape=(1,), dtype='float32')
T = Input(shape=(1,), dtype='int32')
self.build_model()
self.value_fn = K.function([S], self.model(S))
VS = self.model(S)
VNS = disconnected_grad(self.model(NS))
future_value = (1-T) * VNS.max(axis=1, keepdims=True)
discounted_future_value = self.discount * future_value
target = R + discounted_future_value
cost = ((VS[:, A] - target)**2).mean()
opt = RMSprop(0.0001)
params = self.model.trainable_weights
updates = opt.get_updates(params, [], cost)
self.train_fn = K.function([S, NS, A, R, T], cost, updates=updates)
评论列表
文章目录