def show_examples_pytorch(model, es, rlut1, rlut2, embed2, mxlen, sample, prob_clip, max_examples, reverse):
si = np.random.randint(0, len(es))
batch_dict = es[si]
src_array = batch_dict['src']
tgt_array = batch_dict['dst']
src_len = batch_dict['src_len']
#src_array, tgt_array, src_len, _ = es[si]
if max_examples > 0:
max_examples = min(max_examples, src_array.size(0))
src_array = src_array[0:max_examples]
tgt_array = tgt_array[0:max_examples]
src_len = src_len[0:max_examples]
GO = embed2.vocab['<GO>']
EOS = embed2.vocab['<EOS>']
# TODO: fix this, check for GPU first
src_array = src_array.cuda()
for src_len,src_i,tgt_i in zip(src_len, src_array, tgt_array):
print('========================================================================')
sent = lookup_sentence(rlut1, src_i.cpu().numpy(), reverse=reverse)
print('[OP] %s' % sent)
sent = lookup_sentence(rlut2, tgt_i)
print('[Actual] %s' % sent)
dst_i = torch.zeros(1, mxlen).long()
#if use_gpu:
dst_i = dst_i.cuda()
next_value = GO
src_i = src_i.view(1, -1)
for j in range(mxlen):
dst_i[0,j] = next_value
probv = model((torch.autograd.Variable(src_i), torch.autograd.Variable(dst_i)))
output = probv.squeeze()[j]
if sample is False:
_, next_value = torch.max(output, 0)
next_value = int(next_value.data[0])
else:
probs = output.data.exp()
# This is going to zero out low prob. events so they are not
# sampled from
best, ids = probs.topk(prob_clip, 0, largest=True, sorted=True)
probs.zero_()
probs.index_copy_(0, ids, best)
probs.div_(torch.sum(probs))
fv = torch.multinomial(probs, 1)[0]
next_value = fv
if next_value == EOS:
break
sent = lookup_sentence(rlut2, dst_i.squeeze())
print('Guess: %s' % sent)
print('------------------------------------------------------------------------')
评论列表
文章目录