def calculate_loss(model, target_model, transitions, configuration):
ValueTensorType = configuration.VALUE_TENSOR_TYPE
# Inverse of zip, transpose the batch, http://stackoverflow.com/a/19343/3343043
batch = Transition(*zip(*transitions)) # the * operator unpack,
# a collection to arguments, see below
# (S,A,R,S',T)^n -> (S^n,A^n,R^n,S'^n,T^n)
states = Variable(torch.cat(batch.state))
action_indices = Variable(torch.cat(batch.action))
rewards = Variable(torch.cat(batch.reward))
non_terminals = Variable(torch.cat(batch.non_terminal))
non_terminal_successor_states = [state for (state, non_terminal) in zip(
batch.successor_state, non_terminals.data) if non_terminal]
if len(non_terminal_successor_states) == 0:
return 0
non_terminal_successor_states = Variable(torch.cat(non_terminal_successor_states
))
Q_states = model(states).gather(1, action_indices)
Q_successors = model(non_terminal_successor_states)
if configuration.DOUBLE_DQN:
Q_successors = target_model(non_terminal_successor_states)
V_successors = Variable(
torch.zeros(configuration.BATCH_SIZE).type(ValueTensorType))
V_successors[non_terminals] = Q_successors.detach().max(1)[0]
Q_expected = rewards + (configuration.DISCOUNT_FACTOR * V_successors)
return F.smooth_l1_loss(Q_states, Q_expected)
评论列表
文章目录