text_cnn.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:mxnet_tk1 作者: starimpact 项目源码 文件源码
def train_without_pretrained_embedding():
    x, y, vocab, vocab_inv = data_helpers.load_data()
    vocab_size = len(vocab)

    # randomly shuffle data
    np.random.seed(10)
    shuffle_indices = np.random.permutation(np.arange(len(y)))
    x_shuffled = x[shuffle_indices]
    y_shuffled = y[shuffle_indices]

    # split train/dev set
    x_train, x_dev = x_shuffled[:-1000], x_shuffled[-1000:]
    y_train, y_dev = y_shuffled[:-1000], y_shuffled[-1000:]
    print 'Train/Dev split: %d/%d' % (len(y_train), len(y_dev))
    print 'train shape:', x_train.shape
    print 'dev shape:', x_dev.shape
    print 'vocab_size', vocab_size

    batch_size = 50
    num_embed = 300
    sentence_size = x_train.shape[1]

    print 'batch size', batch_size
    print 'sentence max words', sentence_size
    print 'embedding size', num_embed

    cnn_model = setup_cnn_model(mx.gpu(0), batch_size, sentence_size, num_embed, vocab_size, dropout=0.5, with_embedding=False)
    train_cnn(cnn_model, x_train, y_train, x_dev, y_dev, batch_size)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号