main.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:cnn-text-classification 作者: marevol 项目源码 文件源码
def __call__(self, x, train=True):
        hlist = []
        h_0 = self['embed'](x)
        if not self.non_static:
            h_0 = Variable(h_0.data)
        h_1 = F.reshape(h_0, (h_0.shape[0], 1, h_0.shape[1], h_0.shape[2]))
        for filter_h in self.filter_sizes:
            pool_size = (self.doc_length - filter_h + 1, 1)
            h = F.max_pooling_2d(F.relu(self['conv' + str(filter_h)](h_1)), pool_size)
            hlist.append(h)
        h = F.concat(hlist)
        pos = 0
        while pos < len(self.hidden_units) - 1:
            h = F.dropout(F.relu(self['l' + str(pos)](h)))
            pos += 1
        y = F.relu(self['l' + str(pos)](h))
        return y
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号