trainer.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:TerpreT 作者: 51alg 项目源码 文件源码
def make_test_node(self, hypers_name):
        outputs = self.tf_nodes[hypers_name]["outputs"]

        deltas = []
        for var_name, output_node in outputs.iteritems():
            data_node = self.tf_nodes[hypers_name]["placeholders"][var_name]
            output_rank = output_node.get_shape().ndims
            if output_rank == 1:
                output_node = tf.tile(tf.expand_dims(output_node, 0), [tf.shape(data_node)[0], 1])
            deltas.append(
                tf.to_int32(tf.argmax(output_node, dimension=1)) - data_node)

        zero_if_correct = tf.reduce_sum(tf.pack(deltas), reduction_indices=0)
        zero_elements = tf.equal(zero_if_correct, tf.zeros_like(zero_if_correct))
        n_correct = tf.reduce_sum(tf.to_int32(zero_elements))
        n_total = tf.shape(zero_if_correct)[0]
        accuracy = tf.truediv(n_correct, n_total)
        self.summary_nodes["test"] = tf.scalar_summary('test_accuracy', accuracy)
        self.tf_nodes[hypers_name]["accuracy"] = accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号