def AttentionDecoder(hidden_size, activation=None, return_sequences=True, bidirectional=False, use_gru=True):
if activation is None:
activation = ELU()
if use_gru:
def _decoder(x, attention):
if bidirectional:
branch_1 = AttentionWrapper(GRU(int(hidden_size/2), activation='linear', return_sequences=return_sequences,
go_backwards=False), attention, single_attention_param=True)(x)
branch_2 = AttentionWrapper(GRU(int(hidden_size/2), activation='linear', return_sequences=return_sequences,
go_backwards=True), attention, single_attention_param=True)(x)
x = concatenate([branch_1, branch_2])
return activation(x)
else:
x = AttentionWrapper(GRU(hidden_size, activation='linear',
return_sequences=return_sequences), attention, single_attention_param=True)(x)
x = activation(x)
return x
else:
def _decoder(x, attention):
if bidirectional:
branch_1 = AttentionWrapper(LSTM(int(hidden_size/2), activation='linear', return_sequences=return_sequences,
go_backwards=False), attention, single_attention_param=True)(x)
branch_2 = AttentionWrapper(LSTM(hidden_size, activation='linear', return_sequences=return_sequences,
go_backwards=True), attention, single_attention_param=True)(x)
x = concatenate([branch_1, branch_2])
x = activation(x)
return x
else:
x = AttentionWrapper(LSTM(hidden_size, activation='linear', return_sequences=return_sequences),
attention, single_attention_param=True)(x)
x = activation(x)
return x
return _decoder
评论列表
文章目录