def __init__(self, inputs, outputs, updates=[]):
assert type(inputs) in {list, tuple}, 'Input to a TensorFlow backend function should be a list or tuple.'
assert type(outputs) in {list, tuple}, 'Output to a TensorFlow backend function should be a list or tuple.'
assert type(updates) in {list, tuple}, 'Updates in a TensorFlow backend function should be a list or tuple.'
self.inputs = list(inputs)
self.outputs = list(outputs)
with tf.control_dependencies(self.outputs):
updates_ops = []
for update in updates:
if type(update) is tuple:
p, new_p = update
updates_ops.append(tf.assign(p, new_p))
else:
# assumed already an op
updates_ops.append(update)
self.updates_op = tf.group(*updates_ops)
评论列表
文章目录