model_HighWay_CNN.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:cnn-lstm-bilstm-deepcnn-clstm-in-pytorch 作者: bamtercelboo 项目源码 文件源码
def forward(self, x):
        # print("source x {} ".format(x.size()))
        x = self.embed(x)  # (N,W,D)
        x = self.dropout(x)
        x = x.unsqueeze(1)  # (N,Ci,W,D)
        if self.args.batch_normalizations is True:
            x = [self.convs1_bn(F.tanh(conv(x))).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
            x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
        else:
            # x = [self.dropout(F.relu(conv(x)).squeeze(3)) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
            x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
            # x = [F.tanh(conv(x)).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
            # x = [conv(x).squeeze(3) for conv in self.convs1] #[(N,Co,W), ...]*len(Ks)
            x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks)
        x = torch.cat(x, 1)
        # x = self.dropout(x)  # (N,len(Ks)*Co)
        if self.args.batch_normalizations is True:
            x = self.fc1_bn(self.fc1(x))
            fc = self.fc2_bn(self.fc2(F.tanh(x)))
        else:
            fc = self.fc1(x)
            # fc = self.fc2(F.relu(x))

        # print("xxx {} ".format(x.size()))

        gate_layer = F.sigmoid(self.gate_layer(x))

        # calculate highway layer values
        # print(" fc_size {} gate_layer_size {}".format(fc.size(), gate_layer.size()))
        gate_fc_layer = torch.mul(fc, gate_layer)
        # print("gate_layer {} ".format(gate_layer))
        # print("1 - gate_layer size {} ".format((1 - gate_layer).size()))
        # if write like follow ,can run,but not equal the HighWay NetWorks formula
        # gate_input = torch.mul((1 - gate_layer), fc)
        gate_input = torch.mul((1 - gate_layer), x)
        highway_output = torch.add(gate_fc_layer, gate_input)

        logit = self.logit_layer(highway_output)

        return logit
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号