lightsout_twisted.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:latplan 作者: guicho271828 项目源码 文件源码
def build_errors(states,base,pad,dim,size):
    # address the numerical viscosity in swirling
    s = K.round(states+viscosity_adjustment)
    s = Reshape((dim+2*pad,dim+2*pad,1))(s)
    s = Cropping2D(((pad,pad),(pad,pad)))(s)
    s = K.reshape(s,[-1,size,base,size,base])
    s = K.permute_dimensions(s, [0,1,3,2,4])
    s = K.reshape(s,[-1,size,size,1,base,base])
    s = K.tile   (s,[1, 1, 1, 2, 1, 1,]) # number of panels : 2

    allpanels = K.variable(panels)
    allpanels = K.reshape(allpanels, [1,1,1,2,base,base])
    allpanels = K.tile(allpanels, [K.shape(s)[0], size,size, 1, 1, 1])

    def hash(x):
        ## 2x2 average hashing
        x = K.reshape(x, [-1,size,size,2, base//3, 3, base//3, 3])
        x = K.mean(x, axis=(5,7))
        return K.round(x)
        ## diff hashing (horizontal diff)
        # x1 = x[:,:,:,:,:,:-1]
        # x2 = x[:,:,:,:,:,1:]
        # d = x1 - x2
        # return K.round(d)
        ## just rounding
        # return K.round(x)
        ## do nothing
        # return x

    # s         = hash(s)
    # allpanels = hash(allpanels)

    # error = K.binary_crossentropy(s, allpanels)
    error = K.abs(s - allpanels)
    error = hash(error)
    error = K.mean(error, axis=(4,5))
    return error
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号