def forward(self, x, *args, **kwargs):
action = super(DiscretePolicy, self).forward(x, *args, **kwargs)
probs = F.softmax(action.raw)
action.value = probs.multinomial().detach()
action.prob = lambda: probs.t()[action.value[:, 0]].mean(1)
action.compute_log_prob = lambda a: F.log_softmax(action.raw).t()[a[:, 0]].mean(1)
action.log_prob = action.compute_log_prob(action.value)
action.entropy = -(action.prob() * action.log_prob)
return action
评论列表
文章目录