def create_eval_metric(metric_name: AnyStr) -> mx.metric.EvalMetric:
"""
Creates an EvalMetric given a metric names.
"""
# output_names refers to the list of outputs this metric should use to update itself, e.g. the softmax output
if metric_name == C.ACCURACY:
return utils.Accuracy(ignore_label=C.PAD_ID, output_names=[C.SOFTMAX_OUTPUT_NAME])
elif metric_name == C.PERPLEXITY:
return mx.metric.Perplexity(ignore_label=C.PAD_ID, output_names=[C.SOFTMAX_OUTPUT_NAME])
else:
raise ValueError("unknown metric name")
评论列表
文章目录