def __call__(self, learning_rate):
"""Update the learning rate according to the exponential decay
schedule.
"""
if self._count == 0.:
self._base_lr = learning_rate.get_vale()
self._count += 1
if not self._min_reached:
new_lr = self._base_lr * (self.decay_factor ** (-self._count))
if new_lr <= self.min_lr:
self._min_reached = True
new_lr = self._min_reached
else:
new_lr = self.min_lr
learning_rate.set_value(np.cast[theano.config.floatX](new_lr))
评论列表
文章目录