test_basic.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_batched_dot():
    first = theano.tensor.tensor3("first")
    second = theano.tensor.tensor3("second")
    output = theano.tensor.basic.batched_dot(first, second)
    first_val = numpy.random.rand(10, 10, 20).astype(config.floatX)
    second_val = numpy.random.rand(10, 20, 5).astype(config.floatX)
    result_fn = theano.function([first, second], output)
    result = result_fn(first_val, second_val)
    assert result.shape[0] == first_val.shape[0]
    assert result.shape[1] == first_val.shape[1]
    assert result.shape[2] == second_val.shape[2]

    first_mat = theano.tensor.dmatrix("first")
    second_mat = theano.tensor.dmatrix("second")
    output = theano.tensor.basic.batched_dot(first_mat, second_mat)
    first_mat_val = numpy.random.rand(10, 10).astype(config.floatX)
    second_mat_val = numpy.random.rand(10, 10).astype(config.floatX)
    result_fn = theano.function([first_mat, second_mat], output)
    result = result_fn(first_mat_val, second_mat_val)

    assert result.shape[0] == first_mat_val.shape[0]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号