transform_test.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def test_transform(self):
    transformer = ge.Transformer()

    def my_transform_op_handler(info, op):
      add_noise = op.name.startswith("Add")
      op_, op_outputs_ = ge.transform.copy_op_handler(info, op)
      if not add_noise:
        return op_, op_outputs_
      # add some noise to op
      with info.graph_.as_default():
        t_ = math_ops.add(
            constant_op.constant(1.0, shape=[10], name="Noise"),
            op_.outputs[0],
            name="AddNoise")
      # return the "noisy" op
      return op_, [t_]

    transformer.transform_op_handler = my_transform_op_handler

    graph = ops.Graph()
    transformer(self.graph, graph, "", "")
    matcher0 = ge.OpMatcher("AddNoise").input_ops(
        "Noise", ge.OpMatcher("Add").input_ops("Const", "Input"))
    matcher1 = ge.OpMatcher("AddNoise_1").input_ops(
        "Noise_1", ge.OpMatcher("Add_1").input_ops("Const_1", matcher0))
    matcher2 = ge.OpMatcher("AddNoise_2").input_ops(
        "Noise_2", ge.OpMatcher("Add_2").input_ops("Const_2", matcher1))
    top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
    self.assertTrue(matcher2(top))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号