def analyse_(self, inputs, outputs, idx2word, inputs_unk=None, return_attend=False, name=None, display=False):
def cut_zero(sample, idx2word, ppp=None, Lmax=None):
if Lmax is None:
Lmax = self.config['dec_voc_size']
if ppp is None:
if 0 not in sample:
return ['{}'.format(idx2word[w].encode('utf-8'))
if w < Lmax else '{}'.format(idx2word[inputs[w - Lmax]].encode('utf-8'))
for w in sample]
return ['{}'.format(idx2word[w].encode('utf-8'))
if w < Lmax else '{}'.format(idx2word[inputs[w - Lmax]].encode('utf-8'))
for w in sample[:sample.index(0)]]
else:
if 0 not in sample:
return ['{0} ({1:1.1f})'.format(
idx2word[w].encode('utf-8'), p)
if w < Lmax
else '{0} ({1:1.1f})'.format(
idx2word[inputs[w - Lmax]].encode('utf-8'), p)
for w, p in zip(sample, ppp)]
idz = sample.index(0)
return ['{0} ({1:1.1f})'.format(
idx2word[w].encode('utf-8'), p)
if w < Lmax
else '{0} ({1:1.1f})'.format(
idx2word[inputs[w - Lmax]].encode('utf-8'), p)
for w, p in zip(sample[:idz], ppp[:idz])]
if inputs_unk is None:
result, _, ppp = self.generate_(inputs[None, :],
return_attend=return_attend)
else:
result, _, ppp = self.generate_(inputs_unk[None, :],
return_attend=return_attend)
source = '{}'.format(' '.join(cut_zero(inputs.tolist(), idx2word, Lmax=len(idx2word))))
target = '{}'.format(' '.join(cut_zero(outputs.tolist(), idx2word, Lmax=len(idx2word))))
decode = '{}'.format(' '.join(cut_zero(result, idx2word)))
if display:
print source
print target
print decode
idz = result.index(0)
p1, p2 = [np.asarray(p) for p in zip(*ppp)]
print p1.shape
import pylab as plt
# plt.rc('text', usetex=True)
# plt.rc('font', family='serif')
visualize_(plt.subplots(), 1 - p1[:idz, :].T, grid=True, name=name)
visualize_(plt.subplots(), 1 - p2[:idz, :].T, name=name)
# visualize_(plt.subplots(), 1 - np.mean(p2[:idz, :], axis=1, keepdims=True).T)
return target == decode
评论列表
文章目录