def _compile(self, ddqn):
a = self.inputs["A"]
r = self.inputs["R"]
nonterminal = self.inputs["Nonterminal"]
q = ls.get_output(self.network, deterministic=True)
if ddqn:
q2 = ls.get_output(self.network, deterministic=True, inputs=self.alternate_input_mappings)
q2_action_ref = tensor.argmax(q2, axis=1)
q2_frozen = ls.get_output(self.frozen_network, deterministic=True)
q2_max = q2_frozen[tensor.arange(q2_action_ref.shape[0]), q2_action_ref]
else:
q2_max = tensor.max(ls.get_output(self.frozen_network, deterministic=True), axis=1)
target_q = r + self.gamma * nonterminal * q2_max
predicted_q = q[tensor.arange(q.shape[0]), a]
loss = self.build_loss_expression(predicted_q, target_q).sum()
params = ls.get_all_params(self.network, trainable=True)
# updates = lasagne.updates.rmsprop(loss, params, self._learning_rate, rho=0.95)
updates = deepmind_rmsprop(loss, params, self.learning_rate)
# TODO does FAST_RUN speed anything up?
mode = None # "FAST_RUN"
s0_img = self.inputs["S0"]
s1_img = self.inputs["S1"]
if self.misc_state_included:
s0_misc = self.inputs["S0_misc"]
s1_misc = self.inputs["S1_misc"]
print "Compiling the training function..."
self._learn = theano.function([s0_img, s0_misc, s1_img, s1_misc, a, r, nonterminal], loss,
updates=updates, mode=mode, name="learn_fn")
print "Compiling the evaluation function..."
self._evaluate = theano.function([s0_img, s0_misc], q, mode=mode,
name="eval_fn")
else:
print "Compiling the training function..."
self._learn = theano.function([s0_img, s1_img, a, r, nonterminal], loss, updates=updates, mode=mode,
name="learn_fn")
print "Compiling the evaluation function..."
self._evaluate = theano.function([s0_img], q, mode=mode, name="eval_fn")
print "Network compiled."
评论列表
文章目录