def generate_(self, inputs, mode='display', return_attend=False, return_all=False):
# assert self.config['sample_stoch'], 'RNNLM sampling must be stochastic'
# assert not self.config['sample_argmax'], 'RNNLM sampling cannot use argmax'
args = dict(k=self.config['sample_beam'],
maxlen=self.config['max_len'],
stochastic=self.config['sample_stoch'] if mode == 'display' else None,
argmax=self.config['sample_argmax'] if mode == 'display' else None,
return_attend=return_attend)
context, _, c_mask, _, Z, R = self.encoder.gtenc(inputs)
# c_mask[0, 3] = c_mask[0, 3] * 0
# L = context.shape[1]
# izz = np.concatenate([np.arange(3), np.asarray([1,2]), np.arange(3, L)])
# context = context[:, izz, :]
# c_mask = c_mask[:, izz]
# inputs = inputs[:, izz]
# context, _, c_mask, _ = self.encoder.encode(inputs)
# import pylab as plt
# # visualize_(plt.subplots(), Z[0][:, 300:], normal=False)
# visualize_(plt.subplots(), context[0], normal=False)
if 'explicit_loc' in self.config:
if self.config['explicit_loc']:
max_len = context.shape[1]
expLoc = np.eye(max_len, self.config['encode_max_len'], dtype='float32')[None, :, :]
expLoc = np.repeat(expLoc, context.shape[0], axis=0)
context = np.concatenate([context, expLoc], axis=2)
sample, score, ppp = self.decoder.get_sample(context, c_mask, inputs, **args)
if return_all:
return sample, score, ppp
if not args['stochastic']:
score = score / np.array([len(s) for s in sample])
idz = score.argmin()
sample = sample[idz]
score = score.min()
ppp = ppp[idz]
else:
score /= float(len(sample))
return sample, np.exp(score), ppp
评论列表
文章目录