convert_to_records.py 文件源码

python
阅读 56 收藏 0 点赞 0 评论 0

项目:LSTM_PIT 作者: snsun 项目源码 文件源码
def convert_cmvn_to_numpy(inputs_cmvn, labels_cmvn):
    """Convert global binary ark cmvn to numpy format."""
    tf.logging.info("Convert %s and %s to numpy format" % (
        inputs_cmvn, labels_cmvn))
    inputs_filename = os.path.join(FLAGS.data_dir, inputs_cmvn + '.cmvn')
    labels_filename = os.path.join(FLAGS.data_dir, labels_cmvn + '.cmvn')

    inputs = read_binary_file(inputs_filename, 0)
    labels = read_binary_file(labels_filename, 0)

    inputs_frame = inputs[0][-1]
    labels_frame = labels[0][-1]

    assert inputs_frame == labels_frame

    cmvn_inputs = np.hsplit(inputs, [inputs.shape[1]-1])[0]
    cmvn_labels = np.hsplit(labels, [labels.shape[1]-1])[0]

    mean_inputs = cmvn_inputs[0] / inputs_frame
    stddev_inputs = np.sqrt(cmvn_inputs[1] / inputs_frame - mean_inputs ** 2)
    mean_labels = cmvn_labels[0] / labels_frame
    stddev_labels = np.sqrt(cmvn_labels[1] / labels_frame - mean_labels ** 2)

    cmvn_name = os.path.join(FLAGS.output_dir, "train_cmvn.npz")
    np.savez(cmvn_name,
             mean_inputs=mean_inputs,
             stddev_inputs=stddev_inputs,
             mean_labels=mean_labels,
             stddev_labels=stddev_labels)

    tf.logging.info("Write to %s" % cmvn_name)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号