def __init__(self, incoming, num_units, ingate=Gate(), forgetgate=Gate(),
cell=Gate(W_cell=None, nonlinearity=nonlinearities.tanh), outgate=Gate(),
nonlinearity=nonlinearities.tanh, cell_init=init.Constant(0.), hid_init=init.Constant(0.),
backwards=False, learn_init=False, peepholes=True, gradient_steps=-1, grad_clipping=0,
precompute_input=True, mask_input=None,
encoder_mask_input=None, attention=False, word_by_word=False, **kwargs):
super(CustomLSTMDecoder, self).__init__(incoming, num_units, ingate, forgetgate, cell, outgate, nonlinearity,
cell_init, hid_init, backwards, learn_init, peepholes, gradient_steps,
grad_clipping, False, precompute_input, mask_input, True,
**kwargs)
self.attention = attention
self.word_by_word = word_by_word
# encoder mask
self.encoder_mask_incoming_index = -1
if encoder_mask_input is not None:
self.input_layers.append(encoder_mask_input)
self.input_shapes.append(encoder_mask_input.output_shape)
self.encoder_mask_incoming_index = len(self.input_layers) - 1
# check encoder
if not isinstance(self.cell_init, CustomLSTMEncoder) \
or self.num_units != self.cell_init.num_units:
raise ValueError('cell_init must be CustomLSTMEncoder'
' and num_units should equal')
self.r_init = None
self.r_init = self.add_param(init.Constant(0.),
(1, num_units), name="r_init",
trainable=False, regularizable=False)
if self.word_by_word:
# rewrites
self.attention = True
if self.attention:
if not isinstance(encoder_mask_input, lasagne.layers.Layer):
raise ValueError('Attention mechnism needs encoder mask layer')
# initializes attention weights
self.W_y_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'V_pointer')
self.W_h_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_h_attend')
# doesn't need transpose
self.w_attend = self.add_param(init.Normal(0.1), (num_units, 1), 'v_pointer')
self.W_p_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_p_attend')
self.W_x_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_x_attend')
if self.word_by_word:
self.W_r_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_r_attend')
self.W_t_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_t_attend')
评论列表
文章目录