decoders.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:text-gan-tensorflow 作者: tokestermw 项目源码 文件源码
def gumbel_decoder_fn(encoder_state, embedding_matrix, output_projections, maximum_length,
                      start_of_sequence_id=2, end_of_sequence_id=3, temperature=1.0,
                      name=None):

    with tf.name_scope(name, "gumbel_decoder_fn", [
            encoder_state, embedding_matrix, output_projections, maximum_length,
            start_of_sequence_id, end_of_sequence_id, temperature]) as scope:
        batch_size = tf.shape(encoder_state)[0]
        W, b = output_projections
        start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, tf.int32)
        end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, tf.int32)
        temperature = ops.convert_to_tensor(temperature, tf.float32)
        maximum_length = ops.convert_to_tensor(maximum_length, tf.int32)

    def _decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        with ops.name_scope(name, "gumbel_decoder_fn",
                            [time, cell_state, cell_input, cell_output, context_state]):
            if cell_input is not None:
                # -- not None if training
                raise ValueError("Expected cell_input to be None, but saw: %s" %
                                 cell_input)
            if cell_output is None:
                # -- initial values
                next_done = array_ops.zeros([batch_size, ], dtype=dtypes.bool)
                next_cell_state = encoder_state
                next_cell_input = tf.reshape(tf.tile(embedding_matrix[start_of_sequence_id], [batch_size]),
                                             shape=tf.shape(encoder_state))
                emit_output = cell_output
                next_context_state = context_state

            else:
                # -- transition function
                with ops.name_scope(name, "gumbel_output_fn", [W, b, cell_output, end_of_sequence_id, temperature]):
                    # -- output projection parameters usually used for output logits prior to softmax
                    output_logits = tf.add(tf.matmul(cell_output, W), b)  # [B, H] * [H, V] + [V] -> [B, V]

                    # -- stopping criterion if argmax is
                    output_argmax = tf.cast(tf.argmax(output_logits, axis=1), tf.int32)
                    next_done = tf.equal(output_argmax, end_of_sequence_id)

                    # -- sample from gumbel softmax (aka concrete) distribution, higher the temperature the spikier
                    output_probs = gumbel_softmax(output_logits, temperature=temperature, hard=False)

                    # soft embeddings for the next input
                    next_cell_input = tf.matmul(output_probs, embedding_matrix)  # [B, V] * [V, H] -> [B, H]

                next_cell_state = cell_state
                emit_output = cell_output
                next_context_state = context_state

            next_done = control_flow_ops.cond(math_ops.greater(time, maximum_length),
                                         lambda: array_ops.ones([batch_size, ], dtype=dtypes.bool),
                                         lambda: next_done)

        return next_done, next_cell_state, next_cell_input, emit_output, next_context_state

    return _decoder_fn
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号