dynamic_memory_cell.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:recurrent-entity-networks 作者: jimfleming 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__, initializer=self._initializer):
            U = tf.get_variable('U', [self._num_units_per_block, self._num_units_per_block],
                                initializer=self._recurrent_initializer)
            V = tf.get_variable('V', [self._num_units_per_block, self._num_units_per_block],
                                initializer=self._recurrent_initializer)
            W = tf.get_variable('W', [self._num_units_per_block, self._num_units_per_block],
                                initializer=self._recurrent_initializer)

            U_bias = tf.get_variable('U_bias', [self._num_units_per_block])

            # Split the hidden state into blocks (each U, V, W are shared across blocks).
            state = tf.split(state, self._num_blocks, axis=1)

            next_states = []
            for j, state_j in enumerate(state): # Hidden State (j)
                key_j = tf.expand_dims(self._keys[j], axis=0)
                gate_j = self.get_gate(state_j, key_j, inputs)
                candidate_j = self.get_candidate(state_j, key_j, inputs, U, V, W, U_bias)

                # Equation 4: h_j <- h_j + g_j * h_j^~
                # Perform an update of the hidden state (memory).
                state_j_next = state_j + tf.expand_dims(gate_j, -1) * candidate_j

                # Equation 5: h_j <- h_j / \norm{h_j}
                # Forget previous memories by normalization.
                state_j_next_norm = tf.norm(
                    tensor=state_j_next,
                    ord='euclidean',
                    axis=-1,
                    keep_dims=True)
                state_j_next_norm = tf.where(
                    tf.greater(state_j_next_norm, 0.0),
                    state_j_next_norm,
                    tf.ones_like(state_j_next_norm))
                state_j_next = state_j_next / state_j_next_norm

                next_states.append(state_j_next)
            state_next = tf.concat(next_states, axis=1)
        return state_next, state_next
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号