def __init__(self, output_dim, hidden_dim, output_length, depth=1, broadcast_state=True, inner_broadcast_state=True, peek=False, dropout=0.1, **kwargs):
super(Seq2seq, self).__init__()
if type(depth) not in [list, tuple]:
depth = (depth, depth)
if 'batch_input_shape' in kwargs:
shape = kwargs['batch_input_shape']
del kwargs['batch_input_shape']
elif 'input_shape' in kwargs:
shape = (None,) + tuple(kwargs['input_shape'])
del kwargs['input_shape']
elif 'input_dim' in kwargs:
shape = (None, None, kwargs['input_dim'])
del kwargs['input_dim']
lstms = []
layer = LSTMEncoder(batch_input_shape=shape, output_dim=hidden_dim, state_input=False, return_sequences=depth[0] > 1, **kwargs)
self.add(layer)
lstms += [layer]
for i in range(depth[0] - 1):
self.add(Dropout(dropout))
layer = LSTMEncoder(output_dim=hidden_dim, state_input=inner_broadcast_state, return_sequences=i < depth[0] - 2, **kwargs)
self.add(layer)
lstms += [layer]
if inner_broadcast_state:
for i in range(len(lstms) - 1):
lstms[i].broadcast_state(lstms[i + 1])
encoder = self.layers[-1]
self.add(Dropout(dropout))
decoder_type = LSTMDecoder2 if peek else LSTMDecoder
decoder = decoder_type(hidden_dim=hidden_dim, output_length=output_length, state_input=broadcast_state, **kwargs)
self.add(decoder)
lstms = [decoder]
for i in range(depth[1] - 1):
self.add(Dropout(dropout))
layer = LSTMEncoder(output_dim=hidden_dim, state_input=inner_broadcast_state, return_sequences=True, **kwargs)
self.add(layer)
lstms += [layer]
if inner_broadcast_state:
for i in range(len(lstms) - 1):
lstms[i].broadcast_state(lstms[i + 1])
if broadcast_state:
encoder.broadcast_state(decoder)
self.add(Dropout(dropout))
self.add(TimeDistributed(Dense(output_dim, **kwargs)))
self.encoder = encoder
self.decoder = decoder
评论列表
文章目录