def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, buckets, vocab_size, batch_size, seq2seq,
output_projection=None, softmax_loss_function=None, per_example_loss=False, name=None):
if len(encoder_inputs) < buckets[-1][0]:
raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
"st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
if len(targets) < buckets[-1][1]:
raise ValueError("Length of targets (%d) must be at least that of last"
"bucket (%d)." % (len(targets), buckets[-1][1]))
if len(weights) < buckets[-1][1]:
raise ValueError("Length of weights (%d) must be at least that of last"
"bucket (%d)." % (len(weights), buckets[-1][1]))
all_inputs = encoder_inputs + decoder_inputs + targets + weights
losses = []
outputs = []
encoder_states = []
with ops.name_scope(name, "model_with_buckets", all_inputs):
for j, bucket in enumerate(buckets):
with variable_scope.variable_scope(variable_scope.get_variable_scope(),
reuse=True if j > 0 else None):
bucket_outputs, decoder_states, encoder_state = seq2seq(encoder_inputs[:bucket[0]],
decoder_inputs[:bucket[1]])
outputs.append(bucket_outputs)
#print("bucket outputs: %s" %bucket_outputs)
encoder_states.append(encoder_state)
if per_example_loss:
losses.append(sequence_loss_by_example(
outputs[-1], targets[:bucket[1]], weights[:bucket[1]],
softmax_loss_function=softmax_loss_function))
else:
# losses.append(sequence_loss_by_mle(outputs[-1], targets[:bucket[1]], vocab_size, bucket[1], batch_size, output_projection))
losses.append(sequence_loss(outputs[-1], targets[:bucket[1]], weights[:bucket[1]], softmax_loss_function=softmax_loss_function))
return outputs, losses, encoder_states
评论列表
文章目录