optimisation.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:agent 作者: sintefneodroid 项目源码 文件源码
def calculate_loss(model, target_model, transitions, configuration):
  ValueTensorType = configuration.VALUE_TENSOR_TYPE

  # Inverse of zip, transpose the batch, http://stackoverflow.com/a/19343/3343043
  batch = Transition(*zip(*transitions))  # the * operator unpack,
  # a collection to arguments, see below
  # (S,A,R,S',T)^n -> (S^n,A^n,R^n,S'^n,T^n)

  states = Variable(torch.cat(batch.state))
  action_indices = Variable(torch.cat(batch.action))
  rewards = Variable(torch.cat(batch.reward))
  non_terminals = Variable(torch.cat(batch.non_terminal))
  non_terminal_successor_states = [state for (state, non_terminal) in zip(
      batch.successor_state, non_terminals.data) if non_terminal]
  if len(non_terminal_successor_states) == 0:
    return 0
  non_terminal_successor_states = Variable(torch.cat(non_terminal_successor_states
                                                     ))

  Q_states = model(states).gather(1, action_indices)
  Q_successors = model(non_terminal_successor_states)

  if configuration.DOUBLE_DQN:
    Q_successors = target_model(non_terminal_successor_states)

  V_successors = Variable(
      torch.zeros(configuration.BATCH_SIZE).type(ValueTensorType))
  V_successors[non_terminals] = Q_successors.detach().max(1)[0]

  Q_expected = rewards + (configuration.DISCOUNT_FACTOR * V_successors)

  return F.smooth_l1_loss(Q_states, Q_expected)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号