networks.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Sisyphus 作者: davidbrandfonbrener 项目源码 文件源码
def dOmega_dWrec(self):

        # states in shape timesteps, batch, n_rec
        states = self.states
        dxt_list = tf.gradients(self.error, states)

        #dxt_list[0] = tf.Print(dxt_list[0], [dxt_list[0]], "dxt 0: ")

        test = tf.gradients(states[0], states[-1])

        dxt = tf.stack(dxt_list)
        xt = tf.stack(states)

        num = (1 - self.alpha) * dxt + tf.tensordot(self.alpha * dxt ,
                                                    tf.transpose(
                                                    tf.matmul(tf.abs(self.W_rec) * self.rec_Connectivity,self.Dale_rec)),
                                                    axes=1) * \
                                        tf.where(tf.greater(xt, 0), tf.ones_like(xt), tf.zeros_like(xt))
        denom = dxt

        # sum over hidden units
        num = tf.reduce_sum(tf.square(num), axis=2)
        denom = tf.reduce_sum(tf.square(denom), axis=2)

        bounded = tf.where(tf.greater(denom, 1e-20), tf.div(num, 1.0 * denom), tf.ones_like(num))
        nelems = tf.reduce_mean(tf.where(tf.greater(denom, 1e-20), 1.0 * tf.ones_like(num), 1.0 * tf.zeros_like(num)), axis=1)

        # sum mean over each batch by time steps
        Omega = tf.square(bounded - 1.0)
        Omega = tf.reduce_sum(tf.reduce_mean(Omega, axis=1)) / (1.0 * tf.reduce_sum(nelems))

        out = tf.gradients(Omega, self.W_rec)

        out[0] = tf.Print(out[0], [out[0], self.W_rec, Omega], "omega grads")
        out[0] = tf.verify_tensor_all_finite(out[0], "dead omega grad")

        return out, test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号