trainer.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def build_computation_graphs(self):
        self.model.declare_params(self.param_init_function)

        self.tf_nodes = {}
        to_build = {k:v for k, v in self.model_hypers_to_build_graph.iteritems() 
                        if k in self.data.get_hypers_names()}

        for model_hypers, build_graph in to_build.iteritems():
            print ("Construct forward graph... ", end="")

            forward_time_start = time.time()
            inputs, outputs = build_graph(self.model)
            loss, display_loss, output_placeholders, mask_placeholders, loss_nodes = \
                self.construct_loss(outputs)
            print ("done in %.2fs." % (time.time() - forward_time_start))

            optimizer = self.make_optimizer()

            gradient_time_start = time.time()
            print ("Construct gradient graph... ", end="")
            grads_and_vars = self.compute_gradients(optimizer, loss)
            print ("done in %.2fs." % (time.time() - gradient_time_start))

            gradient_apply_time_start = time.time()
            print ("Construct apply gradient graph... ", end="")
            train_op = self.apply_update(optimizer, grads_and_vars)
            print ("done in %.2fs." % (time.time() - gradient_apply_time_start))

            if self.do_debug:
                check_time_start = time.time()
                print ("Construct check numerics graph... ", end="")
                self.check_ops.append(tf.add_check_numerics_ops())
                print ("done in %.2fs." % (time.time() - check_time_start))

            if self.make_log:
                self.summary_nodes["train"] = tf.scalar_summary('train_loss', display_loss)
                self.summary_nodes["validate"] = tf.scalar_summary('validate_loss', display_loss)
                self.summary_nodes["params"] = []
                for p_name, p_node in self.model.params.iteritems():
                    n_elements = p_node.get_shape()[0].value
                    for i in range(n_elements):
                        self.summary_nodes["params"].append(
                            tf.scalar_summary('%s/%i' % (p_name, i), p_node[i]))


            placeholders = {}
            placeholders.update(inputs)
            placeholders.update(output_placeholders)
            placeholders.update(mask_placeholders)
            self.tf_nodes[model_hypers] = {
                "inputs": inputs,
                "outputs": outputs,
                "placeholders": placeholders,
                "loss_nodes": loss_nodes,
                "loss": loss,
                "display_loss": display_loss,
                "grads_and_vars": grads_and_vars,
                "train_op": train_op
            }
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号