def __init__(self, output_dim, hidden_dim, output_length, depth=1, dropout=0.25, **kwargs):
super(SimpleSeq2seq, self).__init__()
if type(depth) not in [list, tuple]:
depth = (depth, depth)
self.encoder = LSTM(hidden_dim, **kwargs)
self.decoder = LSTM(hidden_dim, return_sequences=True, **kwargs)
for i in range(1, depth[0]):
self.add(LSTM(hidden_dim, return_sequences=True, **kwargs))
self.add(Dropout(dropout))
self.add(self.encoder)
self.add(Dropout(dropout))
self.add(RepeatVector(output_length))
self.add(self.decoder)
for i in range(1, depth[1]):
self.add(LSTM(hidden_dim, return_sequences=True, **kwargs))
self.add(Dropout(dropout))
#if depth[1] > 1:
self.add(TimeDistributedDense(output_dim, activation='softmax'))
评论列表
文章目录