def q_value(q_dist, num_atoms, num_actions, V_max, delta_z):
V_min = -V_max
start = V_min
end = V_max + delta_z
delta = delta_z
z = tf.range(start, end, delta)
q_as = []
for action in range(num_actions):
dist = q_dist[:, num_atoms*action: num_atoms*(action+1)]
q_a = tf.reduce_sum(tf.multiply(dist, z), axis = 1, keep_dims = True)
q_as.append(q_a)
q_values = tf.concat(q_as, axis=1)
return q_values
build_graph.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录