model.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:python-utils 作者: zhijian-liu 项目源码 文件源码
def __init__(self, input_size, feature_size = 128, hidden_size = 256, num_layers = 1, dropout = 0.9):
        super(SeqEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # set up modules for recurrent neural networks
        self.rnn = nn.LSTM(input_size = input_size,
                           hidden_size = hidden_size,
                           num_layers = num_layers,
                           batch_first = True,
                           dropout = dropout,
                           bidirectional = True)
        self.rnn.apply(weights_init)

        # set up modules to compute features
        self.feature = nn.Linear(hidden_size * 2, feature_size)
        self.feature.apply(weights_init)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号