def get_output(self, train=False):
input = self.get_input(train)
proj_input = self.activation(T.tensordot(input, self.att_proj, axes=(3,0)))
if self.context == 'word':
att_scores = T.tensordot(proj_input, self.att_scorer, axes=(3, 0))
elif self.context == 'clause':
def step(a_t, h_tm1, W_in, W, sc):
h_t = T.tanh(T.tensordot(a_t, W_in, axes=(2,0)) + T.tensordot(h_tm1, W, axes=(2,0)))
s_t = T.tensordot(h_t, sc, axes=(2,0))
return h_t, s_t
[_, scores], _ = theano.scan(step, sequences=[proj_input.dimshuffle(2,0,1,3)], outputs_info=[T.zeros((proj_input.shape[0], self.td1, self.rec_hid_dim)), None], non_sequences=[self.rec_in_weights, self.rec_hid_weights, self.att_scorer])
att_scores = scores.dimshuffle(1,2,0)
elif self.context == 'para':
att_scores = T.tensordot(proj_input, self.att_scorer, axes=(3, 2)).sum(axis=(1, 2))
# Nested scans. For shame!
def get_sample_att(sample_input, sample_att):
sample_att_inp, _ = theano.scan(fn=lambda s_att_i, s_input_i: T.dot(s_att_i, s_input_i), sequences=[T.nnet.softmax(sample_att), sample_input])
return sample_att_inp
att_input, _ = theano.scan(fn=get_sample_att, sequences=[input, att_scores])
return att_input
评论列表
文章目录