utils.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:nn-patterns 作者: pikinder 项目源码 文件源码
def get_conv_xy(layer, deterministic=True):
    w_np = layer.W.get_value()
    input_layer = layer.input_layer
    if layer.pad == 'same':
        input_layer = L.PadLayer(layer.input_layer,
                                 width=np.array(w_np.shape[2:])/2,
                                 batch_ndim=2)
    input_shape = L.get_output_shape(input_layer)
    max_x = input_shape[2] - w_np.shape[2]
    max_y = input_shape[3] - w_np.shape[3]
    srng = RandomStreams()
    patch_x = srng.random_integers(low=0, high=max_x)
    patch_y = srng.random_integers(low=0, high=max_y)

    #print("input_shape shape: ", input_shape)
    #print("pad: \"%s\""% (layer.pad,))
    #print(" stride: " ,layer.stride)
    #print("max_x %d max_y %d"%(max_x,max_y))

    x = L.get_output(input_layer, deterministic=deterministic)
    x = x[:, :,
          patch_x:patch_x + w_np.shape[2], patch_y:patch_y + w_np.shape[3]]
    x = T.flatten(x, 2)  # N,D

    w = layer.W
    if layer.flip_filters:
        w = w[:, :, ::-1, ::-1]
    w = T.flatten(w, outdim=2).T  # D,O
    y = T.dot(x, w) # N,O
    if layer.b is not None:
        y += T.shape_padaxis(layer.b, axis=0)
    return x, y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号