def set_param_values(self, flattened_params, **tags):
debug = tags.pop("debug", False)
param_values = unflatten_tensors(
flattened_params, self.get_param_shapes(**tags))
ops = []
feed_dict = dict()
for param, dtype, value in zip(
self.get_params(**tags),
self.get_param_dtypes(**tags),
param_values):
if param not in self._cached_assign_ops:
assign_placeholder = tf.placeholder(dtype=param.dtype.base_dtype)
assign_op = tf.assign(param, assign_placeholder)
self._cached_assign_ops[param] = assign_op
self._cached_assign_placeholders[param] = assign_placeholder
ops.append(self._cached_assign_ops[param])
feed_dict[self._cached_assign_placeholders[param]] = value.astype(dtype)
if debug:
print("setting value of %s" % param.name)
tf.get_default_session().run(ops, feed_dict=feed_dict)
评论列表
文章目录