def __init__(self, network, target_network, lr=0.01, learn_start = 1000, batch_size = 32, map_dim = 10, gamma = 0.95, replay_size = 10000, instr_len = 7, layout_channels = 1, object_channels = 1):
self.network = network
self.target_network = target_network
self._copy_net()
self.learn_start = learn_start
self.batch_size = batch_size
self.gamma = gamma
self.replay_size = replay_size
self.instr_len = instr_len
self.layout_channels = layout_channels
self.object_channels = object_channels
self._refresh_size(map_dim, map_dim)
self.criterion = F.smooth_l1_loss
self.optimizer = optim.RMSprop(self.network.parameters(), lr=lr)
评论列表
文章目录