def test_convolutional_embedding_encoder(config, out_data_shape, out_data_length, out_seq_len):
conv_embed = sockeye.encoder.ConvolutionalEmbeddingEncoder(config)
data_nd = mx.nd.random_normal(shape=(_BATCH_SIZE, _SEQ_LEN, _NUM_EMBED))
data = mx.sym.Variable("data", shape=data_nd.shape)
data_length = mx.sym.Variable("data_length", shape=_DATA_LENGTH_ND.shape)
(encoded_data,
encoded_data_length,
encoded_seq_len) = conv_embed.encode(data=data, data_length=data_length, seq_len=_SEQ_LEN)
exe = encoded_data.simple_bind(mx.cpu(), data=data_nd.shape)
exe.forward(data=data_nd)
assert exe.outputs[0].shape == out_data_shape
exe = encoded_data_length.simple_bind(mx.cpu(), data_length=_DATA_LENGTH_ND.shape)
exe.forward(data_length=_DATA_LENGTH_ND)
assert np.equal(exe.outputs[0].asnumpy(), np.asarray(out_data_length)).all()
assert encoded_seq_len == out_seq_len
评论列表
文章目录