def __theano_build_train__(self):
params = self.params
params_names = self.param_names
hidden_dim = self.hidden_dim
batch_size = self.batch_size
# inputs[0], first sentence.
# inputs[1], second sentence.
# inputs[2], encoding target
inputs = T.itensor3("inputs")
masks = T.ftensor3("masks")
def rnn_cell(x, mx, ph, Wh):
h = T.tanh(ph.dot(Wh) + x)
h = mx[:, None] * h + (1-mx[:, None]) * ph
return [h] # size = sample * hidden : 3 * 4
# encoding first sentence
_state = params["E"][inputs[0].flatten(), :].reshape([inputs[0].shape[0], inputs[0].shape[1], hidden_dim])
_state = _state.dot(params["W"][0]) + params["B"][0]
[h1], updates = theano.scan(
fn=rnn_cell,
sequences=[_state, masks[0]],
truncate_gradient=self.truncate,
outputs_info=[dict(initial=T.zeros([batch_size, hidden_dim]))],
non_sequences=[params["W"][1]])
# decoding second sentence
_state = params["E"][inputs[1].flatten(), :].reshape([inputs[1].shape[0], inputs[1].shape[1], hidden_dim])
_state = _state.dot(params["W"][2]) + params["B"][1]
[h2], updates = theano.scan(
fn=rnn_cell,
sequences=[_state, masks[1]],
truncate_gradient=self.truncate,
outputs_info=[dict(initial=h1[-1])],
non_sequences=[params["W"][3]])
# Loss
_s = h2.dot(params["DecodeW"]) + params["DecodeB"]
_s = _s.reshape([_s.shape[0] * _s.shape[1], _s.shape[2]])
_s = T.nnet.softmax(_s)
_cost = T.nnet.categorical_crossentropy(_s, inputs[2].flatten())
_cost = T.sum(_cost * masks[2].flatten())
# SGD parameters
learning_rate = T.scalar("learning_rate")
decay = T.scalar("decay")
_grads, _updates = rms_prop(_cost, params_names, params, learning_rate, decay)
# Assign functions
self.bptt = theano.function([inputs, masks], _grads)
self.loss = theano.function([inputs, masks], _cost)
self.weights = theano.function([inputs, masks], _s)
self.sgd_step = theano.function(
[inputs, masks, learning_rate, decay], #theano.In(decay, value=0.9)],
updates=_updates)
评论列表
文章目录