def loss(self, inf_targets, inf_vads, targets, vads, mtl_fac):
'''
Loss definition
Only speech inference loss is defined and work quite well
Add VAD cross entropy loss if you want
'''
loss_v1 = tf.nn.l2_loss(inf_targets - targets) / self.batch_size
loss_o = loss_v1
reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
# ipdb.set_trace()
loss_v = loss_o + tf.add_n(reg_loss)
tf.scalar_summary('loss', loss_v)
# loss_merge = tf.cond(
# is_val, lambda: tf.scalar_summary('val_loss_batch', loss_v),
# lambda: tf.scalar_summary('loss', loss_v))
return loss_v, loss_o
# return tf.reduce_mean(tf.nn.l2_loss(inf_targets - targets))
SENN.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录