net.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:chainer-cf-nade 作者: dsanno 项目源码 文件源码
def __call__(self, h, train=True):
        """
        in_type:
            h: float32
        in_shape:
            h: (batch_size, hidden_num)
        out_type: float32
        out_shape: (batch_size, rating_num, predicted_item_num)
        """

        xp = cuda.get_array_module(h.data)
        h = self.p(h)
        if hasattr(self, 'q'):
            h = self.q(h)
        h = F.reshape(h, (-1, self.rating_num, self.item_num, 1))
        w = chainer.Variable(xp.asarray(np.tri(self.rating_num, dtype=np.float32).reshape(self.rating_num, self.rating_num, 1, 1)), volatile=h.volatile)
        h = F.convolution_2d(h, w)
        return F.reshape(h, (-1, self.rating_num, self.item_num))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号