def compute_policy_gradient_full_correction(
action_distrib, action_distrib_mu, action_value, v,
truncation_threshold):
"""Compute off-policy bias correction term wrt all actions."""
assert truncation_threshold is not None
assert np.isscalar(v)
with chainer.no_backprop_mode():
rho_all_inv = compute_full_importance(action_distrib_mu,
action_distrib)
correction_weight = (
np.maximum(1 - truncation_threshold * rho_all_inv,
np.zeros_like(rho_all_inv)) *
action_distrib.all_prob.data[0])
correction_advantage = action_value.q_values.data[0] - v
return -F.sum(correction_weight *
action_distrib.all_log_prob *
correction_advantage, axis=1)
评论列表
文章目录