def testOpsCopy(self):
with graph1.as_default():
#Initialize a basic expression y = ax + b
x = array_ops.placeholder("float")
a = variables.Variable(3.0)
b = constant_op.constant(4.0)
ax = math_ops.multiply(x, a)
y = math_ops.add(ax, b)
#Initialize session
sess1 = session_lib.Session()
#Initialize the Variable
variables.global_variables_initializer().run(session=sess1)
#First, initialize a as a Variable in graph2
a1 = copy_elements.copy_variable_to_graph(a, graph2)
#Initialize a1 in graph2
with graph2.as_default():
#Initialize session
sess2 = session_lib.Session()
#Initialize the Variable
variables.global_variables_initializer().run(session=sess2)
#Initialize a copy of y in graph2
y1 = copy_elements.copy_op_to_graph(y, graph2, [a1])
#Now that y has been copied, x must be copied too.
#Get that instance
x1 = copy_elements.get_copied_op(x, graph2)
#Compare values of y & y1 for a sample input
#and check if they match
v1 = y.eval({x: 5}, session=sess1)
v2 = y1.eval({x1: 5}, session=sess2)
assert v1 == v2
copy_test.py 文件源码
python
阅读 18
收藏 0
点赞 0
评论 0
评论列表
文章目录