def transition(h):
# compute A,B,o linearization matrices
with tf.variable_scope("trans"):
for l in range(2):
h = ReLU(h, 100, "aggregate_loss" + str(l))
with tf.variable_scope("A"):
v, r = tf.split(1, 2, linear(h, z_dim * 2))
v1 = tf.expand_dims(v, -1) # (batch, z_dim, 1)
rT = tf.expand_dims(r, 1) # batch, 1, z_dim
I = tf.diag([1.] * z_dim)
A = (
I + tf.batch_matmul(v1, rT)
) # (z_dim, z_dim) + (batch, z_dim, 1)*(batch, 1, z_dim) (I is broadcasted)
with tf.variable_scope("B"):
B = linear(h, z_dim * u_dim)
B = tf.reshape(B, [-1, z_dim, u_dim])
with tf.variable_scope("o"):
o = linear(h, z_dim)
return A, B, o, v, r
评论列表
文章目录