build.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:deeppavlov 作者: deepmipt 项目源码 文件源码
def balance_dataset(dataset_0, labels_0, dataset_1, labels_1, ratio=1):
    """Balance the dataset_0 with samples from dataset_1 up to given ratio.

    Args:
        dataset_0: array of text samples
        labels_0: array of labels for dataset_0
        dataset_1: array of text samples
        labels_1: array of labels for dataset_1
        ratio: ratio of samples of class 1 to samples of class 0 (default 1.0)

    Returns:
        balanced array of text samples, corresponding array of labels
    """
    initial_train_size = dataset_0.shape[0]
    insult_inds = np.nonzero(labels_1)[0]
    num_insults_0 = len(np.nonzero(labels_0)[0])
    num_insults_1 = len(np.nonzero(labels_1)[0])
    insult_inds_to_add = insult_inds[np.random.randint(low=0, high=num_insults_1,
                                                       size=(ratio * (initial_train_size - num_insults_0) - num_insults_0))]
    result = dataset_0.append(dataset_1.iloc[insult_inds_to_add])
    result_labels = labels_0.append(labels_1.iloc[insult_inds_to_add])
    return result, result_labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号