def calfilter(X):
'''X is nbatch*boxheight*boxwidth image. k1 and k2 is the nbatch*(boxheight*boxwidth)*(boxheight*boxwidth)
filters. Here we only consider 4 neigbor regeion.'''
k1 = np.zeros((X.shape[0], X.shape[1], X.shape[2], X.shape[1], X.shape[2]))
k2 = np.zeros((X.shape[0], X.shape[1], X.shape[2], X.shape[1], X.shape[2]))
for i in range(X.shape[1]):
for j in range(X.shape[2]):
if i != 0:
k1[:,i,j,i-1,j] = 1
k2[:,i,j,i-1,j] = np.exp(-(X[:,i,j]-X[:,i-1,j])**2)
if i != X.shape[1]-1:
k1[:,i,j,i+1,j] = 1
k2[:,i,j,i+1,j] = np.exp(-(X[:,i,j]-X[:,i+1,j])**2)
if j != 0:
k1[:,i,j,i,j-1] = 1
k2[:,i,j,i,j-1] = np.exp(-(X[:,i,j]-X[:,i,j-1])**2)
if j != X.shape[2]-1:
k1[:,i,j,i,j+1] = 1
k2[:,i,j,i,j+1] = np.exp(-(X[:,i,j]-X[:,i,j+1])**2)
k1 = k1.reshape((X.shape[0], X.shape[1]*X.shape[2], X.shape[1]*X.shape[2]))
k2 = k2.reshape((X.shape[0], X.shape[1]*X.shape[2], X.shape[1]*X.shape[2]))
return k1, k2
utils_combine.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录