def update_target_network(source_network, target_network, update_rate):
target_network_update = []
for v in source_network.variables():
# this is equivalent to target = (1-alpha) * target + alpha * source
# print ("source: " + v.name + " : " + str(v.get_shape()))
pass
for v in target_network.variables():
# this is equivalent to target = (1-alpha) * target + alpha * source
# print ("target: " + v.name + " : " + str(v.get_shape()))
pass
for v_source, v_target in zip(source_network.variables(), target_network.variables()):
# this is equivalent to target = (1-alpha) * target + alpha * source
update_op = v_target.assign_sub(update_rate * (v_target - v_source))
target_network_update.append(update_op)
return tf.group(*target_network_update)
# def concat_nn_input(self, input1, input2):
# return tf.concat(1, [input1, input2])
# def add_pow_values(self, values):
# return self.concat_nn_input(values, 0.01 * tf.pow(values, [2 for i in range(self.action_size)]))
评论列表
文章目录