ntm.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:ntm_keras 作者: flomlo 项目源码 文件源码
def build(self, input_shape):
        bs, input_length, input_dim = input_shape

        self.controller_input_dim, self.controller_output_dim = controller_input_output_shape(
                input_dim, self.units, self.m_depth, self.n_slots, self.shift_range, self.read_heads,
                self.write_heads)

        # Now that we've calculated the shape of the controller, we have add it to the layer/model.
        if self.controller is None:
            self.controller = Dense(
                name = "controller",
                activation = 'linear',
                bias_initializer = 'zeros',
                units = self.controller_output_dim,
                input_shape = (bs, input_length, self.controller_input_dim))
            self.controller.build(input_shape=(self.batch_size, input_length, self.controller_input_dim))
            self.controller_with_state = False


        # This is a fixed shift matrix
        self.C = _circulant(self.n_slots, self.shift_range)

        self.trainable_weights = self.controller.trainable_weights 

        # We need to declare the number of states we want to carry around.
        # In our case the dimension seems to be 6 (LSTM) or 5 (GRU) or 4 (FF),
        # see self.get_initial_states, those respond to:
        # [old_ntm_output] + [init_M, init_wr, init_ww] +  [init_h] (LSMT and GRU) + [(init_c] (LSTM only))
        # old_ntm_output does not make sense in our world, but is required by the definition of the step function we
        # intend to use.
        # WARNING: What self.state_spec does is only poorly understood,
        # I only copied it from keras/recurrent.py.
        self.states = [None, None, None, None]
        self.state_spec = [InputSpec(shape=(None, self.output_dim)),                            # old_ntm_output
                            InputSpec(shape=(None, self.n_slots, self.m_depth)),                # Memory
                            InputSpec(shape=(None, self.read_heads, self.n_slots)),   # weights_read
                            InputSpec(shape=(None, self.write_heads, self.n_slots))]  # weights_write

        super(NeuralTuringMachine, self).build(input_shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号