def __init__(self, cell, num_layers, in_dim, hid_dim,
dropout=0.0, **kwargs):
"""
cell: str or custom cell class
"""
super(BaseStackedRNN, self).__init__()
self.in_dim = in_dim
self.hid_dim = hid_dim
self.has_dropout = False
if dropout:
self.has_dropout = True
self.dropout = nn.Dropout(dropout)
self.num_layers = num_layers
self.layers = nn.ModuleList()
if isinstance(cell, str):
cell = getattr(nn, cell)
for i in range(num_layers):
self.layers.append(cell(in_dim, hid_dim, **kwargs))
in_dim = hid_dim
评论列表
文章目录