rnn_sampling.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:sequence-based-recommendations 作者: rdevooght 项目源码 文件源码
def _compile_test_function(self):
        ''' Differs from base test function because of the added softmax operation
        '''
        print("Compiling test...")
        deterministic_output = T.nnet.softmax(lasagne.layers.get_output(self.l_out, deterministic=True))
        if self.interactions_are_unique:
            deterministic_output *= (1 - self.exclude)

        theano_test_function = theano.function(self.theano_inputs, deterministic_output, allow_input_downcast=True, name="Test_function", on_unused_input='ignore')

        def precision_test_function(theano_inputs, k=10):
            output = theano_test_function(*theano_inputs)
            ids = np.argpartition(-output, range(k), axis=-1)[0, :k]

            return ids

        self.test_function = precision_test_function
        print("Compilation done.")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号