model.py 文件源码

python
阅读 45 收藏 0 点赞 0 评论 0

项目:cnn-lstm-bilstm-deepcnn-clstm-in-pytorch 作者: bamtercelboo 项目源码 文件源码
def __init__(self, args):
        super(CNN_Text,self).__init__()
        self.args = args

        V = args.embed_num
        D = args.embed_dim
        C = args.class_num
        Ci = 1
        Co = args.kernel_num
        Ks = args.kernel_sizes

        self.embed = nn.Embedding(V, D)
        # print("aaaaaaaa", self.embed.weight)
        pretrained_weight = np.array(args.pretrained_weight)
        self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))
        # print("bbbbbbbb", self.embed.weight)

        self.convs1 = [nn.Conv2d(Ci, Co, (K, D)) for K in Ks]
        '''
        self.conv13 = nn.Conv2d(Ci, Co, (3, D))
        self.conv14 = nn.Conv2d(Ci, Co, (4, D))
        self.conv15 = nn.Conv2d(Ci, Co, (5, D))
        '''
        self.dropout = nn.Dropout(args.dropout)
        self.fc1 = nn.Linear(len(Ks)*Co, C)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号