custom_layers.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:MachineComprehension 作者: sa-j 项目源码 文件源码
def __init__(self, incoming, num_units, peepholes=True,
                 backwards=False, mask_input=None, only_return_final=True,
                 encoder_input=None, encoder_mask_input=None, **kwargs):
        super(MatchLSTM, self).__init__(incoming, num_units, peepholes=peepholes,
                                        backwards=backwards,
                                        precompute_input=False, mask_input=mask_input,
                                        only_return_final=only_return_final, **kwargs)
        # encoder mask
        self.encoder_input_incoming_index = -1
        self.encoder_mask_incoming_index = -1

        if encoder_mask_input is not None:
            self.input_layers.append(encoder_mask_input)
            self.input_shapes.append(encoder_mask_input.output_shape)
            self.encoder_mask_incoming_index = len(self.input_layers) - 1
        if encoder_input is not None:
            self.input_layers.append(encoder_input)
            encoder_input_output_shape = encoder_input.output_shape
            self.input_shapes.append(encoder_input_output_shape)
            self.encoder_input_incoming_index = len(self.input_layers) - 1
            # hidden state length should equal to embedding size
            assert encoder_input_output_shape[-1] == num_units
            # input features length should equal to embedding size plus hidden state length
            assert encoder_input_output_shape[-1] + num_units == self.input_shapes[0][-1]

        # initializes attention weights
        self.W_y_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'V_pointer')
        self.W_h_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_h_attend')
        # doesn't need transpose
        self.w_attend = self.add_param(init.Normal(0.1), (num_units, 1), 'v_pointer')
        self.W_m_attend = self.add_param(init.Normal(0.1), (num_units, num_units), 'W_a_pointer')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号