def __init__(self, output_dim, hidden_dim, output_length, depth=1,bidirectional=True, dropout=0.1, **kwargs):
if bidirectional and hidden_dim % 2 != 0:
raise Exception ("hidden_dim for AttentionSeq2seq should be even (Because of bidirectional RNN).")
super(AttentionSeq2seq, self).__init__()
if type(depth) not in [list, tuple]:
depth = (depth, depth)
if 'batch_input_shape' in kwargs:
shape = kwargs['batch_input_shape']
del kwargs['batch_input_shape']
elif 'input_shape' in kwargs:
shape = (None,) + tuple(kwargs['input_shape'])
del kwargs['input_shape']
elif 'input_dim' in kwargs:
if 'input_length' in kwargs:
input_length = kwargs['input_length']
else:
input_length = None
shape = (None, input_length, kwargs['input_dim'])
del kwargs['input_dim']
self.add(Layer(batch_input_shape=shape))
if bidirectional:
self.add(Bidirectional(LSTMEncoder(output_dim=int(hidden_dim / 2), state_input=False, return_sequences=True, **kwargs)))
else:
self.add(LSTMEncoder(output_dim=hidden_dim, state_input=False, return_sequences=True, **kwargs))
for i in range(0, depth[0] - 1):
self.add(Dropout(dropout))
if bidirectional:
self.add(Bidirectional(LSTMEncoder(output_dim=int(hidden_dim / 2), state_input=False, return_sequences=True, **kwargs)))
else:
self.add(LSTMEncoder(output_dim=hidden_dim, state_input=False, return_sequences=True, **kwargs))
encoder = self.layers[-1]
self.add(Dropout(dropout))
self.add(TimeDistributed(Dense(hidden_dim if depth[1] > 1 else output_dim)))
decoder = AttentionDecoder(hidden_dim=hidden_dim, output_length=output_length, state_input=False, **kwargs)
self.add(Dropout(dropout))
self.add(decoder)
for i in range(0, depth[1] - 1):
self.add(Dropout(dropout))
self.add(LSTMEncoder(output_dim=hidden_dim, state_input=False, return_sequences=True, **kwargs))
self.add(Dropout(dropout))
self.add(TimeDistributed(Dense(output_dim, activation='softmax')))
self.encoder = encoder
self.decoder = decoder
评论列表
文章目录