bricks.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:lmkit 作者: jiangnanhugo 项目源码 文件源码
def apply(self, inputs, update_inputs, reset_inputs, mask=None):
        def step(inputs, update_inputs, reset_inputs, states, state_to_update, state_to_reset, state_to_state):
            #import ipdb
            #ipdb.set_trace()
            reset_values = self.gate_activation.apply(
                    states.dot(self.state_to_reset) + reset_inputs)

            update_values = self.gate_activation.apply(
                    states.dot(self.state_to_update) + update_inputs)

            next_states_proposed = self.activation.apply(
                (states * reset_values).dot(self.state_to_state) + inputs)

            next_states = (next_states_proposed * update_values +
                           states * (1 - update_values))
            return next_states

        def step_mask(inputs, update_inputs, reset_inputs, mask_input, states, state_to_update, state_to_reset, state_to_state):
            next_states = step(inputs, updatE_inputs, reset_inputs, states, state_to_update, state_to_reset, state_to_state)
            if mask_input:
                next_states = (mask_input[:, None] * next_states +
                               (1 - mask_input[:, None]) * states)
            return next_states


        if mask:
            func = step_mask
            sequences = [inputs, update_inputs, reset_inputs, mask]
        else:
            func = step
            sequences = [inputs, update_inputs, reset_inputs]
        #[dict(input=inputs), dict(input=gate_inputs), dict(input=mask)]
        #output = tensor.repeat(self.params[2].dimshuffle('x',0), inputs.shape[1], axis=0)
        states_output, _ = theano.scan(fn=func,
                sequences=sequences,
                outputs_info=[self.initial_state('initial_state', inputs.shape[1])],
                non_sequences=[self.state_to_reset, self.state_to_update, self.state_to_state],
                strict=True,
                allow_gc=False)

        return states_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号