def __call__(self):
mem_optimize = nmtrain.optimization.chainer_mem_optimize
# Calculate Attention vector
a = self.attention(self.S, self.h)
# Calculate context vector
c = F.squeeze(F.batch_matmul(self.S, a, transa=True), axis=2)
# Calculate hidden vector + context
self.ht = self.context_project(F.concat((self.h, c), axis=1))
# Calculate Word probability distribution
y = mem_optimize(self.affine_vocab, F.tanh(self.ht), level=1)
if self.use_lexicon:
y = self.lexicon_model(y, a, self.ht, self.lexicon_matrix)
if nmtrain.environment.is_train():
return nmtrain.models.decoders.Output(y=y)
else:
# Return the vocabulary size output projection
return nmtrain.models.decoders.Output(y=y, a=a)
评论列表
文章目录