def forward(self, x, seq):
"""
:param x: (length, dim)
:param seq: (length - 1, 3)
:return:
"""
# (length, dim) -> (2 * length - 1, dim)
vector = T.concatenate([x, T.zeros_like(x)[:-1, :]], axis=0)
# vector = theano.printing.Print()(vector)
# scan length-1 times
hs, _ = theano.scan(fn=self.encode,
sequences=seq,
outputs_info=[vector, shared_scalar(0)],
name="compose_phrase")
comp_vec_init = hs[0][-1][-1]
comp_rec_init = T.sum(hs[1])
if self.normalize:
hidden = x[0] / x[0].norm(2)
else:
hidden = x[0]
comp_vec = ifelse(x.shape[0] > 1, comp_vec_init, hidden)
comp_rec = ifelse(x.shape[0] > 1, comp_rec_init, shared_zero_scalar())
return comp_vec, comp_rec
评论列表
文章目录