net.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:bigan 作者: jeffdonahue 项目源码 文件源码
def test_multifc(n=5, b=100, d_in=500, d_out=1000, n_trials=1000):
    x = [Output(T.matrix(), shape=(b, d_in)) for _ in xrange(n)]
    x_in = [xi.value for xi in x]
    x_sample = [np.asarray(np.random.rand(*xi.shape), dtype=xi.value.dtype)
                for xi in x]
    # method A: concat then multiply
    N_a = Net(name='A')
    x_cat = N_a.Concat(*x, axis=1)
    y_a = N_a.FC(x_cat, nout=d_out)
    f_a = theano.function(x_in, y_a.value)
    time_a = Timer(partial(f_a, *x_sample))
    # method B: multiply each one then sum results
    N_b = Net(name='B')
    ys = [N_b.FC(xi, nout=d_out) for xi in x]
    y_b = N_b.EltwiseSum(*ys)
    f_b = theano.function(x_in, y_b.value)
    time_b = Timer(partial(f_b, *x_sample))
    # time them
    print 'Time A:', time_a.timeit(number=n_trials)
    print 'Time B:', time_b.timeit(number=n_trials)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号