def loss(preturns, lambda_preturn, labels):
with tf.variable_scope('loss'):
preturns_loss = tf.reduce_mean(
tf.squared_difference(preturns, tf.expand_dims(labels, 1)))
lambda_preturn_loss = tf.reduce_mean(
tf.squared_difference(lambda_preturn, labels))
consistency_loss = tf.reduce_mean(
tf.squared_difference(
preturns, tf.stop_gradient(tf.expand_dims(lambda_preturn, 1))))
l2_loss = tf.get_collection('losses')
total_loss = preturns_loss + lambda_preturn_loss + consistency_loss
consistency_loss += l2_loss
return total_loss, consistency_loss
评论列表
文章目录