def _init_model(self, in_size, out_size, n_hid=10, learning_rate_sl=0.005, \
learning_rate_rl=0.005, batch_size=32, ment=0.1):
# 2-layer MLP
self.in_size = in_size # x and y coordinate
self.out_size = out_size # up, down, right, left
self.batch_size = batch_size
self.learning_rate = learning_rate_rl
self.n_hid = n_hid
input_var, turn_mask, act_mask, reward_var = T.ftensor3('in'), T.imatrix('tm'), \
T.itensor3('am'), T.fvector('r')
in_var = T.reshape(input_var, (input_var.shape[0]*input_var.shape[1],self.in_size))
l_mask_in = L.InputLayer(shape=(None,None), input_var=turn_mask)
pol_in = T.fmatrix('pol-h')
l_in = L.InputLayer(shape=(None,None,self.in_size), input_var=input_var)
l_pol_rnn = L.GRULayer(l_in, n_hid, hid_init=pol_in, mask_input=l_mask_in) # B x H x D
pol_out = L.get_output(l_pol_rnn)[:,-1,:]
l_den_in = L.ReshapeLayer(l_pol_rnn, (turn_mask.shape[0]*turn_mask.shape[1], n_hid)) # BH x D
l_out = L.DenseLayer(l_den_in, self.out_size, nonlinearity=lasagne.nonlinearities.softmax)
self.network = l_out
self.params = L.get_all_params(self.network)
# rl
probs = L.get_output(self.network) # BH x A
out_probs = T.reshape(probs, (input_var.shape[0],input_var.shape[1],self.out_size)) # B x H x A
log_probs = T.log(out_probs)
act_probs = (log_probs*act_mask).sum(axis=2) # B x H
ep_probs = (act_probs*turn_mask).sum(axis=1) # B
H_probs = -T.sum(T.sum(out_probs*log_probs,axis=2),axis=1) # B
self.loss = 0.-T.mean(ep_probs*reward_var + ment*H_probs)
updates = lasagne.updates.rmsprop(self.loss, self.params, learning_rate=learning_rate_rl, \
epsilon=1e-4)
self.inps = [input_var, turn_mask, act_mask, reward_var, pol_in]
self.train_fn = theano.function(self.inps, self.loss, updates=updates)
self.obj_fn = theano.function(self.inps, self.loss)
self.act_fn = theano.function([input_var, turn_mask, pol_in], [out_probs, pol_out])
# sl
sl_loss = 0.-T.mean(ep_probs)
sl_updates = lasagne.updates.rmsprop(sl_loss, self.params, learning_rate=learning_rate_sl, \
epsilon=1e-4)
self.sl_train_fn = theano.function([input_var, turn_mask, act_mask, pol_in], sl_loss, \
updates=sl_updates)
self.sl_obj_fn = theano.function([input_var, turn_mask, act_mask, pol_in], sl_loss)
评论列表
文章目录