DMC_query.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:QDREN 作者: andreamad8 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__, initializer=self._initializer):
            # Split the hidden state into blocks (each U, V, W are shared across blocks).

            U = tf.get_variable('U', [self._num_units_per_block, self._num_units_per_block])
            V = tf.get_variable('V', [self._num_units_per_block, self._num_units_per_block])
            W = tf.get_variable('W', [self._num_units_per_block, self._num_units_per_block])

            b = tf.get_variable('biasU',[self._num_units_per_block])

            state = tf.split(state, self._num_blocks, 1)
            next_states = []
            for j, state_j in enumerate(state): # Hidden State (j)
                key_j = self._keys[j]
                gate_j = self.get_gate(state_j, key_j, inputs)
                candidate_j = self.get_candidate(state_j, key_j, inputs, U, V, W, b)

                # Equation 4: h_j <- h_j + g_j * h_j^~
                # Perform an update of the hidden state (memory).
                state_j_next = state_j + tf.expand_dims(gate_j, -1) * candidate_j

                # # Forget previous memories by normalization.
                # Equation 5: h_j <- h_j / \norm{h_j}
                state_j_next = tf.nn.l2_normalize(state_j_next, -1) # TODO: Is epsilon necessary?


                # Forget previous memories by normalization.
                # state_j_next_norm = tf.norm(tensor=state_j_next,
                #                             ord='euclidean',
                #                             axis=-1,
                #                             keep_dims=True)
                # state_j_next_norm = tf.where(
                #     tf.greater(state_j_next_norm, 0.0),
                #     state_j_next_norm,
                #     tf.ones_like(state_j_next_norm))
                # state_j_next = state_j_next / state_j_next_norm


                next_states.append(state_j_next)
            state_next = tf.concat(next_states, 1)
        return state_next, state_next
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号