grl_seq2seq.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Deep-Reinforcement-Learning-for-Dialogue-Generation-in-tensorflow 作者: liuyuemaicha 项目源码 文件源码
def decode_model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, buckets, seq2seq,
                              softmax_loss_function=None,
                              per_example_loss=False,
                              name=None):
    """Create a sequence-to-sequence models with support for bucketing.

    The seq2seq argument is a function that defines a sequence-to-sequence models,
    e.g., seq2seq = lambda x, y: basic_rnn_seq2seq(x, y, rnn_cell.GRUCell(24))

    Args:
      encoder_inputs: A list of Tensors to feed the encoder; first seq2seq input.
      decoder_inputs: A list of Tensors to feed the decoder; second seq2seq input.
      targets: A list of 1D batch-sized int32 Tensors (desired output sequence).
      weights: List of 1D batch-sized float-Tensors to weight the targets.
      buckets: A list of pairs of (input size, output size) for each bucket.
      seq2seq: A sequence-to-sequence models function; it takes 2 input that
        agree with encoder_inputs and decoder_inputs, and returns a pair
        consisting of outputs and states (as, e.g., basic_rnn_seq2seq).
      softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch
        to be used instead of the standard softmax (the default if this is None).
      per_example_loss: Boolean. If set, the returned loss will be a batch-sized
        tensor of losses for each sequence in the batch. If unset, it will be
        a scalar with the averaged loss from all examples.
      name: Optional name for this operation, defaults to "model_with_buckets".

    Returns:
      A tuple of the form (outputs, losses), where:
        outputs: The outputs for each bucket. Its j'th element consists of a list
          of 2D Tensors of shape [batch_size x num_decoder_symbols] (jth outputs).
        losses: List of scalar Tensors, representing losses for each bucket, or,
          if per_example_loss is set, a list of 1D batch-sized float Tensors.

    Raises:
      ValueError: If length of encoder_inputsut, targets, or weights is smaller
        than the largest (last) bucket.
    """
    if len(encoder_inputs) < buckets[-1][0]:
        raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
                         "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
    if len(targets) < buckets[-1][1]:
        raise ValueError("Length of targets (%d) must be at least that of last"
                         "bucket (%d)." % (len(targets), buckets[-1][1]))
    if len(weights) < buckets[-1][1]:
        raise ValueError("Length of weights (%d) must be at least that of last"
                         "bucket (%d)." % (len(weights), buckets[-1][1]))

    all_inputs = encoder_inputs + decoder_inputs + targets + weights
    states = []
    outputs = []
    with ops.name_scope(name, "model_with_buckets", all_inputs):
        for j, bucket in enumerate(buckets):
            with variable_scope.variable_scope(variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
                bucket_outputs, bucket_states = seq2seq(encoder_inputs[:bucket[0]],
                                                        decoder_inputs[:bucket[1]])
                states.append(bucket_states)
                outputs.append(bucket_outputs)

    return outputs, states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号