def discount(rewards, gamma):
tensor = False
if not isinstance(rewards, list):
tensor = True
rewards = rewards.split(1)
R = 0.0
discounted = []
for r in rewards[::-1]:
R = r + gamma * R
discounted.insert(0, R)
if tensor:
return th.cat(discounted).view(-1)
return T(discounted)
评论列表
文章目录