def test_forward_declarations(self):
# Define a simple expression data structure
nlit = lambda x: {'op': 'lit', 'val': x}
nadd = lambda x, y: {'op': 'add', 'left': x, 'right': y}
nexpr = nadd(nadd(nlit(3.0), nlit(5.0)), nlit(2.0))
# Define a recursive block using forward declarations
expr_fwd = tdb.ForwardDeclaration(tdt.PyObjectType(),
tdt.TensorType((), 'float32'))
lit_case = tdb.GetItem('val') >> tdb.Scalar()
add_case = (tdb.Record({'left': expr_fwd(), 'right': expr_fwd()})
>> tdb.Function(tf.add))
expr = tdb.OneOf(lambda x: x['op'], {'lit': lit_case, 'add': add_case})
expr_fwd.resolve_to(expr)
self.assertBuilds(10.0, expr, nexpr, max_depth=2)
评论列表
文章目录