blocks_test.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def test_forward_declarations(self):
    # Define a simple expression data structure
    nlit = lambda x: {'op': 'lit', 'val': x}
    nadd = lambda x, y: {'op': 'add', 'left': x, 'right': y}
    nexpr = nadd(nadd(nlit(3.0), nlit(5.0)), nlit(2.0))

    # Define a recursive block using forward declarations
    expr_fwd = tdb.ForwardDeclaration(tdt.PyObjectType(),
                                      tdt.TensorType((), 'float32'))
    lit_case = tdb.GetItem('val') >> tdb.Scalar()
    add_case = (tdb.Record({'left': expr_fwd(), 'right': expr_fwd()})
                >> tdb.Function(tf.add))
    expr = tdb.OneOf(lambda x: x['op'], {'lit': lit_case, 'add': add_case})
    expr_fwd.resolve_to(expr)

    self.assertBuilds(10.0, expr, nexpr, max_depth=2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号