def interpolate(self, x1, x2, n):
''' Interpolation from the latent space '''
x1 = tf.expand_dims(x1, 0)
x2 = tf.expand_dims(x2, 0)
z1, _ = self._encode(x1, is_training=False)
z2, _ = self._encode(x2, is_training=False)
def L2norm(x):
return tf.sqrt(tf.reduce_sum(tf.square(x), -1))
norm1 = L2norm(z1)
norm2 = L2norm(z2)
theta = tf.matmul(z1/norm1, z2/norm2, transpose_b=True)
a = tf.reshape(tf.linspace(0., 1., n), [n, 1]) # 10x1
a1 = tf.sin((1. - a) * theta) / tf.sin(theta)
a2 = tf.sin(a * theta) / tf.sin(theta)
z = a1 * z1 + a2 * z2
xh = self._generate(z, is_training=False)
xh = tf.concat(0, [x1, xh, x2])
return xh
评论列表
文章目录