fops.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:shuttleNet 作者: shiyemin 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        """Memory grid (MemGrid) with nunits cells."""
        with tf.variable_scope(scope or type(self).__name__):  # "MemGrid"
            with tf.variable_scope("Gates"):  # Reset gate and update gate.
                # We start with bias of 1.0 to not reset and not update.
                r, u = tf.split(self.unbalance_linear([inputs, self._memory],
                                                    2 * self._mem_dim, True, 1.0), 2, 2)
                r, u = sigmoid(r), sigmoid(u)
            with tf.variable_scope("Candidate"):
                c = self._activation(self.unbalance_linear([inputs, r * self._memory],
                                            self._mem_dim, True))
            # Decide which line to write: line weights
            l = att_weight(inputs, tf.concat([c, self._memory], 2), self.echocell, scope="Line_weights")
            l = tf.reshape(l, [self._batch_size, self._mem_size, 1])
            t_memory = u * self._memory + (1 - u) * c
            self._memory = self._memory * (1 - l) + t_memory * l

            #  hl = att_weight(inputs, self._memory, echocell, scope="hidden_lw")
            #  hl = tf.reshape(hl, [self._batch_size, self._mem_size, 1])
            #  output = tf.reduce_sum(hl * self._memory, 1)
            output = tf.reduce_sum(l * self._memory, 1)
            output = tf.reshape(output, [self._batch_size, self._mem_dim])

            return output, state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号