model_DeepCNN.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:cnn-lstm-bilstm-deepcnn-clstm-in-pytorch 作者: bamtercelboo 项目源码 文件源码
def __init__(self, args):
        super(DEEP_CNN, self).__init__()
        self.args = args

        V = args.embed_num
        D = args.embed_dim
        C = args.class_num
        Ci = 1
        Co = args.kernel_num
        Ks = args.kernel_sizes
        if args.max_norm is not None:
            print("max_norm = {} ".format(args.max_norm))
            self.embed = nn.Embedding(V, D, max_norm=args.max_norm, scale_grad_by_freq=True)
            # self.embed.weight.data.uniform(-0.1, 0.1)
        else:
            print("max_norm = {} ".format(args.max_norm))
            self.embed = nn.Embedding(V, D, scale_grad_by_freq=True)
        # word embedding
        if args.word_Embedding:
            pretrained_weight = np.array(args.pretrained_weight)
            self.embed.weight.data.copy_(torch.from_numpy(pretrained_weight))
            # fixed the word embedding
            self.embed.weight.requires_grad = True
        # cons layer
        self.convs1 = [nn.Conv2d(Ci, D, (K, D), stride=1, padding=(K//2, 0), bias=True) for K in Ks]
        self.convs2 = [nn.Conv2d(Ci, Co, (K, D), stride=1, padding=(K//2, 0), bias=True) for K in Ks]
        print(self.convs1)
        print(self.convs2)

        if args.init_weight:
            print("Initing W .......")
            for (conv1, conv2) in zip(self.convs1, self.convs2):
                init.xavier_normal(conv1.weight.data, gain=np.sqrt(args.init_weight_value))
                init.uniform(conv1.bias, 0, 0)
                init.xavier_normal(conv2.weight.data, gain=np.sqrt(args.init_weight_value))
                init.uniform(conv2.bias, 0, 0)

        # dropout
        self.dropout = nn.Dropout(args.dropout)
        # linear
        in_fea = len(Ks) * Co
        self.fc1 = nn.Linear(in_features=in_fea, out_features=in_fea // 2, bias=True)
        self.fc2 = nn.Linear(in_features=in_fea // 2, out_features=C, bias=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号