def plotVAEplotly(self, logdir, prefix, ctable=None, reverseUtt=False, batch_size=128, debug=False):
ticks = [[-1,-0.5,0,0.5,1]]*self.latentDim
samplePoints = np.array(np.meshgrid(*ticks)).T.reshape(-1,3)
input_placeholder = np.ones(tuple([len(samplePoints)] + list(self.phon.output_shape[1:-1]) + [1]))
preds = self.decode_word([samplePoints, input_placeholder], batch_size=batch_size)
if reverseUtt:
preds = getYae(preds, reverseUtt)
reconstructed = reconstructXae(np.expand_dims(preds.argmax(-1), -1), ctable, maxLen=5)
data = [go.Scatter3d(
x = samplePoints[:,0],
y = samplePoints[:,1],
z = samplePoints[:,2],
text = reconstructed,
mode='text'
)]
layout = go.Layout()
fig = go.Figure(data=data, layout=layout)
plotly.offline.plot(fig, filename=logdir + '/' + prefix + '_VAEplot.html', auto_open=False)
评论列表
文章目录