def __init__(self, inputs, outputs, updates=[], defaults={},
training=None):
self.training = training
# ====== validate input ====== #
if isinstance(inputs, Mapping):
self.inputs_name = inputs.keys()
inputs = inputs.values()
elif not isinstance(inputs, (tuple, list)):
inputs = [inputs]
self.inputs = flatten_list(inputs, level=None)
if not hasattr(self, 'inputs_name'):
self.inputs_name = [i.name.split(':')[0] for i in self.inputs]
# ====== defaults ====== #
defaults = dict(defaults)
self.defaults = defaults
# ====== validate outputs ====== #
return_list = True
if not isinstance(outputs, (tuple, list)):
outputs = (outputs,)
return_list = False
self.outputs = flatten_list(list(outputs), level=None)
self.return_list = return_list
# ====== validate updates ====== #
if isinstance(updates, Mapping):
updates = updates.items()
with tf.control_dependencies(self.outputs):
# create updates ops
if not isinstance(updates, tf.Operation):
updates_ops = []
for update in updates:
if isinstance(update, (tuple, list)):
p, new_p = update
updates_ops.append(tf.assign(p, new_p))
else: # assumed already an assign op
updates_ops.append(update)
self.updates_ops = tf.group(*updates_ops)
else: # already an tensorflow Ops
self.updates_ops = updates
评论列表
文章目录