def visualize_activations(self, x):
"""
Visualizes the activations in the mdn caused by a given data minibatch.
:param x: a minibatch of data
:return: none
"""
self.net.visualize_activations(x)
forwprop = theano.function(
inputs=[self.input],
outputs=[self.a, tt.concatenate(self.ms, axis=1) + tt.concatenate([tt.reshape(U, [U.shape[0], -1]) for U in self.Us], axis=1)]
)
activations = forwprop(x.astype(dtype))
for a, title in izip(activations, ['mixing coefficients', 'means', 'scale matrices']):
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.imshow(a, cmap='gray', interpolation='none')
ax.set_title(title)
ax.set_xlabel('layer units')
ax.set_ylabel('data points')
plt.show(block=False)
评论列表
文章目录