test_encoder.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sockeye 作者: awslabs 项目源码 文件源码
def test_convolutional_embedding_encoder(config, out_data_shape, out_data_length, out_seq_len):
    conv_embed = sockeye.encoder.ConvolutionalEmbeddingEncoder(config)

    data_nd = mx.nd.random_normal(shape=(_BATCH_SIZE, _SEQ_LEN, _NUM_EMBED))

    data = mx.sym.Variable("data", shape=data_nd.shape)
    data_length = mx.sym.Variable("data_length", shape=_DATA_LENGTH_ND.shape)

    (encoded_data,
     encoded_data_length,
     encoded_seq_len) = conv_embed.encode(data=data, data_length=data_length, seq_len=_SEQ_LEN)

    exe = encoded_data.simple_bind(mx.cpu(), data=data_nd.shape)
    exe.forward(data=data_nd)
    assert exe.outputs[0].shape == out_data_shape

    exe = encoded_data_length.simple_bind(mx.cpu(), data_length=_DATA_LENGTH_ND.shape)
    exe.forward(data_length=_DATA_LENGTH_ND)
    assert np.equal(exe.outputs[0].asnumpy(), np.asarray(out_data_length)).all()

    assert encoded_seq_len == out_seq_len
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号