seq2seq_model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: eske 项目源码 文件源码
def step(self, data, update_model=True, align=False, use_sgd=False, **kwargs):
        if update_model:
            self.dropout_on.run()
        else:
            self.dropout_off.run()

        encoder_inputs, targets, input_length = self.get_batch(data)
        input_feed = {self.targets: targets}

        for i in range(len(self.encoders)):
            input_feed[self.encoder_inputs[i]] = encoder_inputs[i]
            input_feed[self.encoder_input_length[i]] = input_length[i]

        output_feed = {'loss': self.xent_loss}
        if update_model:
            output_feed['update'] = self.update_ops.xent[1] if use_sgd else self.update_ops.xent[0]
        if align:
            output_feed['weights'] = self.attention_weights

        res = tf.get_default_session().run(output_feed, input_feed)
        return namedtuple('output', 'loss weights')(res['loss'], res.get('weights'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号