model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tfplus 作者: renmengye 项目源码 文件源码
def build_all(self, param_avg=False):
        """Build all nodes."""
        if self._has_built_all:
            raise Exception('Only call build_all or build_eval once.')
        self._has_built_all = True
        with tf.device(self.get_device_fn()):
            with tf.variable_scope(self.name):
                inp_var = self.build_input()
                output_var = self.build(inp_var)
                loss_var = self.build_loss(inp_var, output_var)
                train_step = self.build_optim(loss_var)
                if param_avg:
                    ema_op, avg_var = self.get_average_var()
                    self._avg_var = avg_var
                    with tf.control_dependencies([train_step, ema_op]):
                        train_step = tf.no_op(name='train_step')
                self.register_var('train_step', train_step)
        return self
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号