def calc_loss_dqn(batch, net, tgt_net, gamma, cuda=False):
states, actions, rewards, dones, next_states = unpack_batch(batch)
states_v = Variable(torch.from_numpy(states))
next_states_v = Variable(torch.from_numpy(next_states), volatile=True)
actions_v = Variable(torch.from_numpy(actions))
rewards_v = Variable(torch.from_numpy(rewards))
done_mask = torch.ByteTensor(dones)
if cuda:
states_v = states_v.cuda()
next_states_v = next_states_v.cuda()
actions_v = actions_v.cuda()
rewards_v = rewards_v.cuda()
done_mask = done_mask.cuda()
state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
next_state_values = tgt_net(next_states_v).max(1)[0]
next_state_values[done_mask] = 0.0
next_state_values.volatile = False
expected_state_action_values = next_state_values * gamma + rewards_v
return nn.MSELoss()(state_action_values, expected_state_action_values)
评论列表
文章目录