def tf_static_adem_score(context, model_response, reference_response):
rr_size, rr_dim = reference_response.get_shape().as_list()
mr_size, mr_dim = model_response.get_shape().as_list()
ct_size, ct_dim = context.get_shape().as_list()
with tf.control_dependencies(
[tf.assert_equal(rr_size, mr_size, message='responses size not equal'),
tf.assert_equal(ct_size, mr_size, message='context response size not equal')]):
score, M, N = compute_adem_score(
context, model_response, reference_response, mr_dim, ct_dim, rr_dim)
return score, M, N
评论列表
文章目录