def dOmega_dWrec(self):
# states in shape timesteps, batch, n_rec
states = self.states
dxt_list = tf.gradients(self.error, states)
#dxt_list[0] = tf.Print(dxt_list[0], [dxt_list[0]], "dxt 0: ")
test = tf.gradients(states[0], states[-1])
dxt = tf.stack(dxt_list)
xt = tf.stack(states)
num = (1 - self.alpha) * dxt + tf.tensordot(self.alpha * dxt ,
tf.transpose(
tf.matmul(tf.abs(self.W_rec) * self.rec_Connectivity,self.Dale_rec)),
axes=1) * \
tf.where(tf.greater(xt, 0), tf.ones_like(xt), tf.zeros_like(xt))
denom = dxt
# sum over hidden units
num = tf.reduce_sum(tf.square(num), axis=2)
denom = tf.reduce_sum(tf.square(denom), axis=2)
bounded = tf.where(tf.greater(denom, 1e-20), tf.div(num, 1.0 * denom), tf.ones_like(num))
nelems = tf.reduce_mean(tf.where(tf.greater(denom, 1e-20), 1.0 * tf.ones_like(num), 1.0 * tf.zeros_like(num)), axis=1)
# sum mean over each batch by time steps
Omega = tf.square(bounded - 1.0)
Omega = tf.reduce_sum(tf.reduce_mean(Omega, axis=1)) / (1.0 * tf.reduce_sum(nelems))
out = tf.gradients(Omega, self.W_rec)
out[0] = tf.Print(out[0], [out[0], self.W_rec, Omega], "omega grads")
out[0] = tf.verify_tensor_all_finite(out[0], "dead omega grad")
return out, test
评论列表
文章目录