def _initialize_updated_shapes(self, session):
shapes = array_ops.shape_n(self._vars)
var_shapes = list(map(tuple, session.run(shapes)))
if self._var_shapes is not None:
new_old_shapes = zip(self._var_shapes, var_shapes)
if all([old == new for old, new in new_old_shapes]):
return
self._var_shapes = var_shapes
vars_and_shapes = zip(self._vars, self._var_shapes)
vars_and_shapes_dict = dict(vars_and_shapes)
packed_bounds = None
if self._var_to_bounds is not None:
left_packed_bounds = []
right_packed_bounds = []
for var, var_shape in vars_and_shapes:
shape = list(var_shape)
bounds = (-np.infty, np.infty)
if var in var_to_bounds:
bounds = var_to_bounds[var]
left_packed_bounds.extend(list(np.broadcast_to(bounds[0], shape).flat))
right_packed_bounds.extend(list(np.broadcast_to(bounds[1], shape).flat))
packed_bounds = list(zip(left_packed_bounds, right_packed_bounds))
self._packed_bounds = packed_bounds
self._update_placeholders = [
array_ops.placeholder(var.dtype) for var in self._vars
]
self._var_updates = [
var.assign(array_ops.reshape(placeholder, vars_and_shapes_dict[var]))
for var, placeholder in zip(self._vars, self._update_placeholders)
]
loss_grads = _compute_gradients(self._loss, self._vars)
equalities_grads = [
_compute_gradients(equality, self._vars)
for equality in self._equalities
]
inequalities_grads = [
_compute_gradients(inequality, self._vars)
for inequality in self._inequalities
]
self._packed_var = self._pack(self._vars)
self._packed_loss_grad = self._pack(loss_grads)
self._packed_equality_grads = [
self._pack(equality_grads) for equality_grads in equalities_grads
]
self._packed_inequality_grads = [
self._pack(inequality_grads) for inequality_grads in inequalities_grads
]
dims = [_prod(vars_and_shapes_dict[var]) for var in self._vars]
accumulated_dims = list(_accumulate(dims))
self._packing_slices = [
slice(start, end)
for start, end in zip(accumulated_dims[:-1], accumulated_dims[1:])
]
评论列表
文章目录