utils.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:nn-patterns 作者: pikinder 项目源码 文件源码
def get_conv_xy_all(layer, deterministic=True):
    w_np = layer.W.get_value()
    w = layer.W
    if layer.flip_filters:
        w = w[:, :, ::-1, ::-1]

    input_layer = layer.input_layer
    if layer.pad == 'same':
        input_layer = L.PadLayer(layer.input_layer,
                                 width=np.array(w_np.shape[2:])//2,
                                 batch_ndim=2)
    input_shape = L.get_output_shape(input_layer)
    output_shape = L.get_output_shape(layer)
    max_x = input_shape[2] - w_np.shape[2]+1
    max_y = input_shape[3] - w_np.shape[3]+1
    #print("input_shape shape: ", input_shape)
    #print("output_shape shape: ", output_shape,np.prod(output_shape[2:]))
    #print("pad: \"%s\""%layer.pad)
    #print(" stride: " ,layer.stride)
    #print("max_x %d max_y %d"%(max_x,max_y))
    x_orig = L.get_output(input_layer, deterministic=True)

    x = theano.tensor.nnet.neighbours.images2neibs(x_orig,
                                                   neib_shape=layer.filter_size,
                                                   neib_step=layer.stride,
                                                   mode='valid')
    x = T.reshape(x, (x_orig.shape[0], -1,
                      np.prod(output_shape[2:]), np.prod(w_np.shape[2:])))
    x = T.transpose(x, (0, 2, 1, 3))
    x = T.reshape(x, (-1, T.prod(x.shape[2:])))

    w = T.flatten(w, outdim=2).T  # D,O
    y = T.dot(x, w) # N,O
    if layer.b is not None:
        y += T.shape_padaxis(layer.b, axis=0)
    return x, y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号