def _centered_bias_step(logits_dimension, weight_collection, labels,
train_loss_fn):
"""Creates and returns training op for centered bias."""
centered_bias = ops.get_collection(weight_collection)
batch_size = array_ops.shape(labels)[0]
logits = array_ops.reshape(
array_ops.tile(centered_bias[0], [batch_size]),
[batch_size, logits_dimension])
with ops.name_scope(None, "centered_bias", (labels, logits)):
centered_bias_loss = math_ops.reduce_mean(
train_loss_fn(logits, labels), name="training_loss")
# Learn central bias by an optimizer. 0.1 is a convervative lr for a
# single variable.
return training.AdagradOptimizer(0.1).minimize(
centered_bias_loss, var_list=centered_bias)
评论列表
文章目录