bricks.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lmkit 作者: jiangnanhugo 项目源码 文件源码
def apply(self, inputs, gate_inputs, mask=None):
        def step(inputs, gate_inputs, states, state_to_gates, state_to_state):
            #import ipdb
            #ipdb.set_trace()
            gate_values = self.gate_activation.apply(
                states.dot(self.state_to_gates) + gate_inputs)
            update_values = gate_values[:, :self.dim]
            reset_values = gate_values[:, self.dim:]
            states_reset = states * reset_values
            next_states = self.activation.apply(
                states_reset.dot(self.state_to_state) + inputs)
            next_states = (next_states * update_values +
                           states * (1 - update_values))
            return next_states

        def step_mask(inputs, gate_inputs, mask_input, states, state_to_gates, state_to_state):
            next_states = step(inputs, gate_inputs, states, state_to_gates, state_to_state)
            if mask_input:
                next_states = (mask_input[:, None] * next_states +
                               (1 - mask_input[:, None]) * states)
            return next_states


        if mask:
            func = step_mask
            sequences = [inputs, gate_inputs, mask]
        else:
            func = step
            sequences = [inputs, gate_inputs]
        #[dict(input=inputs), dict(input=gate_inputs), dict(input=mask)]
        output = tensor.repeat(self.params[2].dimshuffle('x',0), inputs.shape[1], axis=0)
        states_output, _ = theano.scan(fn=func,
                sequences=sequences,
                outputs_info=[output],
                non_sequences=[self.state_to_gates, self.state_to_state],
                strict=True,
                #allow_gc=False)
                )

        return states_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号