def glimpseSensor(normalLocation, inputPlaceholder):
location = tf.round(tf.multiply((normalLocation + 1)/2.0, InputImageSize))
location = tf.cast(location, tf.int32)
images = tf.reshape(inputPlaceholder, (batchSize, InputImageSize[0],
InputImageSize[1],
InputImageSize[2]))
zooms = []
for k in xrange(batchSize):
imgZooms = []
img = images[k]
loc = location[k]
for i in xrange(glimpseDepth):
radius = int(glimpseRadius * (2 ** i))
glimpse = getGlipmse(img, loc, radius)
glimpse = tf.reshape(glimpse, (glimpseBandwidth, glimpseBandwidth, glimpseBandwidth))
imgZooms.append(glimpse)
zooms.append(tf.pack(imgZooms))
zooms = tf.pack(zooms)
return zooms
评论列表
文章目录