def decode_to_probs(self, activations, relative_position, low_bound, high_bound):
squashed = T.reshape(activations, (-1,self.RAW_ENCODING_WIDTH))
n_parallel = squashed.shape[0]
probs = T.nnet.softmax(squashed)
def _scan_fn(cprobs, cpos):
if self.with_artic:
abs_probs = cprobs[:2]
rel_probs = cprobs[2:]
else:
rel_probs = cprobs
abs_probs = T.ones((2,))
aligned = T.roll(rel_probs, (cpos-low_bound)%12)
num_tile = int(math.ceil((high_bound-low_bound)/self.WINDOW_SIZE))
tiled = T.tile(aligned, (num_tile,))[:(high_bound-low_bound)]
full = T.concatenate([abs_probs, tiled], 0)
return full
# probs = theano.printing.Print("probs",['shape'])(probs)
# relative_position = theano.printing.Print("relative_position",['shape'])(relative_position)
from_scan, _ = theano.map(fn=_scan_fn, sequences=[probs, T.flatten(relative_position)])
# from_scan = theano.printing.Print("from_scan",['shape'])(from_scan)
newshape = T.concatenate([activations.shape[:-1],[2+high_bound-low_bound]],0)
fixed = T.reshape(from_scan, newshape, ndim=activations.ndim)
return fixed
评论列表
文章目录