seq2seq_utils.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def next_inputs(self, time, outputs, state, sample_ids, name=None):
        with tf.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
                           [time, outputs, state, sample_ids]):
            (finished, base_next_inputs, state) = (
                super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
                    time=time,
                    outputs=outputs,
                    state=state,
                    sample_ids=sample_ids,
                    name=name))

            def maybe_sample():
                """Perform scheduled sampling."""
                where_sampling = tf.cast(
                    tf.where(sample_ids > -1), tf.int32)
                where_not_sampling = tf.cast(
                    tf.where(sample_ids <= -1), tf.int32)
                where_sampling_flat = tf.reshape(where_sampling, [-1])
                where_not_sampling_flat = tf.reshape(
                    where_not_sampling, [-1])
                sample_ids_sampling = tf.gather(
                    sample_ids, where_sampling_flat)
                inputs_not_sampling = tf.gather(
                    base_next_inputs, where_not_sampling_flat)
                sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
                base_shape = tf.shape(base_next_inputs)
                return (tf.scatter_nd(indices=where_sampling,
                                      updates=sampled_next_inputs,
                                      shape=base_shape)
                        + tf.scatter_nd(indices=where_not_sampling,
                                        updates=inputs_not_sampling,
                                        shape=base_shape))

            all_finished = tf.reduce_all(finished)
            next_inputs = tf.cond(
                all_finished, lambda: base_next_inputs, maybe_sample)
            return (finished, next_inputs, state)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号