def transition(h,share=None):
# compute A,B,o linearization matrices
with tf.variable_scope("trans",reuse=share):
for l in range(2):
h=ReLU(h,100,"aggregate_loss"+str(l))
with tf.variable_scope("A"):
v,r=tf.split(1,2,linear(h,z_dim*2))
v1=tf.expand_dims(v,-1) # (batch, z_dim, 1)
rT=tf.expand_dims(r,1) # batch, 1, z_dim
I=tf.diag([1.]*z_dim)
A=(I+tf.batch_matmul(v1,rT)) # (z_dim, z_dim) + (batch, z_dim, 1)*(batch, 1, z_dim) (I is broadcasted)
with tf.variable_scope("B"):
B=linear(h,z_dim*u_dim)
B=tf.reshape(B,[-1,z_dim,u_dim])
with tf.variable_scope("o"):
o=linear(h,z_dim)
return A,B,o,v,r
评论列表
文章目录