def score_batch(self, e1, e2, r_index):
"""
:param e1: (batch, entity_dim, )
:param e2: (batch, entity_dim, )
:param r_index: (batch, )
:return:
"""
# (batch, entity_dim) dot (batch, entity_dim, entity_dim, hidden) dot (batch, entity_dim) -> hidden * (batch, )
hidden1_sep, _ = theano.scan(fn=self.step_batch,
sequences=[self.slice_seq],
non_sequences=[e1, e2, self.W[r_index]],
name='batch_scan')
# hidden * (batch, ) -> (batch, hidden)
hidden1 = T.concatenate([hidden1_sep], axis=1).transpose()
if self.keep_normal:
# (batch, 2 * entity_dim) dot (batch, 2 * entity_dim, hidden) -> (batch, hidden, )
hidden2 = T.batched_dot(T.concatenate([e1, e2], axis=1), self.V[r_index])
# (batch, hidden) + (batch, hidden) + (batch, hidden) -> (batch, hidden)
hidden = hidden1 + hidden2 + self.b[r_index]
else:
hidden = hidden1
# (batch, hidden) -> (batch, hidden)
act_hidden = self.act.activate(hidden)
# (batch, hidden) dot (batch, hidden) -> (batch, )
return T.sum(act_hidden * self.U[r_index], axis=1)
评论列表
文章目录