custom.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:seqmod 作者: emanjavacas 项目源码 文件源码
def __init__(self, cell, num_layers, in_dim, hid_dim,
                 dropout=0.0, **kwargs):
        """
        cell: str or custom cell class
        """
        super(BaseStackedRNN, self).__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.has_dropout = False
        if dropout:
            self.has_dropout = True
            self.dropout = nn.Dropout(dropout)
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        if isinstance(cell, str):
            cell = getattr(nn, cell)
        for i in range(num_layers):
            self.layers.append(cell(in_dim, hid_dim, **kwargs))
            in_dim = hid_dim
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号