def optimize_model():
global last_sync
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
# Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for
# detailed explanation).
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
batch.next_state)))
# We don't want to backprop through the expected action values and volatile
# will save us on temporarily changing the model parameters'
# requires_grad to False!
non_final_next_states = Variable(torch.cat([s for s in batch.next_state
if s is not None]),
volatile=True)
state_batch = Variable(torch.cat(batch.state))
action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward))
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken
state_action_values = model(state_batch).gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
# Now, we don't want to mess up the loss with a volatile flag, so let's
# clear it. After this, we'll just end up with a Variable that has
# requires_grad=False
next_state_values.volatile = False
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
# Compute Huber loss
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in model.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
######################################################################
#
# Below, you can find the main training loop. At the beginning we reset
# the environment and initialize the ``state`` variable. Then, we sample
# an action, execute it, observe the next screen and the reward (always
# 1), and optimize our model once. When the episode ends (our model
# fails), we restart the loop.
#
# Below, `num_episodes` is set small. You should download
# the notebook and run lot more epsiodes.
评论列表
文章目录