def compile_inference(self):
"""
build the hidden action prediction.
"""
inputs = T.imatrix() # padded input word sequence (for training)
if self.config['mode'] == 'RNN':
context = alloc_zeros_matrix(inputs.shape[0], self.config['enc_contxt_dim'])
elif self.config['mode'] == 'NTM':
context = T.repeat(self.memory[None, :, :], inputs.shape[0], axis=0)
else:
raise NotImplementedError
# encoding
memorybook = self.encoder.build_encoder(inputs, context)
# get Q(a|y) = sigmoid(.|Posterior * encoded)
q_dis = self.Post(memorybook)
p_dis = self.Prior()
self.inference_ = theano.function([inputs], [memorybook, q_dis, p_dis])
logger.info("inference function compile done.")
评论列表
文章目录