def __init__(self, incoming, num_units, peepholes=True,
backwards=False, mask_input=None, only_return_final=True,
encoder_input=None, encoder_mask_input=None, **kwargs):
super(MatchLSTM, self).__init__(incoming, num_units, peepholes=peepholes,
backwards=backwards,
precompute_input=False, mask_input=mask_input,
only_return_final=only_return_final, **kwargs)
# encoder mask
self.encoder_input_incoming_index = -1
self.encoder_mask_incoming_index = -1
if encoder_mask_input is not None:
self.input_layers.append(encoder_mask_input)
self.input_shapes.append(encoder_mask_input.output_shape)
self.encoder_mask_incoming_index = len(self.input_layers) - 1
if encoder_input is not None:
self.input_layers.append(encoder_input)
encoder_input_output_shape = encoder_input.output_shape
self.input_shapes.append(encoder_input_output_shape)
self.encoder_input_incoming_index = len(self.input_layers) - 1
# hidden state length should equal to embedding size
assert encoder_input_output_shape[-1] == num_units
# input features length should equal to embedding size plus hidden state length
assert encoder_input_output_shape[-1] + num_units == self.input_shapes[0][-1]
# initializes attention weights
self.W_y_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'V_pointer')
self.W_h_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_h_attend')
# doesn't need transpose
self.w_attend = self.add_param(init.Normal(0.1), (num_units, 1), 'v_pointer')
self.W_m_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_a_pointer')
评论列表
文章目录