main_hyperparams.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:cnn-lstm-bilstm-deepcnn-clstm-in-pytorch 作者: bamtercelboo 项目源码 文件源码
def mrs_five_mui(path, train_name, dev_name, test_name, char_data, text_field, label_field, static_text_field, static_label_field, **kargs):
    train_data, dev_data, test_data = mydatasets_self_five.MR.splits(path, train_name, dev_name, test_name,
                                                                     char_data, text_field, label_field)
    static_train_data, static_dev_data, static_test_data = mydatasets_self_five.MR.splits(path, train_name, dev_name,
                                                                                          test_name,
                                                                                          char_data,
                                                                                         static_text_field,
                                                                                          static_label_field)
    print("len(train_data) {} ".format(len(train_data)))
    print("len(train_data) {} ".format(len(static_train_data)))
    text_field.build_vocab(train_data, min_freq=args.min_freq)
    label_field.build_vocab(train_data)
    static_text_field.build_vocab(static_train_data, static_dev_data, static_test_data, min_freq=args.min_freq)
    static_label_field.build_vocab(static_train_data, static_dev_data, static_test_data)
    train_iter, dev_iter, test_iter = data.Iterator.splits(
                                        (train_data, dev_data, test_data),
                                        batch_sizes=(args.batch_size,
                                                     len(dev_data),
                                                     len(test_data)),
                                        **kargs)
    return train_iter, dev_iter, test_iter


# load MR dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号