def create_ctc_loss(logits, labels, timesteps, label_seq_lengths):
with tf.variable_scope('CTC_Loss'):
print()
print("Labels shape")
print(labels)
print()
print("Logits shape")
print(logits)
print()
print("Labels len shape")
print(label_seq_lengths)
# logits = tf.Print(logits, [logits], "Logits")
ctc_loss = tf.nn.ctc_loss(labels,
logits,
timesteps)
cost = tf.reduce_mean(ctc_loss, name='ctc')
# The total loss is defined as the cross entropy loss plus all of the weight
# decay terms (L2 loss).
return cost
评论列表
文章目录