def predict_seq2seq(sess, model, decoding, data, decode_len, dr=None, argmax=True, batch_size=100, ensemble=False, verbose=False):
num_items = len(data)
in_len = len(data[0][0])
input_v = model[:num_items*in_len + decode_len]
input_v.append(decoding)
if dr is not None:
input_v.append(dr)
predictions = model[num_items*in_len + decode_len:]
output = []
samples = zip(*data)
start_idx = 0
n_samples = len(samples)
while start_idx < n_samples:
if verbose:
print '%d' % (start_idx * 100 / n_samples) + '%'
next_batch_input = samples[start_idx:start_idx + batch_size]
batch_size = len(next_batch_input)
holders = []
next_batch_input = zip(*next_batch_input)
for n_batch in next_batch_input:
n_batch = np.asarray(n_batch).T
for b in n_batch:
holders.append(b)
for i in range(decode_len):
holders.append(np.zeros(batch_size, dtype='int32'))
holders.append(True)
if dr is not None:
holders.append(0.0)
if argmax:
pre = sess.run(predictions, feed_dict={i: h for i, h in zip(input_v, holders)})
pre = [np.argmax(pre_t, axis=1) for pre_t in pre]
pre = np.asarray(pre).T.tolist()
pre = [np.trim_zeros(pre_t) for pre_t in pre]
output += pre
else:
pre = sess.run(predictions, feed_dict={i: h for i, h in zip(input_v, holders)})
output += pre
start_idx += batch_size
return output
评论列表
文章目录