def initial_states(self, batch_size, *args, **kwargs):
"""Returns the initial state depending on ``init_strategy``."""
attended = kwargs['attended']
if self.init_strategy == 'constant':
initial_state = [tensor.repeat(self.parameters[2][None, :],
batch_size,
0)]
elif self.init_strategy == 'last':
initial_state = self.initial_transformer.apply(
attended[0, :, -self.attended_dim:])
elif self.init_strategy == 'average':
initial_state = self.initial_transformer.apply(
attended[:, :, -self.attended_dim:].mean(0))
else:
logging.fatal("dec_init parameter %s invalid" % self.init_strategy)
return initial_state
评论列表
文章目录