def grid_data(source, grid=32, crop=16, expand=12):
gridsize = grid + 2 * expand
stacksize = source.shape[0]
height = source.shape[3] # should be 224 for our data
width = source.shape[4]
gridheight = (height - 2 * crop) // grid # should be 6 for our data
gridwidth = (width - 2 * crop) // grid
cells = []
for j in range(gridheight):
for i in range (gridwidth):
cell = source[:,:,:, crop+j*grid-expand:crop+(j+1)*grid+expand, crop+i*grid-expand:crop+(i+1)*grid+expand]
cells.append(cell)
cells = np.vstack (cells)
return cells, gridwidth, gridheight
评论列表
文章目录