def lr_grad_norm_avg(self):
# this is for enforcing lr * grad_norm not
# increasing dramatically in case of instability.
# Not necessary for basic use.
global_state = self._global_state
beta = self._beta
if "lr_grad_norm_avg" not in global_state:
global_state['grad_norm_squared_avg_log'] = 0.0
global_state['grad_norm_squared_avg_log'] = \
global_state['grad_norm_squared_avg_log'] * beta \
+ (1 - beta) * np.log(global_state['grad_norm_squared'] + eps)
if "lr_grad_norm_avg" not in global_state:
global_state["lr_grad_norm_avg"] = \
0.0 * beta + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
# we monitor the minimal smoothed ||lr * grad||
global_state["lr_grad_norm_avg_min"] = \
np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() )
else:
global_state["lr_grad_norm_avg"] = global_state["lr_grad_norm_avg"] * beta \
+ (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps)
global_state["lr_grad_norm_avg_min"] = \
min(global_state["lr_grad_norm_avg_min"],
np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) )
评论列表
文章目录