seq2seq_utils.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def next_inputs(self, time, outputs, state, sample_ids, name=None):
        with tf.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
                           [time, outputs, state, sample_ids]):
            (finished, base_next_inputs, state) = (
                super(ScheduledOutputTrainingHelper, self).next_inputs(
                    time=time,
                    outputs=outputs,
                    state=state,
                    sample_ids=sample_ids,
                    name=name))

            def maybe_sample():
                """Perform scheduled sampling."""

                def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
                    """Concatenate outputs with auxiliary inputs, if they exist."""
                    if self._auxiliary_input_tas is None:
                        return outputs_

                    next_time = time + 1
                    auxiliary_inputs = nest.map_structure(
                        lambda ta: ta.read(next_time), self._auxiliary_input_tas)
                    if indices is not None:
                        auxiliary_inputs = tf.gather_nd(
                            auxiliary_inputs, indices)
                    return nest.map_structure(
                        lambda x, y: tf.concat((x, y), -1),
                        outputs_, auxiliary_inputs)

                if self._next_input_layer is None:
                    return tf.where(
                        sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
                        base_next_inputs)

                where_sampling = tf.cast(
                    tf.where(sample_ids), tf.int32)
                where_not_sampling = tf.cast(
                    tf.where(tf.logical_not(sample_ids)), tf.int32)
                outputs_sampling = tf.gather_nd(outputs, where_sampling)
                inputs_not_sampling = tf.gather_nd(base_next_inputs,
                                                   where_not_sampling)
                sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
                    self._next_input_layer(outputs_sampling), where_sampling)

                base_shape = tf.shape(base_next_inputs)
                return (tf.scatter_nd(indices=where_sampling,
                                      updates=sampled_next_inputs,
                                      shape=base_shape)
                        + tf.scatter_nd(indices=where_not_sampling,
                                        updates=inputs_not_sampling,
                                        shape=base_shape))

            all_finished = tf.reduce_all(finished)
            next_inputs = tf.cond(
                all_finished, lambda: base_next_inputs, maybe_sample)
            return (finished, next_inputs, state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号