def generate_multiple(self, inputs, mode='display', return_attend=False, return_all=True, return_encoding=False):
# assert self.config['sample_stoch'], 'RNNLM sampling must be stochastic'
# assert not self.config['sample_argmax'], 'RNNLM sampling cannot use argmax'
args = dict(k=self.config['sample_beam'],
maxlen=self.config['max_len'],
stochastic=self.config['sample_stoch'] if mode == 'display' else None,
argmax=self.config['sample_argmax'] if mode == 'display' else None,
return_attend=return_attend,
type=self.config['predict_type']
)
'''
Return the encoding of input.
Similar to encoder.encode(), but gate values are returned as well
I think only gtenc with attention
default: with_context=False, return_sequence=True, return_embed=True
'''
"""
return
context: a list of vectors [nb_sample, max_len, 2*enc_hidden_dim], encoding of each time state (concatenate both forward and backward RNN)
_: embedding of text X [nb_sample, max_len, enc_embedd_dim]
c_mask: mask, an array showing which elements in context are not 0 [nb_sample, max_len]
_: encoding of end of X, seems not make sense for bidirectional model (head+tail) [nb_sample, 2*enc_hidden_dim]
Z: value of update gate, shape=(nb_sample, 1)
R: value of update gate, shape=(nb_sample, 1)
but.. Z and R are not used here
"""
context, _, c_mask, _, Z, R = self.encoder.gtenc(inputs)
# c_mask[0, 3] = c_mask[0, 3] * 0
# L = context.shape[1]
# izz = np.concatenate([np.arange(3), np.asarray([1,2]), np.arange(3, L)])
# context = context[:, izz, :]
# c_mask = c_mask[:, izz]
# inputs = inputs[:, izz]
# context, _, c_mask, _ = self.encoder.encode(inputs)
# import pylab as plt
# # visualize_(plt.subplots(), Z[0][:, 300:], normal=False)
# visualize_(plt.subplots(), context[0], normal=False)
if 'explicit_loc' in self.config: # no
if self.config['explicit_loc']:
max_len = context.shape[1]
expLoc = np.eye(max_len, self.config['encode_max_len'], dtype='float32')[None, :, :]
expLoc = np.repeat(expLoc, context.shape[0], axis=0)
context = np.concatenate([context, expLoc], axis=2)
sample, score, ppp, output_encoding = self.decoder.get_sample(context, c_mask, inputs, **args)
if return_all:
if return_encoding:
return context, sample, score, output_encoding
else:
return sample, score
return sample, score
评论列表
文章目录