def loss(logits, depths, invalid_depths):
#304*228 55*74
#576*172 27*142
out_put_size = 55*74
logits_flat = tf.reshape(logits, [-1, out_put_size])
depths_flat = tf.reshape(depths, [-1, out_put_size])
invalid_depths_flat = tf.reshape(invalid_depths, [-1, out_put_size])
predict = tf.multiply(logits_flat, invalid_depths_flat)
target = tf.multiply(depths_flat, invalid_depths_flat)
d = tf.subtract(predict, target)
square_d = tf.square(d)
sum_square_d = tf.reduce_sum(square_d, 1)
sum_d = tf.reduce_sum(d, 1)
sqare_sum_d = tf.square(sum_d)
cost = tf.reduce_mean(sum_square_d / out_put_size - 0.5*sqare_sum_d / math.pow(out_put_size, 2))
tf.add_to_collection('losses', cost)
#return tf.add_n(tf.get_collection('losses'), name='total_loss')
评论列表
文章目录