def decode_model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, buckets, seq2seq,
softmax_loss_function=None,
per_example_loss=False,
name=None):
"""Create a sequence-to-sequence models with support for bucketing.
The seq2seq argument is a function that defines a sequence-to-sequence models,
e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24))
Args:
encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input.
targets: A list of 1D batch-sized int32 Tensors (desired output sequence).
weights: List of 1D batch-sized float-Tensors to weight the targets.
buckets: A list of pairs of (input size, output size) for each bucket.
seq2seq: A sequence-to-sequence models function; it takes 2 input that
agree with encoder_inputs and decoder_inputs, and returns a pair
consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
to be used instead of the standard softmax (the default if this is None).
per_example_loss: Boolean. If set, the returned loss will be a batch-sized
tensor of losses for each sequence in the batch. If unset, it will be
a scalar with the averaged loss from all examples.
name: Optional name for this operation, defaults to "model_with_buckets".
Returns:
A tuple of the form (outputs, losses), where:
outputs: The outputs for each bucket. Its j'th element consists of a list
of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs).
losses: List of scalar Tensors, representing losses for each bucket, or,
if per_example_loss is set, a list of 1D batch-sized float Tensors.
Raises:
ValueError: If length of encoder_inputsut, targets, or weights is smaller
than the largest (last) bucket.
"""
if len(encoder_inputs) < buckets[-1][0]:
raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
"st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
if len(targets) < buckets[-1][1]:
raise ValueError("Length of targets (%d) must be at least that of last"
"bucket (%d)." % (len(targets), buckets[-1][1]))
if len(weights) < buckets[-1][1]:
raise ValueError("Length of weights (%d) must be at least that of last"
"bucket (%d)." % (len(weights), buckets[-1][1]))
all_inputs = encoder_inputs + decoder_inputs + targets + weights
states = []
outputs = []
with ops.name_scope(name, "model_with_buckets", all_inputs):
for j, bucket in enumerate(buckets):
with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
bucket_outputs, bucket_states = seq2seq(encoder_inputs[:bucket[0]],
decoder_inputs[:bucket[1]])
states.append(bucket_states)
outputs.append(bucket_outputs)
return outputs, states
grl_seq2seq.py 文件源码
python
阅读 23
收藏 0
点赞 0
评论 0
评论列表
文章目录