def loss_wrapper(y, y_, loss_function, transitions=None, nums_tags=None, batch_size=None, weights=None, average_cross_steps=True):
assert len(y) == len(y_)
total_loss = []
if loss_function is crf_loss:
#print len(y), len(transitions), len(nums_tags)
assert len(y) == len(transitions) and len(transitions) == len(nums_tags) and batch_size is not None
for sy, sy_, stranstion, snums_tags in zip(y, y_, transitions, nums_tags):
total_loss.append(loss_function(sy, sy_, stranstion, snums_tags, batch_size))
elif loss_function is cross_entropy:
assert len(y) == len(nums_tags)
for sy, sy_, snums_tags in zip(y, y_, nums_tags):
total_loss.append(loss_function(sy, sy_, snums_tags))
elif loss_function is sparse_cross_entropy:
for sy, sy_ in zip(y, y_):
total_loss.append(loss_function(sy, sy_))
elif loss_function is sparse_cross_entropy_with_weights:
assert len(y) == len(nums_tags)
for sy, sy_, snums_tags in zip(y, y_):
total_loss.append(tf.reshape(loss_function(sy, sy_, weights=weights, average_cross_steps=average_cross_steps), [-1]))
else:
for sy, sy_ in zip(y, y_):
total_loss.append(tf.reshape(loss_function(sy, sy_), [-1]))
return tf.stack(total_loss)
评论列表
文章目录